Introduction
With arcgis.learn, there are a multitude of machine learning models available for many different tasks. There are models for object detection, pixel classification, image translation, natural language processing, pointcloud data, etc., and the list continues to grow.
However, what if you come across a deep learning model that is not yet a part of the learn module, and you want to use it from its library or its opensource code on GitHub? What if you created your own deep learning model for a specific task you are working on? Finally, what if you want to use these new models with all the capabilities of the ArcGIS ecosystem?
Fortunately, there is a solution  Model Extension, a generalpurpose wrapper for any object detection and pixel classification model on top of our existing framework. It wraps all the details of our stack of PyTorch, Fastai, and the learn module and provides an easy to implement structure for the integration of a thirdparty deep learning model.
Parts of a deep learning model
Let's first look at the important parts of a deep learning model and then we will correlate them with model extension:

Preprocessing of input data  Raw image pixels are usually not the best way to feed the model. As such, we need to transform the input data and the ground truth information in a way that the model can process.

Model Architecture  This part defines the actual model architecture with different types of neural network layers (convolutional layers, fully connected layers, batchnorm, dropout, etc.) and the connections between them.

Loss Calculation  A loss function is needed to calculate how much the model predictions are deviating from the ground truth. This calculation helps in adjusting the weights of the layers that train the model to make better predictions.

Postprocessing of the outputs  This step transforms the output from the model so that it can be visualized or understood by a user.
The parts described provide a higher level perspective that only describes the aspects pertaining to our requirements for model integration. A deep learning model can be much more complex upon delving deeper.
The Custom Model
To be able to use Model Extension, a class needs to be created with a few specific functions. These functions will correspond to the different parts of a deep learning model that we discussed earlier. The class can be given any name, but the required functions need to have the exact names. All of the libraries required need to be imported within this class and used in the functions with a self
prefix.
class MyModelName():
import torch
import fastai
import ...
Preprocessing of input data
There are three functions that need to be created to preprocess the data:
on_batch_begin
: This function is required to transform the input data and the target (the ground truth) used for training the model. The transformation of inputs is in accordance to the model input requirements. This function is equivalent to the fastai on_batch_begin function, but it is called after it. Therefore, transformation of inputs is needed only if the format required by the model is different from what fastai transforms it into.
The function receives the following arguments:

learn
 a fastai learner object 
model_input_batch
 fastai transformed batch of input images: tensor of shape [N, C, H, W] with values in the range 1 and 1, where N  batch size C  number of channels (bands) in the image H  height of the image W  width of the image 
model_target_batch
 fastai transformed batch of targets. The targets will be of different type and shape for object detection and pixel classification.Object Detection
list of tensors [bboxes, classes] bboxes: tensor of shape [N, B, 4], where N  batch size B  the maximum number of boxes present in any image of the batch 4  the bounding box coordinates in the order y1, x1, y2, x2 and values in the range 1 to 1 classes: tensor of shape [N, B] representing class of each bounding box
Pixel Classification
tensor of shape [N, K, H, W] representing a binary raster, where N  batch size K  number of classes in the dataset H  height of the image W  width of the image
def on_batch_begin(self, learn, model_input_batch, model_target_batch):
"""
Function to transform the input data and the targets in accordance to the model for training.
"""
model_input = ...
model_target = ...
return model_input, model_target

transform_input
: This function is required to transform the input images during inferencing.The function receives the following arguments:
xb
 fastai transformed batch of input images: tensor of shape [N, C, H, W], where N  batch size C  number of channels (bands) in the image H  height of the image W  width of the image
def transform_input(self, xb):
"""
Function to transform the inputs for inferencing.
"""
model_input = ...
return model_input
transform_input_multispectral
: This function is similar totransform_input
. It is required to transform the input images during inferencing if the data is multispectral imagery. The inputs can be returned asitis if multispectral data is not being used.
def transform_input_multispectral(self, xb):
"""
Function to transform the multispectral inputs for inferencing.
"""
model_input = ...
return model_input
Model architecture
Next comes the model architecture. The function get_model
is used to define the model architecture. Here, you can either create the sequence of neural network layers or import the model defintion from libraries like torchvision, fastai, or a third party repository built using Pytorch.
The required arguments are:
data
 This is the databunch created using the prepare_data function.kwargs
 Any additional parameters required by the model can be used as keyword arguments.
def get_model(self, data, **kwargs):
"""
Function used to define the model architecture.
"""
model = ...
self.model = model
return model
Loss Calculation
During model training, a loss is calculated to ascertain how much the predictions deviate from the ground truth. The loss is then used by the model to backpropagate the required changes in its weights to bring the predictions closer to the target. This function is used to define the loss for the model.
The function receives the following arguments:
model_output
 the predictions by the model*model_target
 the corresponding ground truth (from the one_batch_begin function explained earlier)
