Add a new model using Model Extension

Introduction

With arcgis.learn, there are a multitude of machine learning models available for many different tasks. There are models for object detection, pixel classification, image translation, natural language processing, point-cloud data, etc., and the list continues to grow.

However, what if you come across a deep learning model that is not yet a part of the learn module, and you want to use it from its library or its open-source code on GitHub? What if you created your own deep learning model for a specific task you are working on? Finally, what if you want to use these new models with all the capabilities of the ArcGIS ecosystem?

Fortunately, there is a solution - Model Extension, a general-purpose wrapper for any object detection and pixel classification model on top of our existing framework. It wraps all the details of our stack of PyTorch, Fastai, and the learn module and provides an easy to implement structure for the integration of a third-party deep learning model.

Parts of a deep learning model

Let's first look at the important parts of a deep learning model and then we will correlate them with model extension:

  1. Preprocessing of input data - Raw image pixels are usually not the best way to feed the model. As such, we need to transform the input data and the ground truth information in a way that the model can process.

  2. Model Architecture - This part defines the actual model architecture with different types of neural network layers (convolutional layers, fully connected layers, batch-norm, dropout, etc.) and the connections between them.

  3. Loss Calculation - A loss function is needed to calculate how much the model predictions are deviating from the ground truth. This calculation helps in adjusting the weights of the layers that train the model to make better predictions.

  4. Postprocessing of the outputs - This step transforms the output from the model so that it can be visualized or understood by a user.

The parts described provide a higher level perspective that only describes the aspects pertaining to our requirements for model integration. A deep learning model can be much more complex upon delving deeper.

The Custom Model

To be able to use Model Extension, a class needs to be created with a few specific functions. These functions will correspond to the different parts of a deep learning model that we discussed earlier. The class can be given any name, but the required functions need to have the exact names. All of the libraries required need to be imported within this class and used in the functions with a self prefix.

class MyModelName():
    import torch
    import fastai
    import ...

Preprocessing of input data

There are three functions that need to be created to preprocess the data:

  1. on_batch_begin: This function is required to transform the input data and the target (the ground truth) used for training the model. The transformation of inputs is in accordance to the model input requirements. This function is equivalent to the fastai on_batch_begin function, but it is called after it. Therefore, transformation of inputs is needed only if the format required by the model is different from what fastai transforms it into.

The function receives the following arguments:

  • learn - a fastai learner object

  • model_input_batch - fastai transformed batch of input images: tensor of shape [N, C, H, W] with values in the range -1 and 1, where N - batch size C - number of channels (bands) in the image H - height of the image W - width of the image

  • model_target_batch - fastai transformed batch of targets. The targets will be of different type and shape for object detection and pixel classification.

    Object Detection

                     list of tensors [bboxes, classes]
                     
                     bboxes: tensor of shape [N, B, 4], where 
                         N - batch size
                         B - the maximum number of boxes present in any image of the batch
                         4 - the bounding box coordinates in the order y1, x1, y2, x2 and values in the range -1 to 1
                         
                     classes: tensor of shape [N, B] representing class of each bounding box
    

    Pixel Classification

                     tensor of shape [N, K, H, W] representing a binary raster, where
    
                         N - batch size
                         K - number of classes in the dataset
                         H - height of the image
                         W - width of the image
    
def on_batch_begin(self, learn, model_input_batch, model_target_batch):
    """
    Function to transform the input data and the targets in accordance to the model for training.
    """
    
    model_input = ...
    model_target = ...

    return model_input, model_target
  1. transform_input: This function is required to transform the input images during inferencing.

    The function receives the following arguments:

    • xb - fastai transformed batch of input images: tensor of shape [N, C, H, W], where N - batch size C - number of channels (bands) in the image H - height of the image W - width of the image
def transform_input(self, xb):
    """
    Function to transform the inputs for inferencing.
    """
    
    model_input = ...

    return model_input
  1. transform_input_multispectral: This function is similar to transform_input. It is required to transform the input images during inferencing if the data is multispectral imagery. The inputs can be returned as-it-is if multispectral data is not being used.
