Transfer Learning for Deep Neural Networks on VMware Tanzu Greenplum Database

August 13, 2019 Frank McQuillan

Transfer Learning

It can be expensive and time-consuming to train a deep neural network from scratch.  Even experienced data scientists have to try out many different model architectures and hyperparameters in order to generate a model with the right accuracy/cost trade-offs for the problem at hand.  Therefore, it is common in domains such as computer vision and natural language processing to take pre-trained models developed for one setting and apply them to a different, but related setting. This is called transfer learning.

For example, we may have an existing model that can identify dogs and cats, and use it as the basis for training a new model to identify different animals, say, cows and horses.  This exploits the fact that different classes of images may share the same low-level attributes like edges, shapes, and variations in shade and lighting [1].

Convolutional Neural Networks Example 

CNNs are a special kind of neural network that are very good at image classification [2].  Like other neural networks, they consist of input and output layers with several hidden layers.  Convolution layers are dedicated to learning features about the data (edges, shapes, etc.) and subsequent layers are designed for classification (Figure 1).

Figure 1:  Example of a CNN Showing Feature Layers and Classification Layers [3]

The idea behind transfer learning for CNNs is that the feature layers could be shared between different settings, with only the classification layers needing to be re-trained for the new setting.  That is, the weights for the feature layers are frozen and do not need to be re-trained.  This approach can produce accurate models while saving on training time.

Our simple transfer learning example uses the MNIST dataset [4], which is a well-known database of handwritten digits with a training set of 60,000 examples and a test set of 10,000 examples.   Following the example of [5], we start by training a simple CNN to classify the first 5 digits (0, 1, 2, 3, 4). Then we freeze the convolutional feature layers and fine-tune the dense layers for classification of the last 5 digits (5, 6, 7, 8, 9).  

We use the Apache MADlib open source project which supports deep learning using Keras and TensorFlow on Greenplum Database [6].  The Jupyter notebook for this example is available at [7].

Image Preprocessing

Mini-batching gradient descent can perform better than stochastic gradient descent because it uses more than one training example at a time, typically resulting in faster and smoother convergence [8].  After we load training data into a table with one image per row, we need to call the MADlib image preprocessor to pack multiple images into a row for the Keras optimizer to work on mini-batches. For example, for the training examples for the first 5 digits (0, 1, 2, 3, 4) the SQL is:

Model Architecture

We define two groups of layers in the CNN: feature and classification, which are both trainable at this point.  The Keras code in Python to create the model is:

The resulting model looks like:

Now load this model into the model architecture table:

Next we freeze the feature layers to create the transfer model and load it into the model architecture table:

Now the model architecture table contains the two models that we need, one fully trainable and another with only the classification layers trainable:

Training

Train the model for 5-digit classification (0,1,2,3,4) for 5 iterations using model_id=1:

This classifier for the first 5 digits gets to 99.7% accuracy on the test set.  We want to use this model as a starting point for training the classifier for the last 5 digits.  To do this, we copy the trained weights for the first 5 digits from the model table mnist_model into the model architecture table model_arch_library so that the transfer model can use them:

Now that we have the weights, we can train the dense layers for new classification task for 5 iterations using model_id=2:

This classifier for the last 5 digits gets to 99.0% on the test set with the feature layers frozen.  Since there are fewer parameters to train, it runs almost 2x faster than if we had trained the whole model from scratch.  

References:

[1]  Deep Learning, Goodfellow, Bengio and Courville, p. 526.

[2]  Le Cun, Denker, Henderson, Howard, Hubbard and Jackel, Handwritten digit recognition with a back-propagation network, in: Proceedings of the Advances in Neural Information Processing Systems (NIPS), 1989, pp. 396–404.

[3]  https://www.mathworks.com/solutions/deep-learning/convolutional-neural-network.html

[4]  The MNIST database of handwritten digits, http://yann.lecun.com/exdb/mnist/

[5]  MNIST transfer CNN, https://keras.io/examples/mnist_transfer_cnn/

[6]  GPU-Accelerated Deep Learning on Greenplum Database, https://content.pivotal.io/engineers/gpu-accelerated-deep-learning-on-greenplum-database

[7]  Apache MADlib community artifacts, https://github.com/apache/madlib-site/tree/asf-site/community-artifacts

[8]  Neural Networks for Machine Learning, Lectures 6a and 6b on mini-batch gradient descent, Geoffrey Hinton with Nitish Srivastava and Kevin Swersky, http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf

Learning More:

Ready to take the next step? Learn more about Apache MADlib and Pivotal Greenplum:

About the Author

Frank McQuillan

Frank McQuillan is Director of Product Management at Pivotal, focusing on analytics and machine learning for large data sets. Prior to Pivotal, Frank has worked on projects in the areas of robotics, drones, flight simulation, and advertising technology. He holds a Masters degree from the University of Toronto and a Bachelor's degree from the University of Waterloo, both in Mechanical Engineering.

Previous
How to Trim Your API Access Costs with a Cache
How to Trim Your API Access Costs with a Cache

Next
Securing Services with Spring Cloud Gateway
Securing Services with Spring Cloud Gateway

Spring Cloud Gateway is a lightweight independent and decentralized micro-gateway. In this post, we present...