46

I am trying to merge two Keras models into a single model and I am unable to accomplish this.

For example in the attached Figure, I would like to fetch the middle layer $A2$ of dimension 8, and use this as input to the layer $B1$ (of dimension 8 again) in Model $B$ and then combine both Model $A$ and Model $B$ as a single model.

I am using the functional module to create Model $A$ and Model $B$ independently. How can I accomplish this task?

Note: $A1$ is the input layer to model $A$ and $B1$ is the input layer to model $B$.

See Picture

Rkz
  • 1,033
  • 1
  • 10
  • 12

2 Answers2

42

I figured out the answer to my question and here is the code that builds on the above answer.

from keras.layers import Input, Dense
from keras.models import Model
from keras.utils import plot_model

A1 = Input(shape=(30,),name='A1')
A2 = Dense(8, activation='relu',name='A2')(A1)
A3 = Dense(30, activation='relu',name='A3')(A2)

B2 = Dense(40, activation='relu',name='B2')(A2)
B3 = Dense(30, activation='relu',name='B3')(B2)

merged = Model(inputs=[A1],outputs=[A3,B3])
plot_model(merged,to_file='demo.png',show_shapes=True)

and here is the output structure that I wanted:

enter image description here

Rkz
  • 1,033
  • 1
  • 10
  • 12
14

In Keras there is a helpful way to define a model: using the functional API. With functional API you can define a directed acyclic graphs of layers, which lets you build completely arbitrary architectures. Considering your example:

#A_data = np.zeros((1,30))
#A_labels = np.zeros((1,30))
#B_labels =np.zeros((1,30))

A1 = layers.Input(shape=(30,), name='A_input')
A2 = layers.Dense(8, activation='???')(A1)
A3 = layers.Dense(30, activation='???', name='A_output')(A2)


B2 = layers.Dense(40, activation='???')(A2)
B3 = layers.Dense(30, activation='???', name='B_output')(B2)

## define A
A = models.Model(inputs=A1, outputs=A3)

## define B
B = models.Model(inputs=A1, outputs=B3) 

B.compile(optimizer='??',
          loss={'B_output': '??'}
          )

B.fit({'A_input': A_data},
  {'B_output': B_labels},
  epochs=??, batch_size=??)

So, that's it! You can see the result by: B.summary():

Layer (type)                 Output Shape              Param    
A_input (InputLayer)         (None, 30)                0         
_________________________________________________________________
dense_8 (Dense)              (None, 8)                 248     
______________________________________________________________
dense_9 (Dense)              (None, 40)                360       
_________________________________________________________________
B_output (Dense)             (None, 30)                1230      
Mo-
  • 1,255
  • 1
  • 10
  • 26