def transform_input_multispectral(self, xb):
    """
    Function to transform the multispectral inputs for inferencing.
    """
    
    model_input = ...

    return model_input

Model architecture

Next comes the model architecture. The function get_model is used to define the model architecture. Here, you can either create the sequence of neural network layers or import the model defintion from libraries like torchvision, fastai, or a third party repository built using Pytorch.

The required arguments are:

  • data - This is the databunch created using the prepare_data function.
  • kwargs - Any additional parameters required by the model can be used as keyword arguments.
def get_model(self, data, **kwargs):
    """
    Function used to define the model architecture.
    """
    
    model = ...
    self.model = model
    return model

Loss Calculation

During model training, a loss is calculated to ascertain how much the predictions deviate from the ground truth. The loss is then used by the model to backpropagate the required changes in its weights to bring the predictions closer to the target. This function is used to define the loss for the model.

The function receives the following arguments:

  • model_output - the predictions by the model
  • *model_target - the corresponding ground truth (from the one_batch_begin function explained earlier)

Note: In certain models, the loss calculation is a part of the model definition. In such cases, this function can return the model output as it is.

def loss(self, model_output, *model_target):
    """
    Function to define the loss calculations.
    """
    final_loss = ...
    
    return final_loss

Post-processing of the outputs

Typically, the raw output of the model is not interpretable by the user and needs to be post-processed. The post_process function is used to transform the raw outputs of the model to a specific format for the final results and visualization pipeline to ingest.

Object Detection

The function receives the following arguments:

  • pred - Raw output of the model for a batch of images
  • nms_overlap - Non-maxima suppression value used to select from overlapping bounding boxes
  • thres - Confidence threshold to be used to filter the predictions
  • chip_size - Size of the image chips on which predictions are made
  • device - Device (CPU or GPU) on which the output needs to be put after post-processing

Returns:

  • post_processed_pred: List[Tuple(bboxes, labels, scores)] where for each image bboxes - tensor of shape [Number_of_bboxes_in_image, 4] labels - tensor of shape [Number_of_bboxes_in_image,] scores - tensor of shape [Number_of_bboxes_in_image,]
                     The bounding box (bboxes) values need to be in range -1 to 1 and in [y1,x1,y2,x2] format 
    
def post_process(self, pred, nms_overlap, thres, chip_size, device):
        """
        Fuction to post process the output of the model in validation/infrencing mode.
        """
        post_processed_pred = []
        for p in pred:
            # Convert bboxes to range -1 to 1.
            # Convert bboxes to format [y1,x1,y2,x2]
            # Create a tuple of bboxes, labels and scores for each image and append it in a list
            ...
            post_processed_pred.append(...)
            
        return post_processed_pred

Pixel Classification

The function receives the following arguments:

  • pred - Raw output of the model for a batch of images
  • thres - Confidence threshold to be used to filter the predictions

Returns:

  • post_processed_pred: tensor of shape [N, 1, H, W] or a List/Tuple of N tensors of shape [1, H, W], where N - batch size H - height of the image W - width of the image
                     The values (type: LongTensor) of the tensor denote the predicted class of each pixel.
                 
    
def post_process(self, pred, thres):
        """
        Fuction to post process the output of the model in validation/infrencing mode.
        """
        post_processed_pred = ...
            
        return post_processed_pred

Example

Let's put all of this together in an example. For demonstration purpose we have chosen to integrate the FasterRCNN model available in the torchvision library.

