Generating rgb imagery from digital surface model using Pix2Pix

  • 🔬 Data Science
  • 🥠 Deep Learning and image translation

Introduction

In this notebook, we will focus on using Pix2Pix [1], which is one of the famous and sucessful deep learning models used for paired image-to-image translation. In geospatial sciences, this approach could help in wide range of applications traditionally not possible, where we may want to go from one domain of images to another.

The aim of this notebook is to make use of arcgis.learn Pix2Pix model to translate or convert the gray-scale DSM to a RGB imagery. For more details about model and its working refer How Pix2Pix works ? in guide section.

Necessary imports

Input
import os, zipfile
from pathlib import Path
from os import listdir
from os.path import isfile, join

from arcgis import GIS
from arcgis.learn import Pix2Pix, prepare_data

Connect to your GIS

Input
# gis = GIS('home')
ent_gis = GIS('https://pythonapi.playground.esri.com/portal', 'arcgis_python', 'amazing_arcgis_123')

Export image domain data

For this usecase, we have a high-resolution NAIP airborne imagery in the form of IR-G-B tiles and lidar data converted into DSM, collected over St. George, state of utah by state of utah and partners [5] with same spatial resolution of 0.5 m. We will export that using “Export_Tiles” metadata format available in the Export Training Data For Deep Learning tool. This tool is available in ArcGIS Pro as well as ArcGIS Image Server. The various inputs required by the tool, are described below.

  • Input Raster: DSM
  • Additional Input Raster: NAIP airborne imagery
  • Tile Size X & Tile Size Y: 256
  • Stride X & Stride Y: 128
  • Meta Data Format: 'Export_Tiles' as we are training a Pix2Pix model.
  • Environments: Set optimum Cell Size, Processing Extent.

Raster's used for exporting the training dataset are provided below

Input
naip_domain_b_raster = ent_gis.content.get('a55890fcd6424b5bb4edddfc5a4bdc4b')
naip_domain_b_raster
Output
naip_train_area_domain_b
naip raster or domain bImagery Layer by api_data_owner
Last Modified: March 12, 2021
0 comments, 10 views
Input
dsm_domain_a_raster = ent_gis.content.get('aa31a374f889487d951e15063944b921')
dsm_domain_a_raster
Output
dsm_train_area_domain_a
dsm raster or domain aImagery Layer by api_data_owner
Last Modified: January 08, 2021
0 comments, 7 views

Inside the exported data folder, 'Images' and 'Images2' folders contain all the image tiles from two domains exported from DSM and drone imagery respectively. Now we are ready to train the Pix2Pix model.

Model training

Alternatively, we have provided a subset of training data containing a few samples that follows the same directory structure mentioned above and also provided the rasters used for exporting the training dataset. You can use the data directly to run the experiments.

Input
training_data = gis.content.get('2a3dad36569b48ed99858e8579611a80')
training_data
Output
data_for_pix2pix_with_trained_model
data_for_pix2pix_with_trained_modelImage Collection by api_data_owner
Last Modified: January 08, 2021
0 comments, 0 views
Input
filepath = training_data.download(file_name=training_data.name)
Input
#Extract the data from the zipped image collection

with zipfile.ZipFile(filepath, 'r') as zip_ref:
    zip_ref.extractall(Path(filepath).parent)

Prepare data

Input
output_path = Path(os.path.join(os.path.splitext(filepath)[0]))
Input
data = prepare_data(output_path, dataset_type="Pix2Pix", batch_size=5)

Visualize a few samples from your training data

To get a sense of what the training data looks like, arcgis.learn.show_batch() method randomly picks a few training chips and visualize them. On the left are some DSM's (digital surface model) with the corresponding RGB imageries of various locations on the right.

Input
data.show_batch()

Load Pix2Pix model architecture

Input
model = Pix2Pix(data)

Tuning for optimal learning rate

Learning rate is one of the most important hyperparameters in model training. ArcGIS API for Python provides a learning rate finder that automatically chooses the optimal learning rate for you.

Input
lr = model.lr_find()

