TransWikia.com

How to get tf tensor value computed in loss function in keras train_on_batch without computing it twice or writing custom loop?

Data Science Asked by userqwerty1 on April 10, 2021

I have a model and I’ve implemented a custom loss function something along the lines:

def custom_loss(labels, predictions):
    global diff
    #actual code uses decorator so no globals
    diff = labels - predictions
    return tf.square(diff)

model.compile(loss=custom_loss, optimizer=opt.RMSprop())

...

model.train_on_batch(input, labels)

#

How to get diff after I’ve run train_on_batch without causing
it to rerun predict a second time behind the scenes(unnecessary slowdown) and mess up with trainable/batchnorm etc(possible problems)?

I want to avoid making a manual raw tensorflow train_op loop etc, keeping track of learning phase and whatnot. Is this my only choice?

I’m using tensorflow 1.14’s keras module

One Answer

I've solved it(discovered control dependencies and remembered variables) Basically, I create an assign diff to a variable operation, and with the help of control_dependencies force tf to do this operation every time op is calculated, this way when I get this variable, it doesn't cause graph recalculation

diff_var = tf.Variable()

def custom_loss(labels, predictions):
    diff = labels - predictions
    diff_var_op = diff_var.assign(diff)
    with tf.control_dependencies([diff_var_op]):
        return tf.square(diff)

test code

import tensorflow as tf
sess = tf.Session()
var1 = tf.Variable(1, dtype=tf.float32)
var2 = tf.Variable(2, dtype=tf.float32)
counter = tf.Variable(1, dtype=tf.float32)
var2_op = tf.square(var2)*counter

diff_var = tf.Variable(10, dtype=tf.float32, trainable=False)

diff = var2_op - var1
diff_var_op = diff_var.assign(diff)

with tf.control_dependencies([diff_var_op]):
    op = tf.square(diff)

sess.run(tf.global_variables_initializer())

print('diff var:', sess.run(diff_var))     #10
print('counter:', sess.run(counter))       #1
print('op:', sess.run(op))                 #9
print('diff var:', sess.run(diff_var))     #3
print('-')
counter_op = tf.assign_add(counter, 1)
print('counter:', sess.run(counter))       #1
print('diff var:', sess.run(diff_var))     #3  #still the same
print('var2:', sess.run(var2_op))          #4
print('-')
sess.run(counter_op)
print('after counter_op')
print('counter:', sess.run(counter))       #2
print('var2:', sess.run(var2_op))          #8
#still the same even though var2 has changed because of counter_op
print('diff var:', sess.run(diff_var))     #3
print('op:', sess.run(op))                 #49 #running full op
print('-')
print('after op')
print('diff var:', sess.run(diff_var))     #7
#variable changed, no operations involved

Answered by userqwerty1 on April 10, 2021

Add your own answers!

Ask a Question

Get help from others!

© 2024 TransWikia.com. All rights reserved. Sites we Love: PCI Database, UKBizDB, Menu Kuliner, Sharing RPP