class MyFasterRCNN():
    """
    Custom class to integrate FasterRCNN using Model Extension
    """
    
    import torch
    import torchvision
    import fastai


    def on_batch_begin(self, learn, model_input_batch, model_target_batch):
        """
        Function to transform the input data and the targets in accordance to the model for training.
        """
        
        # During training, after each epoch, validation loss is required on validation dataset
        # Torchvision's FasterRCNN model gives losses only in training mode 
        # and therefore, the model is set to train mode
        learn.model.train()

        target_list = []

        # Denormalize from imagenet_stats
        if not learn.data._is_multispectral:
            imagenet_stats = [[0.485, 0.456, 0.406], [0.229, 0.224, 0.225]]
            mean = self.torch.tensor(imagenet_stats[0], dtype=self.torch.float32).to(model_input_batch.device)
            std  = self.torch.tensor(imagenet_stats[1], dtype=self.torch.float32).to(model_input_batch.device)
            model_input_batch = (model_input_batch.permute(0, 2, 3, 1)*std + mean).permute(0, 3, 1, 2)
        
        for bbox, label in zip(*model_target_batch):
            
            # FasterRCNN model require bboxes with values between 0 and H and 0 and W.
            bbox = ((bbox+1)/2)*learn.data.chip_size 
            
            # FasterRCNN require target of each image in the format of a dictionary.
            target = {} 
            
            # Handle images without any bboxes.
            if bbox.nelement() == 0:        
                bbox = self.torch.tensor([[0.,0.,0.,0.]]).to(learn.data.device)
                label = self.torch.tensor([0]).to(learn.data.device)
            
            # FasterRCNN require the format of bboxes as [x1,y1,x2,y2].
            bbox = self.torch.index_select(bbox, 1, self.torch.tensor([1,0,3,2]).to(learn.data.device))
            
            # FasterRCNN require batch of targets in the form of a list of dictionaries.
            target["boxes"] = bbox
            target["labels"] = label
            target_list.append(target)
        
        # FasterRCNN require model input with images and coresponding targets in training mode to return the losses
        # therefore, append the targets in model_input itself.
        model_input = [list(model_input_batch), target_list]
        
        # Model target is not required in traing mode so just return the same model_target to train the model.
        model_target = model_target_batch

        return model_input, model_target
  

    def transform_input(self, xb, thresh=0.5, nms_overlap=0.1):
        """
        Function to transform the inputs for inferencing.
        """
        
        # Storing the original threshold and nms overlap values to restore later and applying the user provided values
        self.nms_thres = self.model.roi_heads.nms_thresh
        self.thresh = self.model.roi_heads.score_thresh
        self.model.roi_heads.nms_thresh = nms_overlap
        self.model.roi_heads.score_thresh = thresh

        # Denormalize from imagenet_stats
        imagenet_stats = [[0.485, 0.456, 0.406], [0.229, 0.224, 0.225]]
        mean = self.torch.tensor(imagenet_stats[0], dtype=self.torch.float32).to(xb.device)
        std  = self.torch.tensor(imagenet_stats[1], dtype=self.torch.float32).to(xb.device)
        xb = (xb.permute(0, 2, 3, 1)*std + mean).permute(0, 3, 1, 2)
        
        # Model inputs are required in the list format
        return list(xb)
    

    def transform_input_multispectral(self, xb, thresh=0.5, nms_overlap=0.1):
        """
        Function to transform the multispectral inputs for inferencing.
        """

        # Storing the original threshold and nms overlap values to restore later and applying the user provided values
        self.nms_thres = self.model.roi_heads.nms_thresh
        self.thresh = self.model.roi_heads.score_thresh
        self.model.roi_heads.nms_thresh = nms_overlap
        self.model.roi_heads.score_thresh = thresh
        
        # Model inputs are required in the list format
        return list(xb)
    

    def get_model(self, data, backbone=None, **kwargs):
        """
        Function that defines the model architecture using FasterRCNN model available in the torchvision library.
        An option to select different backbones from the resnet family has also been added.
        """
        
        self.fasterrcnn_kwargs, kwargs = self.fastai.core.split_kwargs_by_func(kwargs,
                                                                         self.torchvision.models.detection.FasterRCNN.__init__)
        if backbone is None:
            backbone = self.torchvision.models.resnet50

        elif type(backbone) is str:
            if hasattr(self.torchvision.models, backbone):
                backbone = getattr(self.torchvision.models, backbone)
            elif hasattr(self.torchvision.models.detection, backbone):
                backbone = getattr(self.torchvision.models.detection, backbone)
        else:
            backbone = backbone
        pretrained_backbone = kwargs.get('pretrained_backbone', True)
        assert type(pretrained_backbone) == bool
        if backbone.__name__ == 'resnet50':
            model = self.torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=pretrained_backbone,
                                                                              min_size = 1.5*data.chip_size,
                                                                              max_size = 2*data.chip_size,
                                                                              **self.fasterrcnn_kwargs)
        elif backbone.__name__ in ['resnet18','resnet34']:
            backbone_small = self.fastai.vision.learner.create_body(backbone, pretrained=pretrained_backbone)
            backbone_small.out_channels = 512
            model = self.torchvision.models.detection.FasterRCNN(backbone_small,
                                                                 91,
                                                                 min_size = 1.5*data.chip_size,
                                                                 max_size = 2*data.chip_size,
                                                                 **self.fasterrcnn_kwargs)
        else:
            backbone_fpn = self.torchvision.models.detection.backbone_utils.resnet_fpn_backbone(
                backbone.__name__,
                pretrained = pretrained_backbone
            )
            model = self.torchvision.models.detection.FasterRCNN(backbone_fpn,
                                                                 91,
                                                                 min_size = 1.5*data.chip_size,
                                                                 max_size = 2*data.chip_size,
                                                                 **self.fasterrcnn_kwargs)

        in_features = model.roi_heads.box_predictor.cls_score.in_features
        model.roi_heads.box_predictor = self.torchvision.models.detection.faster_rcnn.FastRCNNPredictor(in_features, len(data.classes))
        
        if data._is_multispectral:
            model.transform.image_mean = [0]*len(data._extract_bands)
            model.transform.image_std = [1]*len(data._extract_bands)

        self.model = model

        return model


    def loss(self, model_output, *model_target):
        """
        Function to define the loss calculations.
        """
        
        # Torchvision's FasterRCNN model itself returns the loss in training mode.
        # Therefore, we don't need to re-define the loss calculation here.
        final_loss = 0.
        for i in model_output.values():
            i[self.torch.isnan(i)] = 0.
            i[self.torch.isinf(i)] = 0.
            final_loss += i
        
        return final_loss
   

    def post_process(self, pred, nms_overlap, thres, chip_size, device):
        """
        Function to post process the output of the model in validation/infrencing mode.
        """
        
        # Restoring the original threshold and nms_thresh
        self.model.roi_heads.score_thresh = self.thresh
        self.model.roi_heads.nms_thresh = self.nms_thres
        # Torchvision's FasterRCNN returns the otput after applying confidence threshold and nms threshold,
        # therefore, `nms_overlap` and `thresh` are not used in this function.

        post_processed_pred = []
        
        for p in pred:
            bbox, label, score = p["boxes"], p["labels"], p["scores"]
            # Convert bboxes to range -1 to 1.
            bbox = bbox/(chip_size/2) - 1
            # Convert bboxes to format [y1,x1,y2,x2]
            bbox = self.torch.index_select(bbox, 1, self.torch.tensor([1,0,3,2]).to(bbox.device))
            # Create a tuple of bboxes, labels and scores for each image and append it in a list
            post_processed_pred.append((bbox.data.to(device), label.to(device), score.to(device)))
            
        return post_processed_pred

Once we have created our model class, we need to save it as a Python script file. The model class will then need to be imported from the file using the Python statement from fasterrcnn import MyFasterRCNN. Note, here the file fasterrcnn.py is present in the current working directory, which is suggested. We need to provide the correct path depending on where we saved the file.

Next, we'll need to import prepare_data function and the ModelExtension class from arcgis.learn. The prepare_data function will be used to create a fastai databunch and ModelExtension class will be used to initialize our custom model.

from fasterrcnn import MyFasterRCNN

from arcgis.learn import prepare_data, ModelExtension

data = prepare_data(r'path\to\data')

model = ModelExtension(data, MyFasterRCNN)

From here, we can continue with the usual workflow of using an arcgis.learn deep learning model. Refer to the Detecting Swimming Pools using Deep Learning sample sample notebook to see the workflow for an object detection model and the Land Conver Classification using Satellite Imagery and Deep Learning sample notebook for the workflow for a pixel classification model.

Your browser is no longer supported. Please upgrade your browser for the best experience. See our browser deprecation post for more details.