TransWikia.com

Tensorflow training with batch size of (1, None, features), but model expects extra dimension

Data Science Asked by komodovaran_ on February 28, 2021

I’ve made an autoencoder like below, to accept variable-length inputs. It works for a single sample if I do model.fit(np.expand_dims(x, axis = 0) but this won’t work when passing in an entire dataset. What’s the simplest approach in this case?

import numpy as np
import tensorflow.python.keras.backend as K
from tensorflow.python.keras.layers import Input, LSTM, Lambda
from tensorflow.python.keras.models import Model


def repeat(x):
    step_matrix = K.ones_like(x[0][:, :, :1])
    latent_matrix = K.expand_dims(x[1], axis = 1)
    return K.batch_dot(step_matrix, latent_matrix)

timesteps = None
features = 2
latent_dim = 10

inputs = Input(shape = (timesteps, features))
encoded = LSTM(latent_dim, name = "encoded")(inputs)
decoded = Lambda(repeat)([inputs, encoded])
outputs = LSTM(features, return_sequences = True)(decoded)
autoenc = Model(inputs = inputs, outputs = outputs)
autoenc.compile(optimizer = "adam", loss = "mse")
encoder = Model(
    inputs = autoenc.input, outputs = autoenc.get_layer("encoded").output
)

x1 = np.ones((20, 2))
x2 = np.ones((30, 2))
x3 = np.ones((40, 2))
X_train = np.array((x1, x2, x3))

autoenc.fit(x = X_train, y = X_train, epochs = 10, batch_size = 1)

One Answer

I managed to solve my problem with a generator, which expands the dimensions for each single batch to return shape (1, None, 2).

class SingleBatchGenerator:
    def __init__(self, X):
        self.X = X

    def __call__(self):
        for i in range(len(self.X)):
            xi = np.expand_dims(self.X[i], axis=0)
            yield xi, xi

X = [np.ones((np.random.randint(1, 100), 2)) for _ in range(100)]
gen = SingleBatchGenerator(X)

ds = tf.data.Dataset.from_generator(
    generator = gen,
    output_types=(tf.float64, tf.float64),
    output_shapes=((1, None, 2), (1, None, 2)),
)

autoenc.fit(ds.repeat(), steps_per_epoch=len(X), epochs=500)

Answered by komodovaran_ on February 28, 2021

Add your own answers!

Ask a Question

Get help from others!

© 2024 TransWikia.com. All rights reserved. Sites we Love: PCI Database, UKBizDB, Menu Kuliner, Sharing RPP