Cassava Leaf Disease Classification Part 3
In our previous posts, we outlined our EDA, baseline models, and data loading procedures. We also introduced the models we were using to classify cassava leaf diseases and provided checkpoints on our procedures, as well as issues we came across. We hope to address these issues in this post. Links to the first two posts are below!
Loading with Raw Data & Applying VGG/ResNet
Due to the imbalance we found in the EDA, we change the splitting method from simple train test split to Stratified Kfold with n_split equals the number of labels. This keeps the ratio for the labels the same in each fold as it is in the full dataset, or the portion of raw data we used. Since there’s no group structure other than the group of labels, Stratified GroupKfold is not used.
The first hyperparameter tuned after finishing the model construction is which optimizer to use. Previously we are using Adam for all models, now we also tried the classic SGD and RMSprop. Overall, the mini-batch gradient descent method does not out run Adam while it is more time consuming. The good part for SGD is that the validation accuracy is more stable during model training. Consider that RMSprop is more of an adaptive moment estimation with momentum term and also does not provide significantly better validation score, our later models will stick to Adam. Note that for RMSprop we have an extra hyperparameter called rho, which is the discounting factor for the gradient. And for Adam, the hyperparameter to tune is called beta_1, the exponential decay rate for the first moment estimate.
In the previous blog we mentioned we are going to try different base model from ResNet, VGG and DenseNet. After doing so we found that the difference is not significant. For example, VGG19 has three more layers than VGG16 but the complicity does not benefit the validation accuracy score. Similarly for ResNet50 with ResNet101 and DenseNet169 with DenseNet201. In addition, we found not freezing the base model tends to provide better result. Dropout is also kept.
More data augmentation, regularization method and learning rate schedule techniques are included in order to perform better result as well as getting rid of the problem of overfitting. To see the effect of these add-ons, we stick to the categorical_crossentroy as the loss and accuracy for the metrics.
Similar learning rate schedule techniques such as exponential decay is also applied to SGD optimizer. Basically, we manually added the feature which is equivalent to beta_1 parameter in Adam, although sometimes this would lead to error message during training, so we also tried different decay rate to compare with beta_1. This is more of to see the difference between two optimizers than hyperparameter tuning itself.
We tried the ReduceLROnPlateau which reduce the learning rate when our chosen monitor “val_loss” does not change for 10 epochs by different factor, on each base model. This is applied since last blog as well as the ImageDataGenerator augmentation and we decide to keep them.
Talking about callback, we also added early stopping in recent runs. However, for one of the run provides a ‘local maximum’ validation accuracy with limited amount of raw images under VGG16 model, we noticed that after the model finds the first local minimum, it got stuck for about 15 epochs and went to the next stage and again stuck for more than 20 epochs. Previously we are using 10 as the patience and do not consider this parameter a hyperparameter to tune. Now we realized that there’s chance the model training might stop early and cause underfitting, so we removed this feature. Instead we eyeballed training process and chose a larger number for epoch.
One thing that came to our mind by observing figures for runs like this as well as talking with our instructor Dan: is the model really stuck at the local minimum or rather we are running out of data? Actually, we never let down our vigilance on that all these model selections and hyperparameter tunings are server for narrowing down the range for our final model. In the precious blog we found that with small amount of data, VGG family provides the best result, following by DenseNet family and ResNet family. However, not even before we tested them in Kevin’s TFRecord structure, during our tuning process, as we increase the number of the data used, epochs to train and the batch_size our computational power can handle, we found DenseNet169 outrun VGG16 and ResNet50 and thus we choose it to be the base model for our final model.
Our team were advised to try run the whole dataset with raw images rather than using TFRecord. The resolving method I provide is to continuously feeds in small portion of data within the secure line for the limited RAM, records the weight after each batch-wise training and we have our model when we finish running all of the data. It will be fun to compare the result from this method with the one using TFRecord structure. Our final model is trained using TFRecord.
Lastly, we updated our plot function, now it has two y scales with ticks, one for accuracy and the other for loss. We also redesigned theme for two groups of lines so it is easier to look at. A tag with arrow was added to show the best validation accuracy for the run. We also added a function to visualize the confusion matrix with heatmap. Most of the confusion matrix shows that class 3 is the easiest class to predict while class 0 is the hardest one. We checked the heatmap after each run to see if it deviates from the regular distribution and try to find the reason for it if so, but we are not including it since it is rather less representative as a result for model building.
Loading with TFRecord & Applying DenseNet
In our previous post, I elaborated on the data loading procedure, facilitated by TFRecord. Using TFRecord allowed me to train the data on more samples. I also talked about my struggles with DenseNet, as its validation accuracy remained constant all throughout training. Almost coincidentally, it converged at a validation accuracy equal to the proportion of the majority label (0.61). Many steps were taken to get DenseNet to learn in the first place, while achieving a decent accuracy score.
Some of the probable causes I attributed to my struggles were data imbalance & shuffling, learning rate, convergence, and overall approach to transfer learning. Ultimately, my struggles were largely due in part to my approach to transfer learning.
Initially, I froze the base DenseNet169 architecture, which was pre-trained with ImageNet weights. I had come to realize that these frozen weights were not compatible with our data, so I unfroze the base model’s convolutional layers. This should allow the model to learn the representations embedded in our images and in a way that is tailored to our data as well. I then followed the implementation of DenseNet around the output layer as outlined in the D2L textbook. This consisted of a BatchNormalization layer, ReLU activation, and Global Average Pooling just before flattening. This further improved my model, and more importantly, the model began learning. To further add complexity in the model, I provided three fully-connected layers after flattening, containing 512, 256, and 64 units each.
Below is the training history for my best-performing model. We see that it achieves a maximum validation accuracy of approximately 0.83 at epoch 13. From that point on, the training accuracy continues to increase, while the validation accuracy slowly decreases. This is an indicator of overfitting. Nonetheless, this model was able to predict the test set with an accuracy of 84.8%.
I then decided to fine-tune the model above by creating blank model and loading the weights from the previous model. I then froze all layers except for the 3 dense layers connecting the flattening to the output. Below is the training history for my fine-tuning.
The behavior with respect to the validation loss/accuracy is rather erratic. The fine-tuned model achieves a maximum accuracy of 0.8475 on the second epoch and decreases from there. The validation loss mimics the same behavior but in the reverse direction. However, there was a slight improvement in test accuracy with the fine-tuned model, with an accuracy of 85.5%.
Overall, simply changing the architecture and adding minimal customizations made a great impact on DenseNet’s performance.
- Explore other DenseNet architectures (DenseNet121, DenseNet201)
- Addressing data imbalances (increasing batch size, stratified splitting with TFRecord)
- Further hyperparameter tuning to prevent overfitting
- More fine-tuning
- Apply other CNN architectures (Inception, EfficientNet, etc.)
Thanks for keeping up with our blog series about Cassava Leaf Disease Classification. Hope you enjoyed it!
Ambalina, L. (2020, April 01). What is the difference Between CNN And rnn? Retrieved March 21, 2021, from https://lionbridge.ai/articles/difference-between-cnn-and-rnn/
Anwar, A. (2021, January 19). Difference between ALEXNET, VGGNet, Resnet and Inception. Retrieved March 21, 2021, from https://towardsdatascience.com/the-w3h-of-alexnet-vggnet-resnet-and-inception-7baaaecccc96
Brownlee, J. (2020, September 11). Understand the impact of learning rate on neural network performance. Retrieved March 21, 2021, from https://machinelearningmastery.com/understand-the-dynamics-of-learning-rate-on-deep-learning-neural-networks/
Gupta, M. (2020, November 07). How to fetch Kaggle datasets into Google Colab. Retrieved March 21, 2021, from https://medium.com/analytics-vidhya/how-to-fetch-kaggle-datasets-into-google-colab-ea682569851a
Isaienkov, Y. (2020, November 28). Cassava leaf disease — exploratory data analysis. Retrieved March 21, 2021, from https://www.kaggle.com/ihelon/cassava-leaf-disease-exploratory-data-analysis
Oliveira, D. (2020, December 02). Cassava leaf disease — tpu tensorflow — training. Retrieved March 21, 2021, from https://www.kaggle.com/dimitreoliveira/cassava-leaf-disease-tpu-tensorflow-training
Setiawan, W. (2020, March). Layers Modification of Convolutional Neural Network for Pneumonia Detection. Retrieved March 21, 2021, from https://www.researchgate.net/publication/340650976_Layers_Modification_of_Convolutional_Neural_Network_for_Pneumonia_Detection
Sindwani, K. (2020, September 24). How to train an image classifier ON Tfrecord files. Retrieved March 21, 2021, from https://towardsdatascience.com/how-to-train-an-image-classifier-on-tfrecord-files-97a98a6a1d7a
Sultanov, J. (2019, December 10). Image dataset with TFRecord files! Retrieved March 21, 2021, from https://medium.com/ai%C2%B3-theory-practice-business/image-dataset-with-tfrecord-files-7188b565bfc
T. (n.d.). TFRecord and tf.train.Example. Retrieved March 21, 2021, from https://tensorflow.google.cn/tutorials/load_data/tfrecord#tfrecord_files_in_python
Team, K. (2020). Keras documentation: Keras applications. Retrieved March 21, 2021, from https://keras.io/api/applications/