Data Science Asked by userqwerty1 on April 10, 2021
I have a model and I’ve implemented a custom loss function something along the lines:
def custom_loss(labels, predictions):
global diff
#actual code uses decorator so no globals
diff = labels - predictions
return tf.square(diff)
model.compile(loss=custom_loss, optimizer=opt.RMSprop())
...
model.train_on_batch(input, labels)
#
How to get diff after I’ve run train_on_batch without causing
it to rerun predict a second time behind the scenes(unnecessary slowdown) and mess up with trainable/batchnorm etc(possible problems)?
I want to avoid making a manual raw tensorflow train_op loop etc, keeping track of learning phase and whatnot. Is this my only choice?
I’m using tensorflow 1.14’s keras module
I've solved it(discovered control dependencies and remembered variables) Basically, I create an assign diff to a variable operation, and with the help of control_dependencies force tf to do this operation every time op is calculated, this way when I get this variable, it doesn't cause graph recalculation
diff_var = tf.Variable()
def custom_loss(labels, predictions):
diff = labels - predictions
diff_var_op = diff_var.assign(diff)
with tf.control_dependencies([diff_var_op]):
return tf.square(diff)
test code
import tensorflow as tf
sess = tf.Session()
var1 = tf.Variable(1, dtype=tf.float32)
var2 = tf.Variable(2, dtype=tf.float32)
counter = tf.Variable(1, dtype=tf.float32)
var2_op = tf.square(var2)*counter
diff_var = tf.Variable(10, dtype=tf.float32, trainable=False)
diff = var2_op - var1
diff_var_op = diff_var.assign(diff)
with tf.control_dependencies([diff_var_op]):
op = tf.square(diff)
sess.run(tf.global_variables_initializer())
print('diff var:', sess.run(diff_var)) #10
print('counter:', sess.run(counter)) #1
print('op:', sess.run(op)) #9
print('diff var:', sess.run(diff_var)) #3
print('-')
counter_op = tf.assign_add(counter, 1)
print('counter:', sess.run(counter)) #1
print('diff var:', sess.run(diff_var)) #3 #still the same
print('var2:', sess.run(var2_op)) #4
print('-')
sess.run(counter_op)
print('after counter_op')
print('counter:', sess.run(counter)) #2
print('var2:', sess.run(var2_op)) #8
#still the same even though var2 has changed because of counter_op
print('diff var:', sess.run(diff_var)) #3
print('op:', sess.run(op)) #49 #running full op
print('-')
print('after op')
print('diff var:', sess.run(diff_var)) #7
#variable changed, no operations involved
Answered by userqwerty1 on April 10, 2021
Get help from others!
Recent Questions
Recent Answers
© 2024 TransWikia.com. All rights reserved. Sites we Love: PCI Database, UKBizDB, Menu Kuliner, Sharing RPP