Track objects using SiamMask
Object Tracking is a methodology that helps to monitor the location of objects over a sequence of video frames.
SiamMask is a deep learning model architecture which performs both Visual Object Tracking (VOT) and semi-supervised Video Object Segmentation (VOS). Given the location of the object in the first frame of the sequence, the aim of VOT is to estimate an object's position in subsequent frames with the best possible accuracy. Similarly, the main goal of VOS is to output a binary segmentation mask which expresses whether or not a pixel belongs to the target. In other words, SiamMask takes as input a single object bounding box for initialization and outputs segmentation mask and object bounding box for each subsequent frame of a video.
SiamMask improves over its siamese-network based predecessors by adding a new branch to produce a pixel-wise binary mask. As depicted below, there is a three-branch variant and a two-branch variant.
Backbone SiamMask uses ResNet-50 as backbone. The architecture depicted below uses the first 4 stages of ResNet, adjust layer and depth-wise cross-correlation resulting in a feature map of size 17×17.
Network heads The conv5 block in the architecture contains a normalisation layer and ReLU non-linearity activation layer while conv6 only consists of a 1×1 convolutional layer.
Refinement This module merges low and high resolution features using multiple refinement steps making use of upsampling layers and skip connections.
Import the SiamMask class from arcgis.learn module.
from arcgis.learn import SiamMask
To use DAVIS pretrained weights, instantiate model object as follows:
ot = SiamMask()
Note: the model must be initialized without providing any data. Because we are not training the model and instead using the pre-trained weights, we do not require a databunch. The initialized model can be used to track object .
Prepare databunch for SiamMask model using prepare_data() in arcgis.learn.
When we have data in Youtube_VOS dataset format, we can call the prepare_data function with dataset_type='ObjectTracking' and for better results use batch_size=64.
from arcgis.learn import prepare_data data = prepare_data(r"path_to_data_folder", dataset_type="ObjectTracking", batch_size=64)
Once the data is prepared, SiamMask model object can be instantiated as follows:
ot = SiamMask(data)
To use the model in ArcGIS Pro, pass an additional parameter
framework set to "torchscript". Doing so will create additional model files inside 'torch_scripts' folder, which can be loaded and used in ArcGIS Pro.
init method helps in initializing objects using bounding boxes.
tracks = ot.init(img, [[x,y,w,h]], [["truck"]])
- The parameters to be passed are as follows:
frame: Required numpy array. Frame from the video used to initialize object(s) to track.
detections: Required list. A list of bounding boxes to intialize object(s). Eg: [[x, y, w, h]]
x, y, w, h represents value of x-cordinate, y-cordinate, width of bbox, height of bbox respectively.
labels: Optional list. A list of labels that represents the class of object(s).
reset: Optional Boolean. If set to True all the previous track(s) will get reset.
The method returns list of initialized tracks.
Note: The length of detections should match to the length of labels.
update method helps in updating the tracks in the next frames.
state = ot.update(frame)
- The parameters to be passed are as follows:
frame: Required numpy array. Frame from the video used to update the track(s) of object(s).
The method returns list of updated tracks.
The sample code below depicts usage of SiamMask model using ArcGIS API for Python.
- Execute the cell below to play the video.
- Press spacebar to pause the video.
- Use your mouse to annotate the object in the frame.
- Once annotated, press spacebar to track the object.
- Press q to quit.
import numpy as np import cv2 from arcgis.learn import SiamMask ot = SiamMask.from_model("path_to_save_model") cap = cv2.VideoCapture(r"path_to_video_file") initialized = False while(True): ret, frame = cap.read() if ret is False: break if initialized: state = ot.update(frame) ## Update the track location in the frame for track in state: mask = track.mask frame[:, :, 2] = (mask > 0) * 255 + (mask == 0) * frame[:, :, 2] cv2.polylines(frame, [np.int0(track.location).reshape((-1, 1, 2))], True, (w, 255, h), 1) cv2.imshow('frame',frame) key = cv2.waitKey(1) if key & 0xFF == ord('q'): break if key == 32: init_rect = cv2.selectROI('frame', frame, False, False) values = np.array(init_rect) if all(values == 0): continue x, y, w, h = init_rect state = ot.init(frame, [[x,y,w,h]]) ## Initialize the track in the frame initialized = True cv2.waitKey() cap.release() cv2.destroyAllWindows()
ArcGIS Pro 2.8 only supports model saved using ArcGIS API for Python v1.8.5. If you are using ArcGIS API for Python v1.9.0 to train the model, follow the steps below:
Step 1: Train and save the model
ot = SiamMask(data) ot.load("path_to_emd_file") ot.fit(10) ot.save("path_to_save_model")
Step 2: Load and save the model using ArcGIS Pro 2.8 python environment
ot = SiamMask.from_model("path_to_emd_file") ot.save("path_to_save_model", framework="torchscript")
Step 3: Load the model, which is saved inside 'torch_scripts' folder, in ArcGIS Pro 2.8 and follow Object tracking in motion imagery to track object
Fast Online Object Tracking and Segmentation: A Unifying Approach https://arxiv.org/abs/1812.05050
Object tracking in motion imagery https://pro.arcgis.com/en/pro-app/latest/help/analysis/image-analyst/object-tracking-in-motion-imagery.htm