Note: In certain models, the loss calculation is a part of the model definition. In such cases, this function can return the model output as it is.
def loss(self, model_output, *model_target):
"""
Function to define the loss calculations.
"""
final_loss = ...
return final_loss
Postprocessing of the outputs
Typically, the raw output of the model is not interpretable by the user and needs to be postprocessed. The post_process
function is used to transform the raw outputs of the model to a specific format for the final results and visualization pipeline to ingest.
Object Detection
The function receives the following arguments:
pred
 Raw output of the model for a batch of imagesnms_overlap
 Nonmaxima suppression value used to select from overlapping bounding boxesthres
 Confidence threshold to be used to filter the predictionschip_size
 Size of the image chips on which predictions are madedevice
 Device (CPU or GPU) on which the output needs to be put after postprocessing
Returns:
post_processed_pred
: List[Tuple(bboxes, labels, scores)] where for each image bboxes  tensor of shape [Number_of_bboxes_in_image, 4] labels  tensor of shape [Number_of_bboxes_in_image,] scores  tensor of shape [Number_of_bboxes_in_image,]The bounding box (bboxes) values need to be in range 1 to 1 and in [y1,x1,y2,x2] format
def post_process(self, pred, nms_overlap, thres, chip_size, device):
"""
Fuction to post process the output of the model in validation/infrencing mode.
"""
post_processed_pred = []
for p in pred:
# Convert bboxes to range 1 to 1.
# Convert bboxes to format [y1,x1,y2,x2]
# Create a tuple of bboxes, labels and scores for each image and append it in a list
...
post_processed_pred.append(...)
return post_processed_pred
Pixel Classification
The function receives the following arguments:
pred
 Raw output of the model for a batch of imagesthres
 Confidence threshold to be used to filter the predictions
