Automatic road extraction using deep learning

  • 🔬 Data Science
  • 🥠 Deep Learning and pixel-based classification

Introduction

Road network is a required layer in a lot of mapping exercises, for example in Basemap preparation (critical for navigation), humanitarian aid, disaster management, transportation, and for a lot of other applications it is a critical component.

This sample shows how ArcGIS API for Python can be used to train a deep learning model (Multi-Task Road Extractor model) to extract the road network from satellite imagery. The models trained can be used with ArcGIS Pro or ArcGIS Enterprise and even support distributed processing for quick results.

Further details on the Multi-Task Road Extractor implementation in the API (working principle, architecture, best practices, etc.), can be found in the Guide, along with instructions on how to set up the Python environment.

Before proceeding through this notebook, it is advised to go through the API Reference for Multi-Task Road Extractor (prepare_data(), MultiTaskRoadExtractor()). It will help in understanding the Multi-Task Road Extractor's workflow in detail.

Objectives:

  1. Classify roads, utilizing API's Multi-Task Road Extractor model.

Area of Interest and data pre-processing

For this sample, we will be using a subset of the publically available SpaceNet dataset. Vector labels as 'road centerlines' are available for download along with imagery, hosted on AWS S3 [1].

The area of interest is Paris, with 425 km of 'road centerline' length (As shown in Figure. 1). Both of these inputs, Imagery, and vector layer (for creating image chips and labels as 'classified tiles') are used to create data that is needed for model training.

Figure 1: SpaceNet dataset - AOI 3 - Paris

Downloaded data has 4 types of imagery: Multispectral, Pan, Pan-sharpened Multispectral, Pan-sharpened RGB. 8-bit RGB imagery support and 16-bit RGB imagery experimental support is available with Multi-Task Road Extractor Model (Multispectral imagery will be supported in the subsequent release). In this sample, Pan-sharpened RGB is used, after converting it to 8-bit imagery.

