Introduction
Address Standardization is the process of changing addresses to adhere to USPS standards. In this notebook, we will aim at abbreviating the addresses as per standard USPS abbreviations.
Address Correction will aim at correcting miss-spelled place names.
We will train a model using SequenceToSequence
class of arcgis.learn.text
module to translate the non-standard and erroneous address to their standard and correct form.
The dataset consists of a pair of non-standard, incorrect(synthetic errors) house addresses and corresponding correct, standard house addresses from the United States. The correct addresses are taken from OpenAddresses data.
Disclaimer: The correct addresses were synthetically corrupted to prepare the training dataset, this could have lead to some unexpected corruptions in addresses, which will affect the translation learned by the model.
A note on the dataset
- The data is collected around 2020-04-29 by OpenAddresses.
- The data licenses can be found in
data/address_standardization_correction_data/LICENSE.txt
.
Prerequisites
-
Data preparation and model training workflows using arcgis.learn have a dependency on transformers. Refer to the section "Install deep learning dependencies of arcgis.learn module" on this page for detailed documentation on the installation of the dependencies.
-
Labeled data: For
SequenceToSequence
model to learn, it needs to see documents/texts that have been assigned a label. Labeled data for this sample notebook is located atdata/address_standardization_correction_data/address_standardization_correction.csv
-
To learn more about how
SequenceToSequence
works, please see the guide on How SequenceToSequence works.
!pip install transformers==3.3.0
Collecting transformers==3.3.0 Downloading transformers-3.3.0-py3-none-any.whl (1.1 MB) [K |████████████████████████████████| 1.1 MB 8.7 MB/s eta 0:00:01 [?25hRequirement already satisfied: regex!=2019.12.17 in /opt/conda/lib/python3.7/site-packages (from transformers==3.3.0) (2020.11.13) Collecting sentencepiece!=0.1.92 Downloading sentencepiece-0.1.95-cp37-cp37m-manylinux2014_x86_64.whl (1.2 MB) [K |████████████████████████████████| 1.2 MB 30.3 MB/s eta 0:00:01 [?25hCollecting sacremoses Downloading sacremoses-0.0.45-py3-none-any.whl (895 kB) [K |████████████████████████████████| 895 kB 63.0 MB/s eta 0:00:01 [?25hRequirement already satisfied: requests in /opt/conda/lib/python3.7/site-packages (from transformers==3.3.0) (2.25.1) Requirement already satisfied: packaging in /opt/conda/lib/python3.7/site-packages (from transformers==3.3.0) (20.9) Requirement already satisfied: numpy in /opt/conda/lib/python3.7/site-packages (from transformers==3.3.0) (1.19.2) Collecting tokenizers==0.8.1.rc2 Downloading tokenizers-0.8.1rc2-cp37-cp37m-manylinux1_x86_64.whl (3.0 MB) [K |████████████████████████████████| 3.0 MB 62.6 MB/s eta 0:00:01 [?25hRequirement already satisfied: tqdm>=4.27 in /opt/conda/lib/python3.7/site-packages (from transformers==3.3.0) (4.56.0) Collecting filelock Downloading filelock-3.0.12-py3-none-any.whl (7.6 kB) Requirement already satisfied: pyparsing>=2.0.2 in /opt/conda/lib/python3.7/site-packages (from packaging->transformers==3.3.0) (2.4.7) Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/lib/python3.7/site-packages (from requests->transformers==3.3.0) (2020.12.5) Requirement already satisfied: chardet<5,>=3.0.2 in /opt/conda/lib/python3.7/site-packages (from requests->transformers==3.3.0) (4.0.0) Requirement already satisfied: idna<3,>=2.5 in /opt/conda/lib/python3.7/site-packages (from requests->transformers==3.3.0) (2.10) Requirement already satisfied: urllib3<1.27,>=1.21.1 in /opt/conda/lib/python3.7/site-packages (from requests->transformers==3.3.0) (1.26.3) Requirement already satisfied: six in /opt/conda/lib/python3.7/site-packages (from sacremoses->transformers==3.3.0) (1.15.0) Requirement already satisfied: click in /opt/conda/lib/python3.7/site-packages (from sacremoses->transformers==3.3.0) (7.1.2) Requirement already satisfied: joblib in /opt/conda/lib/python3.7/site-packages (from sacremoses->transformers==3.3.0) (1.0.1) Installing collected packages: tokenizers, sentencepiece, sacremoses, filelock, transformers Successfully installed filelock-3.0.12 sacremoses-0.0.45 sentencepiece-0.1.95 tokenizers-0.8.1rc2 transformers-3.3.0
Note: Please restart the kernel before running the cells below.
Imports
import os
import zipfile
from pathlib import Path
from arcgis.gis import GIS
from arcgis.learn import prepare_textdata
from arcgis.learn.text import SequenceToSequence
gis = GIS('home')
Data preparation
Data preparation involves splitting the data into training and validation sets, creating the necessary data structures for loading data into the model and so on. The prepare_textdata()
function can directly read the training samples and automate the entire process.
training_data = gis.content.get('06200bcbf46a4f58b2036c02b0bff41e')
training_data
Note: This address dataset is a subset (~15%) of the dataset available at "ea94e88b5a56412995fd1ffcb85d60e9" item id.
filepath = training_data.download(file_name=training_data.name)
with zipfile.ZipFile(filepath, 'r') as zip_ref:
zip_ref.extractall(Path(filepath).parent)
data_root = Path(os.path.join(os.path.splitext(filepath)[0]))
data = prepare_textdata(path=data_root, batch_size=16, task='sequence_translation',
text_columns='non-std-address', label_columns='std-address',
train_file='address_standardization_correction_data_small.csv')
The show_batch()
method can be used to see the training samples, along with labels.
data.show_batch()
non-std-address | std-address |
---|---|
4967, red violet dr, dubuque, ia, 52002, us | 4967, red violet dr, dubuque, ia, 52002, us |
211, 7th street, carmi, illinois, 62821.0, us | 211, 7th st, carmi, il, 62821.0, us |
916, cleary dvenue, junction city, kansas, 66441, us | 916, cleary ave, junction city, ks, 66441, us |
1919, freychrn drive south west, cedar rapids, iowa, 52404.0, us | 1919, gretchen dr sw, cedar rapids, ia, 52404.0, us |
512, new haven drive, cary, illinois, 60013.0, us | 512, new haven dr, cary, il, 60013.0, us |
SequenceToSequence model
SequenceToSequence
model in arcgis.learn.text
is built on top of Hugging Face Transformers library. The model training and inferencing workflows are similar to computer vision models in arcgis.learn
.
Run the command below to see what backbones are supported for the sequence translation task.
SequenceToSequence.supported_backbones
['T5', 'Bart', 'Marian']
Call the model's available_backbone_models()
method with the backbone name to get the available models for that backbone. The call to available_backbone_models method will list out only a few of the available models for each backbone. Visit this link to get a complete list of models for each backbone.
SequenceToSequence.available_backbone_models("T5")
['t5-small', 't5-base', 't5-large', 't5-3b', 't5-11b', 'See all T5 models at https://huggingface.co/models?filter=t5 ']
Load model architecture
Invoke the SequenceToSequence
class by passing the data and the backbone you have chosen. The dataset consists of house addresses in non-standard format with synthetic errors, we will finetune a t5-base pretrained model. The model will attempt to learn how to standardize and correct the input addresses.
model = SequenceToSequence(data,backbone='t5-base')
Model training
The learning rate
[1] is a tuning parameter that determines the step size at each iteration while moving toward a minimum of a loss function, it represents the speed at which a machine learning model "learns". arcgis.learn
includes a learning rate finder, and is accessible through the model's lr_find()
method, which can automatically select an optimum learning rate, without requiring repeated experiments.
lr = model.lr_find()
lr
0.1445439770745928
Training the model is an iterative process. We can train the model using its fit()
method till the validation loss (or error rate) continues to go down with each training pass also known as epoch. This is indicative of the model learning the task.
model.fit(1, lr=lr)
epoch | train_loss | valid_loss | seq2seq_acc | bleu | time |
---|---|---|---|---|---|
0 | 1.287657 | 0.951018 | 0.863680 | 0.752205 | 11:10 |
By default, the earlier layers of the model (i.e. the backbone) are frozen. Once the later layers have been sufficiently trained, the earlier layers are unfrozen (by calling unfreeze()
method of the class) to further fine-tune the model.
model.unfreeze()
lr = model.lr_find()
lr
model.fit(5, lr)
epoch | train_loss | valid_loss | seq2seq_acc | bleu | time |
---|---|---|---|---|---|
0 | 0.331751 | 0.278617 | 0.962188 | 0.916663 | 17:45 |
1 | 0.177372 | 0.153773 | 0.982446 | 0.959307 | 17:36 |
2 | 0.143805 | 0.118750 | 0.987322 | 0.970336 | 17:37 |
3 | 0.118908 | 0.105951 | 0.989088 | 0.974331 | 17:40 |
4 | 0.124536 | 0.103347 | 0.989461 | 0.975195 | 17:41 |
model.fit(3, lr)
epoch | train_loss | valid_loss | seq2seq_acc | bleu | time |
---|---|---|---|---|---|
0 | 0.116942 | 0.100216 | 0.989321 | 0.974961 | 17:49 |
1 | 0.103494 | 0.088271 | 0.990844 | 0.978451 | 17:43 |
2 | 0.091599 | 0.084226 | 0.991426 | 0.979786 | 17:40 |
Validate results
Once we have the trained model, we can see the results to see how it performs.
model.show_results()
text | target | pred |
---|---|---|
940, north pennsylvania avneue, mason icty, iowa, 50401, us | 940, n pennsylvania ave, mason city, ia, 50401, us | 940, n pennsylvania ave, mason city, ia, 50401, us |
24640, a-b 305th srreet, nora speings, iowa, 50458, us | 24640, a-b 305th st, nora springs, ia, 50458, us | 24640, a-b 305th st, nora cetings, ia, 50458, us |
2920, 1st srteet south west, mason ciry, iowa, 50401, us | 2920, 1st st sw, mason city, ia, 50401, us | 2920, 1st st sw, mason city, ia, 50401, us |
210, s rhode island ave, mason ctiy, ia, 50401, us | 210, s rhode island ave, mason city, ia, 50401, us | 210, s rhode island ave, mason city, ia, 50401, us |
427, n massachudetts ave, mason coty, ia, 50401, us | 427, n massachusetts ave, mason city, ia, 50401, us | 427, n massachudetts ave, mason city, ia, 50401, us |
Model metrics
To get a sense of how well the model is trained, we will calculate some important metrics for our SequenceToSequence
model. To see what's the model accuracy [2] and bleu score [3] on the validation data-set. We will call the model's get_model_metrics()
method.
model.get_model_metrics()
{'seq2seq_acc': 0.9914, 'bleu': 0.9798}
Saving the trained model
Once you are satisfied with the model, you can save it using the save() method. This creates an Esri Model Definition (EMD file) that can be used for inferencing unseen data.
model.save('seq2seq_unfrozen8E_bleu_98', publish=True)
Published DLPK Item Id: ed79aa1b34dd406aae4eed0123bc4608
WindowsPath('models/seq2seq_unfrozen8E_bleu_98')
Model inference
The trained model can be used to translate new text documents using the predict method. This method accepts a string or a list of strings to predict the labels of these new documents/text.
txt=['940, north pennsylvania avneue, mason icty, iowa, 50401, us',
'220, soyth rhodeisland aveune, mason city, iowa, 50401, us']
model.predict(txt, num_beams=6, max_length=50)
[('940, north pennsylvania avneue, mason icty, iowa, 50401, us', '940, n pennsylvania ave, mason city, ia, 50401, us'), ('220, soyth rhodeisland aveune, mason city, iowa, 50401, us', '220, s rhode island ave, mason city, ia, 50401, us')]
Conclusion
In this notebook we will build an address standardization and correction model using SequenceToSequence
class of arcgis.learn.text
module. The dataset consisted of a pair of non-standard, incorrect (synthetic errors) house addresses and corresponding correct, standard house addresses from the United States. To achieve this we used a t5-base pretrained transformer to build a SequenceToSequence model to standardize and correct the input house addresses. Below are the results on sample inputs.
Non-Standard → Standard , Error → Correction
- 940, north pennsylvania avneue, mason icty, iowa, 50401, us → 940, n pennsylvania ave, mason city, ia, 50401, us
- 220, soyth rhodeisland aveune, mason city, iowa, 50401, us → 220, s rhode island ave, mason city, ia, 50401, us
References
[1][Learning Rate](https://en.wikipedia.org/wiki/Learning_rate)
[2][Accuracy](https://en.wikipedia.org/wiki/Accuracy_and_precision)
[3][Bleu score](https://en.wikipedia.org/wiki/BLEU)