Returns:
post_processed_pred
: tensor of shape [N, 1, H, W] or a List/Tuple of N tensors of shape [1, H, W], where N  batch size H  height of the image W  width of the imageThe values (type: LongTensor) of the tensor denote the predicted class of each pixel.
def post_process(self, pred, thres):
"""
Fuction to post process the output of the model in validation/infrencing mode.
"""
post_processed_pred = ...
return post_processed_pred
Example
Let's put all of this together in an example. For demonstration purpose we have chosen to integrate the FasterRCNN model available in the torchvision library.
class MyFasterRCNN():
"""
Custom class to integrate FasterRCNN using Model Extension
"""
import torch
import torchvision
import fastai
def on_batch_begin(self, learn, model_input_batch, model_target_batch):
"""
Function to transform the input data and the targets in accordance to the model for training.
"""
# During training, after each epoch, validation loss is required on validation dataset
# Torchvision's FasterRCNN model gives losses only in training mode
# and therefore, the model is set to train mode
learn.model.train()
target_list = []
# Denormalize from imagenet_stats
if not learn.data._is_multispectral:
imagenet_stats = [[0.485, 0.456, 0.406], [0.229, 0.224, 0.225]]
mean = self.torch.tensor(imagenet_stats[0], dtype=self.torch.float32).to(model_input_batch.device)
std = self.torch.tensor(imagenet_stats[1], dtype=self.torch.float32).to(model_input_batch.device)
model_input_batch = (model_input_batch.permute(0, 2, 3, 1)*std + mean).permute(0, 3, 1, 2)
for bbox, label in zip(*model_target_batch):
# FasterRCNN model require bboxes with values between 0 and H and 0 and W.
bbox = ((bbox+1)/2)*learn.data.chip_size
# FasterRCNN require target of each image in the format of a dictionary.
target = {}
# Handle images without any bboxes.
if bbox.nelement() == 0:
bbox = self.torch.tensor([[0.,0.,0.,0.]]).to(learn.data.device)
label = self.torch.tensor([0]).to(learn.data.device)
# FasterRCNN require the format of bboxes as [x1,y1,x2,y2].
bbox = self.torch.index_select(bbox, 1, self.torch.tensor([1,0,3,2]).to(learn.data.device))
# FasterRCNN require batch of targets in the form of a list of dictionaries.
target["boxes"] = bbox
target["labels"] = label
target_list.append(target)
# FasterRCNN require model input with images and coresponding targets in training mode to return the losses
# therefore, append the targets in model_input itself.
model_input = [list(model_input_batch), target_list]
# Model target is not required in traing mode so just return the same model_target to train the model.
model_target = model_target_batch
return model_input, model_target
def transform_input(self, xb, thresh=0.5, nms_overlap=0.1):
"""
Function to transform the inputs for inferencing.
"""
# Storing the original threshold and nms overlap values to restore later and applying the user provided values
self.nms_thres = self.model.roi_heads.nms_thresh
self.thresh = self.model.roi_heads.score_thresh
self.model.roi_heads.nms_thresh = nms_overlap
self.model.roi_heads.score_thresh = thresh
# Denormalize from imagenet_stats
imagenet_stats = [[0.485, 0.456, 0.406], [0.229, 0.224, 0.225]]
mean = self.torch.tensor(imagenet_stats[0], dtype=self.torch.float32).to(xb.device)
std = self.torch.tensor(imagenet_stats[1], dtype=self.torch.float32).to(xb.device)
xb = (xb.permute(0, 2, 3, 1)*std + mean).permute(0, 3, 1, 2)
# Model inputs are required in the list format
return list(xb)
def transform_input_multispectral(self, xb, thresh=0.5, nms_overlap=0.1):
"""
Function to transform the multispectral inputs for inferencing.
"""
# Storing the original threshold and nms overlap values to restore later and applying the user provided values
self.nms_thres = self.model.roi_heads.nms_thresh
self.thresh = self.model.roi_heads.score_thresh
self.model.roi_heads.nms_thresh = nms_overlap
self.model.roi_heads.score_thresh = thresh
# Model inputs are required in the list format
return list(xb)
def get_model(self, data, backbone=None, **kwargs):
"""
Function that defines the model architecture using FasterRCNN model available in the torchvision library.
An option to select different backbones from the resnet family has also been added.
"""
self.fasterrcnn_kwargs, kwargs = self.fastai.core.split_kwargs_by_func(kwargs,
self.torchvision.models.detection.FasterRCNN.__init__)
if backbone is None:
backbone = self.torchvision.models.resnet50
elif type(backbone) is str:
if hasattr(self.torchvision.models, backbone):
backbone = getattr(self.torchvision.models, backbone)
elif hasattr(self.torchvision.models.detection, backbone):
backbone = getattr(self.torchvision.models.detection, backbone)
else:
backbone = backbone
pretrained_backbone = kwargs.get('pretrained_backbone', True)
assert type(pretrained_backbone) == bool
if backbone.__name__ == 'resnet50':
model = self.torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=pretrained_backbone,
min_size = 1.5*data.chip_size,
max_size = 2*data.chip_size,
**self.fasterrcnn_kwargs)
elif backbone.__name__ in ['resnet18','resnet34']:
backbone_small = self.fastai.vision.learner.create_body(backbone, pretrained=pretrained_backbone)
backbone_small.out_channels = 512
model = self.torchvision.models.detection.FasterRCNN(backbone_small,
91,
min_size = 1.5*data.chip_size,
max_size = 2*data.chip_size,
**self.fasterrcnn_kwargs)
else:
backbone_fpn = self.torchvision.models.detection.backbone_utils.resnet_fpn_backbone(
backbone.__name__,
pretrained = pretrained_backbone
)
model = self.torchvision.models.detection.FasterRCNN(backbone_fpn,
91,
min_size = 1.5*data.chip_size,
max_size = 2*data.chip_size,
**self.fasterrcnn_kwargs)
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = self.torchvision.models.detection.faster_rcnn.FastRCNNPredictor(in_features, len(data.classes))
if data._is_multispectral:
model.transform.image_mean = [0]*len(data._extract_bands)
model.transform.image_std = [1]*len(data._extract_bands)
self.model = model
return model
def loss(self, model_output, *model_target):
"""
Function to define the loss calculations.
"""
# Torchvision's FasterRCNN model itself returns the loss in training mode.
# Therefore, we don't need to redefine the loss calculation here.
final_loss = 0.
for i in model_output.values():
i[self.torch.isnan(i)] = 0.
i[self.torch.isinf(i)] = 0.
final_loss += i
return final_loss
def post_process(self, pred, nms_overlap, thres, chip_size, device):
"""
Function to post process the output of the model in validation/infrencing mode.
"""
# Restoring the original threshold and nms_thresh
self.model.roi_heads.score_thresh = self.thresh
self.model.roi_heads.nms_thresh = self.nms_thres
# Torchvision's FasterRCNN returns the otput after applying confidence threshold and nms threshold,
# therefore, `nms_overlap` and `thresh` are not used in this function.
post_processed_pred = []
for p in pred:
bbox, label, score = p["boxes"], p["labels"], p["scores"]
# Convert bboxes to range 1 to 1.
bbox = bbox/(chip_size/2)  1
# Convert bboxes to format [y1,x1,y2,x2]
bbox = self.torch.index_select(bbox, 1, self.torch.tensor([1,0,3,2]).to(bbox.device))
# Create a tuple of bboxes, labels and scores for each image and append it in a list
post_processed_pred.append((bbox.data.to(device), label.to(device), score.to(device)))
return post_processed_pred
Once we have created our model class, we need to save it as a Python script file. The model class will then need to be imported from the file using the Python statement from fasterrcnn import MyFasterRCNN
. Note, here the file fasterrcnn.py
is present in the current working directory, which is suggested. We need to provide the correct path depending on where we saved the file.
Next, we'll need to import prepare_data
function and the ModelExtension
class from arcgis.learn
. The prepare_data
function will be used to create a fastai databunch and ModelExtension
class will be used to initialize our custom model.
from fasterrcnn import MyFasterRCNN
from arcgis.learn import prepare_data, ModelExtension
data = prepare_data(r'path\to\data')
model = ModelExtension(data, MyFasterRCNN)
From here, we can continue with the usual workflow of using an arcgis.learn
deep learning model. Refer to the Detecting Swimming Pools using Deep Learning sample sample notebook to see the workflow for an object detection model and the Land Conver Classification using Satellite Imagery and Deep Learning sample notebook for the workflow for a pixel classification model.