def get_dataset(num_classes, rank=0, size=1):
from tensorflow import keras
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data('MNIST-data-%d' % rank)
x_train = x_train[rank::size]
y_train = y_train[rank::size]
x_test = x_test[rank::size]
y_test = y_test[rank::size]
x_train = x_train.reshape(x_train.shape[0], 28, 28, 1)
x_test = x_test.reshape(x_test.shape[0], 28, 28, 1)
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)
return (x_train, y_train), (x_test, y_test)
def get_model(num_classes):
from tensorflow.keras import models
from tensorflow.keras import layers
model = models.Sequential()
model.add(layers.Conv2D(32, kernel_size=(3, 3),
activation='relu',
input_shape=(28, 28, 1)))
model.add(layers.Conv2D(64, (3, 3), activation='relu'))
model.add(layers.MaxPooling2D(pool_size=(2, 2)))
model.add(layers.Dropout(0.25))
model.add(layers.Flatten())
model.add(layers.Dense(128, activation='relu'))
model.add(layers.Dropout(0.5))
model.add(layers.Dense(num_classes, activation='softmax'))
return model
# Specify training parameters
batch_size = 128
epochs = 5
num_classes = 10
def train(learning_rate=1.0):
from tensorflow import keras
(x_train, y_train), (x_test, y_test) = get_dataset(num_classes)
model = get_model(num_classes)
# Specify the optimizer (Adadelta in this example), using the learning rate input parameter of the function so that Horovod can adjust the learning rate during training
optimizer = keras.optimizers.Adadelta(lr=learning_rate)
model.compile(optimizer=optimizer,
loss='categorical_crossentropy',
metrics=['accuracy'])
model.fit(x_train, y_train,
batch_size=batch_size,
epochs=epochs,
verbose=2,
validation_data=(x_test, y_test))
# Runs in 23.67 seconds on 3-node GPU
# Runs in 418.8 seconds on 3-node non-GPU
train(learning_rate=0.1)
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
8192/11490434 [..............................] - ETA: 0s
319488/11490434 [..............................] - ETA: 1s
679936/11490434 [>.............................] - ETA: 1s
1024000/11490434 [=>............................] - ETA: 1s
1409024/11490434 [==>...........................] - ETA: 1s
1802240/11490434 [===>..........................] - ETA: 1s
2195456/11490434 [====>.........................] - ETA: 1s
2588672/11490434 [=====>........................] - ETA: 1s
2981888/11490434 [======>.......................] - ETA: 1s
3358720/11490434 [=======>......................] - ETA: 1s
3866624/11490434 [=========>....................] - ETA: 1s
4440064/11490434 [==========>...................] - ETA: 0s
5160960/11490434 [============>.................] - ETA: 0s
6094848/11490434 [==============>...............] - ETA: 0s
7274496/11490434 [=================>............] - ETA: 0s
8683520/11490434 [=====================>........] - ETA: 0s
10436608/11490434 [==========================>...] - ETA: 0s
11493376/11490434 [==============================] - 1s 0us/step
Epoch 1/5
469/469 - 3s - loss: 0.6257 - accuracy: 0.8091 - val_loss: 0.2157 - val_accuracy: 0.9345
Epoch 2/5
469/469 - 3s - loss: 0.2950 - accuracy: 0.9127 - val_loss: 0.1450 - val_accuracy: 0.9569
Epoch 3/5
469/469 - 3s - loss: 0.2145 - accuracy: 0.9373 - val_loss: 0.1035 - val_accuracy: 0.9695
Epoch 4/5
469/469 - 3s - loss: 0.1688 - accuracy: 0.9512 - val_loss: 0.0856 - val_accuracy: 0.9738
Epoch 5/5
469/469 - 3s - loss: 0.1379 - accuracy: 0.9600 - val_loss: 0.0701 - val_accuracy: 0.9788
ScaDaMaLe Course site and book
The following is from databricks blog with minor adaptations with help from Tilo Wiklund.