1

I'm implementing the U-Net model per the published paper here. This is my model so far:

def create_unet_model(image_size = IMAGE_SIZE):
# Input layer is a 572,572 colour image
input_layer = Input(shape=(image_size) + (3,))

""" Begin Downsampling """

# Block 1
conv_1 = Conv2D(64, 3, activation = 'relu')(input_layer)
conv_2 = Conv2D(64, 3, activation = 'relu')(conv_1)

max_pool_1 = MaxPool2D(strides=2)(conv_2)

# Block 2
conv_3 = Conv2D(128, 3, activation = 'relu')(max_pool_1)
conv_4 = Conv2D(128, 3, activation = 'relu')(conv_3)

max_pool_2 = MaxPool2D(strides=2)(conv_4)

# Block 3
conv_5 = Conv2D(256, 3, activation = 'relu')(max_pool_2)
conv_6 = Conv2D(256, 3, activation = 'relu')(conv_5)

max_pool_3 = MaxPool2D(strides=2)(conv_6)

# Block 4
conv_7 = Conv2D(512, 3, activation = 'relu')(max_pool_3)
conv_8 = Conv2D(512, 3, activation = 'relu')(conv_7)

max_pool_4 = MaxPool2D(strides=2)(conv_8)

""" Begin Upsampling """

# Block 5
conv_9 = Conv2D(1024, 3, activation = 'relu')(max_pool_4)
conv_10 = Conv2D(1024, 3, activation = 'relu')(conv_9)

upsample_1 = UpSampling2D()(conv_10)

# Copy and Crop
conv_8_cropped = Cropping2D(cropping=4)(conv_8)
merge_1 = Concatenate()([conv_8_cropped, upsample_1])

# Block 6
conv_11 = Conv2D(512, 3, activation = 'relu')(merge_1)
conv_12 = Conv2D(512, 3, activation = 'relu')(conv_11)

upsample_2 = UpSampling2D()(conv_12)

# Copy and Crop
conv_6_cropped = Cropping2D(cropping=16)(conv_6)
merge_2 = Concatenate()([conv_6_cropped, upsample_2])

# Block 7
conv_13 = Conv2D(256, 3, activation = 'relu')(merge_2)
conv_14 = Conv2D(256, 3, activation = 'relu')(conv_13)
upsample_3 = UpSampling2D()(conv_14)

# Copy and Crop
conv_4_cropped = Cropping2D(cropping=40)(conv_4)
merge_3 = Concatenate()([conv_4_cropped, upsample_3])

# Block 8
conv_15 = Conv2D(128, 3, activation = 'relu')(merge_3)
conv_16 = Conv2D(128, 3, activation = 'relu')(conv_15)
upsample_4 = UpSampling2D()(conv_16)

# Connect layers
conv_2_cropped = Cropping2D(cropping=88)(conv_2)
merge_4 = Concatenate()([conv_2_cropped, upsample_4])

# Block 9
conv_17 = Conv2D(64, 3, activation = 'relu')(merge_4)
conv_18 = Conv2D(64, 3, activation = 'relu')(conv_17)

# Output layer
output_layer = Conv2D(1, 1, activation='sigmoid')(conv_18)

""" Define the model """
unet = Model(input_layer, output_layer)

return unet

The cropping implemented as specified in this answer and is specific to 572x572 images.

Unfortunately this implementation causes a ResourceExhaustedError:

Exception has occurred: ResourceExhaustedError
 OOM when allocating tensor with shape[32,64,392,392] and type float on /job:localhost/replica:0/task:0/device:GPU:0 by allocator GPU_0_bfc
     [[node model/cropping2d_3/strided_slice (defined at c:\main.py:74) ]]
Hint: If you want to see a list of allocated tensors when OOM happens, add report_tensor_allocations_upon_oom to RunOptions for current allocation info.
 [Op:__inference_train_function_3026]

