6

I give to keras an input of shape input_shape=(500,).

For some reasons, I would like to decompose the input vector into to vectors of respective shapes input_shape_1=(300,) and input_shape_2=(200,)

I want to do this within the definition of the model, using the Functional API. In a way, I would like to perform slicing on a tf.Tensor object.

Help is welcome!

Contestosis
  • 191
  • 1
  • 6

1 Answers1

4

If it's just the input you like to decompose, you can preprocess the input data and use two input layers:

import tensorflow as tf

inputs_first_half = tf.keras.Input(shape=(300,)) inputs_second_half = tf.keras.Input(shape=(200,))

do something with it

first_half = tf.keras.layers.Dense(1, activation=tf.nn.relu)(inputs_first_half) second_half = tf.keras.layers.Dense(1, activation=tf.nn.relu)(inputs_second_half) outputs = tf.keras.layers.Add()([first_half, second_half])

model = tf.keras.Model(inputs=[inputs_first_half,inputs_second_half],outputs=outputs)

data = np.random.randn(10,500) out = model.predict([data[:,:300],data[:,300:]])

If you like to split after the input layer you could try reshaping and cropping, e.g,:

inputs = tf.keras.Input(shape=(500,))

do something

intermediate = tf.keras.layers.Dense(500,activation=tf.nn.relu)(inputs)

split vector with cropping

intermediate = tf.keras.layers.Reshape((500,1), input_shape=(500,))(intermediate)

first_half = tf.keras.layers.Cropping1D(cropping=(0,200))(intermediate) first_half = tf.keras.layers.Reshape((300,), input_shape=(300,1))(first_half)

second_half = tf.keras.layers.Cropping1D(cropping=(300,0))(intermediate) second_half = tf.keras.layers.Reshape((200,), input_shape=(200,1))(second_half)

do something with decomposed vectors

first_half = tf.keras.layers.Dense(1, activation=tf.nn.relu)(first_half) second_half = tf.keras.layers.Dense(1, activation=tf.nn.relu)(second_half) outputs = tf.keras.layers.Add()([first_half, second_half])

model = tf.keras.Model(inputs=inputs, outputs=outputs)

data = np.random.randn(10,500) out = model.predict(data)

The Cropping1D() function expects a three-dimensional input (batch_size, axis_to_crop, features) and only crops along the first dimension, therefore we need to add "pseudo-dimension" to our vector by reshaping it.

Tinu
  • 538
  • 1
  • 3
  • 8