2.5118864315095795e-05

Fit the model

The model is trained for around a few epochs with the suggested learning rate.

Input
model.fit(30, lr)
epoch train_loss valid_loss gen_loss l1_loss D_loss time
0 13.203547 13.980255 0.576110 0.126274 0.412447 01:01
1 12.675353 13.891787 0.573363 0.121020 0.411131 01:02
2 12.830377 13.652339 0.577334 0.122530 0.410224 01:00
3 12.826028 13.478950 0.578673 0.122474 0.410028 01:01
4 12.830496 13.464501 0.579446 0.122510 0.407034 01:01
5 12.978190 13.808777 0.581329 0.123969 0.405155 01:01
6 12.933887 14.188525 0.579817 0.123541 0.402280 01:01
7 12.660383 13.273459 0.583129 0.120773 0.398041 01:01
8 12.493378 13.234705 0.584513 0.119089 0.395928 01:02
9 12.704373 14.314936 0.583671 0.121207 0.393755 01:01
10 12.283652 12.872752 0.586115 0.116975 0.391496 01:01
11 12.008025 12.989032 0.585851 0.114222 0.386542 01:02
12 11.848214 12.356230 0.586706 0.112615 0.385120 01:01
13 11.648248 12.387824 0.586294 0.110620 0.383345 01:01
14 11.220642 12.051290 0.586354 0.106343 0.380747 01:01
15 11.065363 11.816018 0.587154 0.104782 0.379417 01:01
16 11.107099 11.579307 0.587886 0.105192 0.377144 01:02
17 10.680603 11.504006 0.587307 0.100933 0.375779 01:03
18 10.604408 11.234290 0.587380 0.100170 0.373917 01:03
19 10.459021 11.162776 0.586817 0.098722 0.372892 01:05
20 10.251445 10.944400 0.587671 0.096638 0.371933 01:02
21 10.173382 10.966841 0.587322 0.095861 0.371821 01:01
22 9.945634 10.783834 0.587247 0.093584 0.371387 01:01
23 9.681182 10.716444 0.587864 0.090933 0.369668 01:01
24 9.872039 10.600616 0.588303 0.092837 0.369563 01:00
25 9.786720 10.603912 0.588364 0.091984 0.369715 01:02
26 9.680658 10.506352 0.587878 0.090928 0.369863 01:02
27 9.386904 10.502596 0.587328 0.087996 0.368502 01:01
28 9.835923 10.505837 0.588324 0.092476 0.369696 01:01
29 9.630071 10.498654 0.586929 0.090431 0.368856 01:00

Here, with 30 epochs, we can see reasonable results — both training and validation losses have gone down considerably, indicating that the model is learning to translate between domain of imageries.

Save the model

We will save the model which we trained as a 'Deep Learning Package' ('.dlpk' format). Deep Learning package is the standard format used to deploy deep learning models on the ArcGIS platform.

We will use the save() method to save the trained model. By default, it will be saved to the 'models' sub-folder within our training data folder.

Input
model.save("pix2pix_model_e30", publish =True)

Visualize results in validation set

It is a good practice to see results of the model viz-a-viz ground truth. The code below picks random samples and shows us ground truth and model predictions, side by side. This enables us to preview the results of the model within the notebook.

Input
model.show_results()

Compute evaluation metrics

The Frechet Inception Distance score, or FID for short, is a metric that calculates the distance between feature vectors calculated for real and generated images. Lower scores indicate the two groups of images are more similar, or have more similar statistics, with a perfect score being 0.0 indicating that the two groups of images are identical.

Input
model.compute_metrics()
Output
263.63128885232044

Model inferencing

Inference on a single imagery chip

We can translate DSM to RGB imagery with the help of predict() method.

Using predict function, we can apply the trained model on the image chip kept for validation, which we want to translate.

  • img_path: path to the image file.
