Plant species identification using a TensorFlow-Lite model within mobile devices
Deep Learning models are huge and requires high computation for inferencing. Can we train Deep Learning models which require less computation power, are smaller in size and can be deployed on mobile phones? Well, the answer is 'yes'. With the integration of capability to train TensorFlow lite models with ArcGIS API for Python, we can now train DL models that can be deployed on mobile devices and are smaller in size.
Where can we use them? We can use them up to train multiple DL models to perform classification tasks specifically for mobile devices. One such integration we did is in the "Survey123" application which is a simple and intuitive form-centric data gathering solution being used by multiple surveyors while performing ground surveys, where we integrated a tf-lite model to classify different plant species while clicking it's picture in the app.
This notebook intends to showcase this capability to train a deep learning model that can be used in mobile applications for a real time inferencing using TensorFlow Lite framework. As an example, we will train the same plant species classification model which was discussed earlier but with a smaller dataset.
PlantCLEF data is available in three sets:
- a “trusted” training set based on the online collaborative Encyclopedia Of Life (EoL) .
- A ”noisy” training set (obtained from Google and Bing image search results, including mislabeled or irrelevant images .
- The previous years (2015-2016) images depicting only a subset of the species .
For this notebook, we have taken a subset from the "trusted" training set based on the online collaborative Encyclopedia Of Life  with 39,354 images belonging to 100 plant species and changed their specie numbers with specie names, as an example specie number '42' is changed to 'Acanthus mollis'. The information about the specie name is present in the "xml" file present along with each image file. We wrote a script to perform the specie name and specie number mapping. To know how we have done this, please have a look at the script here.
Use the following command to run the downloaded script. It requires three arguments to be passed:
- path to downloaded PlantCLEF data
- path of the destination folder
We will train our model using
arcgis.learn module within ArcGIS API for Python.
arcgis.learn contains tools and deep learning capabilities required for this study. A detailed documentation to install and setup the environment is available here.
Firstly, we need to set the environment variable for ArcGIS to enable TensorFlow as backend. To perform this, we can set
ARCGIS_ENABLE_TF_BACKEND parameter's value to 1 as shown below.
import os from pathlib import Path from arcgis.gis import GIS from arcgis.learn import prepare_data, FeatureClassifier
gis = GIS('home')
training_data = gis.content.get('81932a51f77b4d2d964218a7c5a4af17') training_data
filepath = training_data.download(file_name=training_data.name)
import zipfile with zipfile.ZipFile(filepath, 'r') as zip_ref: zip_ref.extractall(Path(filepath).parent)
data_path = Path(os.path.join(os.path.splitext(filepath)))
from glob import glob from PIL import Image
for image_filepath in glob(os.path.join(data_path, 'images', '**','*.jpg')): if Image.open(image_filepath).mode != 'RGB': os.remove(image_filepath)
We will now use the
prepare_data() function to apply various types of transformations and augmentations on the training data. These augmentations enable us to train a better model with limited data and also prevent the model from overfitting.
Here, we are passing 3 parameters to the
path: path of folder containing training data.
chip_size: Same as per specified while exporting training data.
batch_size: No. of images your model will train on each step inside an epoch, it directly depends on the memory of your graphic card and the type of model which you are working with. For this sample, a batch size of 64 worked for us on a GPU with 11GB memory.
data = prepare_data( path=data_path, dataset_type='Imagenet', batch_size=64, chip_size=300 )
To make sense of training data we will use the
show_batch() method in arcgis.learn.
show_batch() randomly picks a few samples from the training data and visualizes them.
rows: No of rows we want to see the results for.
arcgis.learn provides capabilities to determine class of each feature in the form of
FeatureClassifier model. To have an in-depth information about it's working and usage, have a look at this link.
As we are training a model to be deployed on mobile phones, we must define the model with "tensorflow" backend. In order to do that we can set the parameter
backend to "tensorflow".
model = FeatureClassifier(data, backbone='MobileNetV2', backend='tensorflow')
Learning rate is one of the most important hyperparameters in model training. Here, we explore a range of learning rates to guide us to choose the best one.
arcgis.learn leverages fast.ai’s learning rate finder to find an optimum learning rate for training models. We can use the
lr_find() method to find the optimum learning rate at which can train a robust model fast enough.
lr = model.lr_find()
Based on the learning rate plot above, we can see that the learning rate suggested by
lr_find() for our training data is 0.000691831. We can use it to train our model. In the latest release of
arcgis.learn we can train models without even specifying a learning rate. That internally uses the learning rate finder to find an optimal learning rate and uses it.
To train the model, we use the
fit() method. To start, we will use 25 epochs to train our model. Epoch defines how many times model is exposed to entire training set.
The code below will pick a few random samples and show us ground truth and respective model predictions side by side. This allows us to validate the results of your model in the notebook itself. Once satisfied, we can save the model and use it further in our workflow.
Here a subset of ground truth from training data is visualized along with the predictions from the model. As we can see, our model is performing well and the predictions are comparable to the ground truth.
We will save the model which we trained in a tf-lite format.
We will use the
save() method to save the trained model. By default, it will be saved to the 'models' sub-folder within our training data folder.
The tf-lite model can now be deployed on mobile devices. Survey123 for ArcGIS has an upcoming feature that integrates such tf-lite models. To learn more on deploying this model in Survey123, join the Early Adopter Community to access the Survey123 private beta.