Data Science Asked by Ryan Marinelli on June 11, 2021
I have been trying to reproduced this paper that is training a CNN with Tensorflow.
Unfortunately, I only have access to limited computing resources. I have a few raspberry pi that I was thinking of networking with my machine to make the computation less burdensome. But, I am having trouble figuring out how to refactor the code to work in a distributed way. I reviewed the documentation on Tensorflow’s site and some of the more simple tutorials, but I am still getting stuck. I was considering
using the ParameterServerStrategy.I was wondering if anyone else could share examples of refactoring similar models. I think I am mostly having trouble getting the input to work with this step function from Tensorflow tutorial.
The input is of images in two different directories to perform classification. I am not too certain what the step function is doing other to iterate through data. I would appreciate any advice on input.
def dataset_fn(_):
data_dir = "~/Fire-Detection-UAV-Aerial-Image-Classification-Segmentation-UnmannedAerialVehicle-main//frames//Training//Training"
data_dir = pathlib.Path(data_dir)
image_count = len(list(data_dir.glob('*/*.jpg')))
print(image_count)
train_ds = tf.data.Dataset.list_files(str(data_dir / '*/*'), shuffle=False)
return train_ds
@tf.function
def step_fn(iterator):
def replica_fn(batch_data, labels):
labels = ['fire','no_fire']
with tf.Range as tape:
pred = tf.keras.Model(batch_data)
per_example_loss = keras.losses.BinaryCrossentropy(
reduction=tf.keras.losses.Reduction.NONE)(labels, pred)
loss = tf.nn.compute_average_loss(per_example_loss)
gradients = tape.gradient(loss, keras.Model.trainable_variables)
keras.optimizers.apply_gradients(zip(gradients, keras.Model.trainable_variables))
actual_pred = tf.cast(tf.greater(pred, 0.5), tf.int64)
accuracy.update_state(labels, actual_pred)
return loss
batch_data, labels = next(iterator)
losses = strategy.run(replica_fn, args=(batch_data, labels))
return strategy.reduce(tf.distribute.ReduceOp.SUM, losses, axis=None)
```
Get help from others!
Recent Questions
Recent Answers
© 2024 TransWikia.com. All rights reserved. Sites we Love: PCI Database, UKBizDB, Menu Kuliner, Sharing RPP