TransWikia.com

How to feed data to multi-output Keras model from a single TFRecords file

Data Science Asked by magomar on September 1, 2021

I know how to feed data to a multi-output Keras model using numpy arrays for the training data. However, I have all my data in a single TFRecords file comprising several feature columns: an image, which is used as input to the Keras model, plus a sequence of outputs corresponding to different classification tasks: eg. an output encodes the age of the person in the image, another output encodes the gende, and so on.

From what I have seen in examples, when the output of the model is made of various heads, the model should be fed with multiple data sources, one for the input, and one for each of the ouputs.

Is there an easy way to do that when the data is all in a single TFRecords? I mean, without having to create separate TFRecords for the input and each of the ouputs?

2 Answers

After playing around with tf.data.map operations I found the answer was easier than expected, I simply had to preprocess the data and put all the labels for each output of the model as a different key of a dictionary.

First I create a dataset from the tfrecords file

dataset = tf.data.TFRecordDataset(tfrecords_file)

Next, I parse data from the file

feature = {'image/encoded': tf.io.FixedLenFeature((), tf.string),
           'image/shape': tf.io.FixedLenFeature((3), tf.int64),
           'age': tf.io.FixedLenFeature((), tf.int64),
           'gender': tf.io.FixedLenFeature((), tf.int64),
           'ethnicity': tf.io.FixedLenFeature((), tf.int64),
 }

return tf_util.parse_pb_message(protobuff_message, feature)

dataset = dataset.map(parser).map(process_example)

At this point, we have a standard dataset we can operate with doing batching, shuffling, augmentation or whatever we wan. Finally, before feeding the data into the model, we have to transform it to fit the requirements of the model. The code below shows an example of both input and label preprocessing. Previoulsy, I concatenated all the labels, now I create a dictionary witht the names of the outputs in the model as keys.

def preprocess_input_fn():
    def _preprocess_input(image,image_shape, age, gender, ethnicity):
        image = self.preprocess_image(image)
        labels = self.preprocess_labels(age, gender, ethnicity)
        return image, labels

    return _preprocess_input

def preprocess_image(image):
    image = tf.cast(image)
    image = tf.image.resize(image)
    image = (image / 127.5) - 1.0
    return image

def preprocess_labels(age,gender,ethnicity):
    gender = tf.one_hot(gender, 2)
    ethnicity = tf.one_hot(ethnicity, self.ethnic_groups)
    age = tf.one_hot(age, self.age_groups)
    return {'Gender': gender, 'Ethnicity': ethnicity, 'Age': age}

In my model, Gender, Ethnicity and Age are the names of the last layers of the model, so my model is defined as having three outputs:

model = Model(inputs=inputs,
              outputs=[gender, ethnic_group, age_group])

Now I can use a dataset to fit the model by applying the preprocessing function first:

data = dataset.map(preprocess_input_fn())

model.fit(data, epochs=...)    

Correct answer by magomar on September 1, 2021

Considering your model recevie an image as input and has two outputs age and gender, and that you have generated a TFRecord with them. You can decode and use your TFRecord through tf.data this way:

decode_features = {
  'image'  : tf.io.FixedLenFeature([], tf.string),
  'age'    : tf.io.FixedLenFeature([1], tf.int64),
  'gender' : tf.io.FixedLenFeature([1], tf.int64),
}

def decode(serialized_example):
  features = tf.io.parse_single_example(serialized_example, features=decode_features)
  image = tf.image.decode_image(features['image_raw'], name="InputImage")
  image = tf.cast(image, tf.float32) / 128. - 1.
  labels = {}
  labels['age']    = tf.cast(features['age'], tf.int32)
  labels['gender'] = tf.cast(features['gender'], tf.int32)
  return image, labels

dataset = tf.data.TFRecordDataset('path/to/file.tfrecords')
dataset = dataset.map(decode)

model.fit(dataset, ...)
```

Answered by barbolo on September 1, 2021

Add your own answers!

Ask a Question

Get help from others!

© 2024 TransWikia.com. All rights reserved. Sites we Love: PCI Database, UKBizDB, Menu Kuliner, Sharing RPP