14

I am a beginner to Keras and I have started with the MNIST example to understand how the library actually works. The code snippet of the MNIST problem in the Keras example folder is given as :

import numpy as np
np.random.seed(1337)  # for reproducibility

from keras.datasets import mnist from keras.models import Sequential from keras.layers import Dense, Dropout, Activation, Flatten
from keras.layers import Convolution2D, MaxPooling2D from keras.utils import np_utils

batch_size = 128 nb_classes = 10 nb_epoch = 12

input image dimensions

img_rows, img_cols = 28, 28

number of convolutional filters to use

nb_filters = 32

size of pooling area for max pooling

nb_pool = 2

convolution kernel size

nb_conv = 3

the data, shuffled and split between train and test sets

(X_train, y_train), (X_test, y_test) = mnist.load_data() X_train = X_train.reshape(X_train.shape[0], 1, img_rows, img_cols) X_test = X_test.reshape(X_test.shape[0], 1, img_rows, img_cols) X_train = X_train.astype('float32') X_test = X_test.astype('float32') ..........

I am unable to understand the reshape function here. What is it doing and why we have applied it?

enterML
  • 3,091
  • 9
  • 28
  • 38

2 Answers2

11

mnist.load_data() supplies the MNIST digits with structure (nb_samples, 28, 28) i.e. with 2 dimensions per example representing a greyscale image 28x28.

The Convolution2D layers in Keras however, are designed to work with 3 dimensions per example. They have 4-dimensional inputs and outputs. This covers colour images (nb_samples, nb_channels, width, height), but more importantly, it covers deeper layers of the network, where each example has become a set of feature maps i.e. (nb_samples, nb_features, width, height).

The greyscale image for MNIST digits input would either need a different CNN layer design (or a param to the layer constructor to accept a different shape), or the design could simply use a standard CNN and you must explicitly express the examples as 1-channel images. The Keras team chose the latter approach, which needs the re-shape.

Neil Slater
  • 29,388
  • 5
  • 82
  • 101
0

Just a small correction from the accepted answer , the input shape indices are named as follows: (n_images, x_shape, y_shape, channels)

Chandan
  • 101