Pre-processing steps:

  • Downloaded vector labels, in .geojson format, are converted to feature class/shapefile. (Refer to ArcGIS Pro's JSON To Features GP tool.)
  • The converted vector data is checked and repaired if any invalid geometry is found. (Refer to ArcGIS Pro's Repair Geometry GP tool.)
  • 'Stretch function' is used to convert 16-bit imagery to 8-bit imagery. (Refer to ArcGIS Pro's Stretch raster function.)
  • 'Projected coordinate system' is applied to imagery and road vector data, for ease in the interpretation of results and setting the values of tool parameters.

Now, the data is ready for Export Training Data For Deep Learning GP tool (As shown in Figure. 2). It is used to export data that will be needed for model training. This tool is available in ArcGIS Pro as well as ArcGIS Enterprise.

Here, we exported the data in 'Classified Tiles' format using a Cell Size of '30 cm'. Tile Size X and Tile Size Y are set to '512', while Stride X and Stride Y are set to '128'. If Road centerlines are directly used as an input, then based on the area of interest and types of roads in that region, the appropriate buffer size can be set. Alternatively, ArcGIS Pro's Create Buffers GP tool can be used to convert road centerlines to road polygons and buffer value can be decided iteratively by checking the results of the Create Buffers GP tool.

Figure 2: Export Training Data For Deep Learning GP tool

This tool will create all the necessary files needed in the next step, at the Output Folder's directory.

Data preparation

Imports:

from arcgis.learn import prepare_data, MultiTaskRoadExtractor

Preparing the exported data:

Some of the frequently used parameters that can be passed in prepare_data() are described below:

path: the path of the folder containing training data. (Output generated by the "Export Training data for deep learning GP tool")

chip_size: Images are cropped to the specified chip_size.

batch_size: No. of images your model will train on each step inside an epoch, it directly depends on the memory of your graphic card.

val_split_pct: Percentage of training data to keep as validation.

resize_to: Resize the cropped image to the mentioned size.

Note: Data meant for 'Try it Live' is a very small subset of the actual data that was used for this sample notebook, so the training time, accuracy, visualization, etc. will change, from what is depicted below.

import os, zipfile
from pathlib import Path
from arcgis.gis import GIS
gis = GIS('home')
training_data = gis.content.get('b7bbf2f5f4184960890afeabbdb51a32')
training_data
automatic_road_extraction_using_deep_learning
Image Collection by api_data_owner
Last Modified: December 04, 2020
0 comments, 0 views
filepath = training_data.download(file_name=training_data.name)
with zipfile.ZipFile(filepath, 'r') as zip_ref:
    zip_ref.extractall(Path(filepath).parent)
output_path = Path(os.path.join(os.path.splitext(filepath)[0]))
data = prepare_data(output_path, chip_size=512, batch_size=4)
data.classes

Visualization of prepared data

show_batch() can be used to show the prepared data. Where input imagery is shown with labels overlayed on them.

alpha is used to control the transparency of labels.

data.show_batch(alpha=1)
<Figure size 576x576 with 4 Axes>

Training the model

First, the Multi-Task Road Extractor model object is created, utilizing the prepared data. Some model-specific advance parameters can be set at this stage.

All of these parameters are optional, as smart 'default values' are already set, which works best in most cases.

The advance parameters are described below:

  • gaussian_thresh: sets the gaussian threshold which allows setting the required road width.
  • orient_bin_size: sets the bin size for orientation angles.
  • orient_theta: sets the width of the orientation mask.
  • mtl_model: It defines two different architectures used to train the Multi-Task Extractor. Values are "linknet" and "hourglass".

While, backbones only work with 'linknet' architecture. ('resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152' are the supported backbones.)

model = MultiTaskRoadExtractor(data, mtl_model="hourglass")

Next, lr_find() function is used to find the optimal learning rate. It controls the rate at which existing information will be overwritten by newly acquired information throughout the training process. If no value is specified, the optimal learning rate will be extracted from the learning curve during the training process.

model.lr_find()
<Figure size 432x288 with 1 Axes>
0.0005754399373371565

fit() is used to train the model, where a new 'optimum learning rate' is automatically computed or the previously computed optimum learning rate can be passed. (Any other user-defined learning rate can also be passed)

If early_stopping is True, then the model training will stop when the model is no longer improving, regardless of the epochs parameter value specified. While an 'epoch' means the dataset will be passed forward and backward through the neural network one time.

miou and dice are the performance metrics, shown after completion of each epoch.

model.fit(50, 0.0005754399373371565, early_stopping=True)
4.20% [21/500 33:38:29<767:20:46]
epochtrain_lossvalid_lossaccuracymioudicetime
00.5484480.5395850.9550070.8139540.7424211:36:12
10.3882700.3987830.9641530.8447060.7817931:35:36
20.2666060.2929860.9751210.8911810.8545221:35:52
30.2217550.2380130.9817580.9202820.8937031:35:08
40.1833460.2056140.9850370.9337060.9095821:35:14
50.1576440.1779760.9876640.9452890.9258151:36:11
60.1424030.1790480.9869430.9419600.9178241:36:39
70.1252070.1755230.9873280.9432370.9173601:35:14
80.1172160.1492060.9899320.9547380.9340741:36:27
90.1186720.1527160.9887260.9497330.9314751:36:10
100.1175240.1412660.9899730.9550450.9391161:35:23
110.1029230.1805050.9855660.9344900.8896061:35:42
120.0980520.1330170.9909300.9590370.9392221:36:56
130.0848680.1284430.9912640.9604690.9380611:35:47
140.0852740.1355330.9906420.9577800.9396271:36:54
150.0794690.1188890.9922810.9649310.9480841:37:21
160.0804570.1170530.9925300.9662090.9497761:36:09
170.0790430.1262190.9914050.9609090.9454181:36:08
180.0766210.1185670.9928380.9674500.9546101:36:14
190.0828200.1318380.9899650.9546870.9339121:36:32
200.0722950.1144150.9929590.9681460.9553991:36:26

100.00% [143/143 09:04<00:00]
Epoch 21: early stopping

Visualization of results

show_results() is used to visualize the results of the model, for the same scene with the ground truth. Validation data is used for this.

  • 1st column is the 'ground truth image' overlayed with its corresponding 'ground truth labels'.
  • 2nd column is the 'ground truth image' overlayed with its corresponding 'predicted labels'.
model.show_results(rows=4)
<Figure size 576x1152 with 8 Axes>

Saving the trained model

The last step, related to training, is saving the model using save(). Here apart from model files, performance metrics, a graph of validation and training losses, sample results, etc are also saved.

model.save('road_model_for_spacenet_data')
WindowsPath('models/road_model_for_spacenet_data')

Inference using the trained model, in ArcGIS Pro

The model saved in the previous step can be used to extract a classified raster using Classify Pixels Using Deep Learning tool (As shown in Figure. 3).

Further, the classified raster can be converted into a vector road layer in ArcGIS Pro. The regularisation related GP tools can be used to remove unwanted artifacts in the output. As the model was trained on a Cell Size of '30 cm', at this step too, the Cell Size is kept equal to '30 cm'.

Figure 3: Classify Pixels Using Deep Learning tool

Conclusion

This notebook has summarized the end-to-end workflow for the training of a deep learning model for road classification. This type of model can predict the roads occluded by small and medium length shadows, however when roads have larger occlusions from clouds/shadows then it is unable to create connected road networks.

References

Your browser is no longer supported. Please upgrade your browser for the best experience. See our browser deprecation post for more details.