Function call stack: train_function File "C:\main.py", line 74, in main unet_model.fit(train_images, epochs=epochs, validation_data=validation_images, callbacks=CALLBACKS) File "C:\main.py", line 276, in <module> main()

My GPU is a GeForce RTX 2070 Super 8GB.

I verified that the image size was the source of this by reproducing the error in another u-net solution which I know works.

To workaround this issue, I'm trying to lower the image sizes e.g. 256x256. I've changed the Cropping2D layers to crop to the expected sizes for each layer:

# Copy and Crop - 24 -> 16
conv_8_cropped = Cropping2D(cropping=4)(conv_8)
merge_1 = Concatenate()([conv_8_cropped, upsample_1])

Copy and Crop - 57 -> 24

conv_6_cropped = Cropping2D(cropping=((17,16),(17,16)))(conv_6) merge_2 = Concatenate()([conv_6_cropped, upsample_2])

Copy and Crop - 122 -> 40

conv_4_cropped = Cropping2D(cropping=41)(conv_4) merge_3 = Concatenate()([conv_4_cropped, upsample_3])

Copy and Crop - 252 -> 72

conv_2_cropped = Cropping2D(cropping=90)(conv_2) merge_4 = Concatenate()([conv_2_cropped, upsample_4])

Updated model summary:

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to
==================================================================================================
input_1 (InputLayer)            [(None, 256, 256, 3) 0
__________________________________________________________________________________________________
conv2d (Conv2D)                 (None, 254, 254, 64) 1792        input_1[0][0]
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 252, 252, 64) 36928       conv2d[0][0]
__________________________________________________________________________________________________
max_pooling2d (MaxPooling2D)    (None, 126, 126, 64) 0           conv2d_1[0][0]
__________________________________________________________________________________________________
conv2d_2 (Conv2D)               (None, 124, 124, 128 73856       max_pooling2d[0][0]
__________________________________________________________________________________________________
conv2d_3 (Conv2D)               (None, 122, 122, 128 147584      conv2d_2[0][0]
__________________________________________________________________________________________________
max_pooling2d_1 (MaxPooling2D)  (None, 61, 61, 128)  0           conv2d_3[0][0]
__________________________________________________________________________________________________
conv2d_4 (Conv2D)               (None, 59, 59, 256)  295168      max_pooling2d_1[0][0]
__________________________________________________________________________________________________
conv2d_5 (Conv2D)               (None, 57, 57, 256)  590080      conv2d_4[0][0]
__________________________________________________________________________________________________
max_pooling2d_2 (MaxPooling2D)  (None, 28, 28, 256)  0           conv2d_5[0][0]
__________________________________________________________________________________________________
conv2d_6 (Conv2D)               (None, 26, 26, 512)  1180160     max_pooling2d_2[0][0]
__________________________________________________________________________________________________
conv2d_7 (Conv2D)               (None, 24, 24, 512)  2359808     conv2d_6[0][0]
__________________________________________________________________________________________________
max_pooling2d_3 (MaxPooling2D)  (None, 12, 12, 512)  0           conv2d_7[0][0]
__________________________________________________________________________________________________
conv2d_8 (Conv2D)               (None, 10, 10, 1024) 4719616     max_pooling2d_3[0][0]
__________________________________________________________________________________________________
conv2d_9 (Conv2D)               (None, 8, 8, 1024)   9438208     conv2d_8[0][0]
__________________________________________________________________________________________________
cropping2d (Cropping2D)         (None, 16, 16, 512)  0           conv2d_7[0][0]
__________________________________________________________________________________________________
up_sampling2d (UpSampling2D)    (None, 16, 16, 1024) 0           conv2d_9[0][0]
__________________________________________________________________________________________________
concatenate (Concatenate)       (None, 16, 16, 1536) 0           cropping2d[0][0]
                                                                 up_sampling2d[0][0]
__________________________________________________________________________________________________
conv2d_10 (Conv2D)              (None, 14, 14, 512)  7078400     concatenate[0][0]
__________________________________________________________________________________________________
conv2d_11 (Conv2D)              (None, 12, 12, 512)  2359808     conv2d_10[0][0]
__________________________________________________________________________________________________
cropping2d_1 (Cropping2D)       (None, 24, 24, 256)  0           conv2d_5[0][0]
__________________________________________________________________________________________________
up_sampling2d_1 (UpSampling2D)  (None, 24, 24, 512)  0           conv2d_11[0][0]
__________________________________________________________________________________________________
concatenate_1 (Concatenate)     (None, 24, 24, 768)  0           cropping2d_1[0][0]
                                                                 up_sampling2d_1[0][0]
__________________________________________________________________________________________________
conv2d_12 (Conv2D)              (None, 22, 22, 256)  1769728     concatenate_1[0][0]
__________________________________________________________________________________________________
conv2d_13 (Conv2D)              (None, 20, 20, 256)  590080      conv2d_12[0][0]
__________________________________________________________________________________________________
cropping2d_2 (Cropping2D)       (None, 40, 40, 128)  0           conv2d_3[0][0]
__________________________________________________________________________________________________
up_sampling2d_2 (UpSampling2D)  (None, 40, 40, 256)  0           conv2d_13[0][0]
__________________________________________________________________________________________________
concatenate_2 (Concatenate)     (None, 40, 40, 384)  0           cropping2d_2[0][0]
                                                                 up_sampling2d_2[0][0]
__________________________________________________________________________________________________
conv2d_14 (Conv2D)              (None, 38, 38, 128)  442496      concatenate_2[0][0]
__________________________________________________________________________________________________
conv2d_15 (Conv2D)              (None, 36, 36, 128)  147584      conv2d_14[0][0]
__________________________________________________________________________________________________
cropping2d_3 (Cropping2D)       (None, 72, 72, 64)   0           conv2d_1[0][0]
__________________________________________________________________________________________________
up_sampling2d_3 (UpSampling2D)  (None, 72, 72, 128)  0           conv2d_15[0][0]
__________________________________________________________________________________________________
concatenate_3 (Concatenate)     (None, 72, 72, 192)  0           cropping2d_3[0][0]
                                                                 up_sampling2d_3[0][0]
__________________________________________________________________________________________________
conv2d_16 (Conv2D)              (None, 70, 70, 64)   110656      concatenate_3[0][0]
__________________________________________________________________________________________________
conv2d_17 (Conv2D)              (None, 68, 68, 64)   36928       conv2d_16[0][0]
__________________________________________________________________________________________________
conv2d_18 (Conv2D)              (None, 68, 68, 1)    65          conv2d_17[0][0]
==================================================================================================
Total params: 31,378,945
Trainable params: 31,378,945
Non-trainable params: 0

This compiles fine but fails at train time with:

Exception has occurred: InvalidArgumentError
 Incompatible shapes: [32,68,68] vs. [32,256,256]
     [[node Equal (defined at c:\main.py:74) ]] [Op:__inference_train_function_3026]

Function call stack: train_function

Does anyone know why the shapes are so incorrect at runtime and how I can fix them?

Update Image loading as part of custom Sequence implementation

source_image = load_img(source_image_paths[i], target_size=self.image_size, color_mode='grayscale')
target_image = load_img(target_image_paths[i], target_size=self.image_size, color_mode='grayscale')

#Start classes at 0 target_image = np.array(target_image) - 1

target_image_array.append(target_image) source_image_array.append(np.array(source_image))

TomSelleck
  • 115
  • 8

1 Answers1

2

It appears that the original images are 68x68 pixels and the model expects 256x256.

You can use the Keras image processing API, in particular the smart_resize function to transform the images to expected number of pixels.

Something like this:

from tf.keras.preprocessing.image import smart_resize

target_size = (256,256) image_resized = smart_resize(image_original, size=target_size, interpolation='bilinear')

Brian Spiering
  • 23,131
  • 2
  • 29
  • 113