Input
valid_data = gis.content.get('f682b16bcc6d40419a775ea2cad8f861')
valid_data
Output
dsm raster chip for inferencing
dsm raster chip for inferencing Image by api_data_owner
Last Modified: January 08, 2021
0 comments, 8 views
Input
filepath2 = valid_data.download(file_name=valid_data.name)
Input
# Visualize the image chip used for inferencing 
from fastai.vision import open_image
open_image(filepath2)
Output
Input
#Inference single imagery chip
model.predict(filepath2)
Output

Generate raster using classify pixels using deep learning tool

After we trained the Pix2Pix model and saved the weights for translating image and we could use the classify pixels using deep learning tool avialable in both ArcGIS pro and ArcGIS Enterprise for inferencing at scale.

Input
test_data = ent_gis.content.get('86bed58f977c4c0aa39053d93141cdb1')
test_data
Output
dsm_test_area
Test area dsm for large scale inferencingImagery Layer by api_data_owner
Last Modified: January 22, 2021
0 comments, 0 views

out_classified_raster = arcpy.ia.ClassifyPixelsUsingDeepLearning("Imagery", r"C:\path\to\model.emd", "padding 64;batch_size 2"); out_classified_raster.save(r"C:\sample\sample.gdb\predicted_img2dsm")

Results visualization

The RGB output raster is generated using ArcGIS Pro. The output raster is published on the portal for visualization.

Input
inferenced_results = ent_gis.content.get('30951690103047f096c6339398593d79')
inferenced_results
Output
predicted_rgb_imagery
Inferenced rgb imagery Imagery Layer by api_data_owner
Last Modified: January 22, 2021
0 comments, 4 views

Create map widgets

Two map widgets are created showing DSM and Inferenced RGB raster.

Input
map1 = ent_gis.map('Washington Fields', 13)
map1.add_layer(test_data)
map2 = ent_gis.map('Washington Fields', 13)
map2.add_layer(inferenced_results)

Synchronize web maps

The maps are synchronized with each other using MapView.sync_navigation functionality. It helps in comparing the inferenced results with the DSM. Detailed description about advanced map widget options can be referred here.

Input
map2.sync_navigation(map1)

Set the map layout

Input
from ipywidgets import HBox, VBox, Label, Layout

Hbox and Vbox were used to set the layout of map widgets.

Input
hbox_layout = Layout()
hbox_layout.justify_content = 'space-around'

hb1=HBox([Label('DSM'),Label('RGB results')])
hb1.layout=hbox_layout

Results

The predictions are provided as a map for better visualization.

Input
VBox([hb1,HBox([map1,map2])])

Input
map2.zoom_to_layer(inferenced_results)

Conclusion

In this notebook, we demonstrated how to use Pix2Pix model using ArcGIS API for Python in order to translate imagery of one domain to the another domain.

References

  • [1]. Isola, Phillip, Jun-Yan Zhu, Tinghui Zhou, and Alexei A. Efros. "Image-to-image translation with conditional adversarial networks." In Proceedings of the IEEE conference on computer vision and pattern recognition, pp. 1125-1134. 2017.
  • [2]. Goodfellow, Ian, Jean Pouget-Abadie, Mehdi Mirza, Bing Xu, David Warde-Farley, Sherjil Ozair, Aaron Courville, and Yoshua Bengio. "Generative adversarial nets." In Advances in neural information processing systems, pp. 2672-2680. 2014.
  • [3]. https://stephan-osterburg.gitbook.io/coding/coding/ml-dl/tensorfow/chapter-4-conditional-generative-adversarial-network/acgan-architectural-design
  • [4]. Kang, Yuhao, Song Gao, and Robert E. Roth. "Transferring multiscale map styles using generative adversarial networks." International Journal of Cartography 5, no. 2-3 (2019): 115-141.
  • [5]. State of Utah and Partners, 2019, Regional Utah high-resolution lidar data 2015 - 2017: Collected by Quantum Spatial, Inc., Digital Mapping, Inc., and Aero-Graphics, Inc. and distributed by OpenTopography, https://doi.org/10.5069/G9RV0KSQ. Accessed: 2020-12-08

Your browser is no longer supported. Please upgrade your browser for the best experience. See our browser deprecation post for more details.