There are two parts to obtaining a BERT model for classification: Preparing the dataset and training/fine-tuning a pre-trained, multilingual BERT.
The training dataset is created by extracting text from PDFs and forming 512 tokens per document in each
row of a TSV file required by run_classifier.py
. A modified BERT
is used to create a SavedModel for fast serving via tensorflow serving
(only run_classifier.py
was modified).
The training/fine-tuning is best done on Google Cloud using TPUs for quick performance (15min for 20k samples), but can also be done using just CPU (30hr using 48GB RAM x 30 cores x 2GHz clock).
Use gen_bert_data.py
to prepare training data for BERT by ingesting .pdf files, producing
intermediate .txt files using pdftotext, and then processing those into the tsv file that
BERT requires. The script prep_all_bert.sh
shows how to run it:
./prep_all_bert.sh
Usage: basename_for_dataset dest_dir kill_list other_pdfs_dir research_pdfs_dir_1 [research_pdfs_dir_2]
# Example usage
TS=$(date '+%Y%m%dT%H')
./prep_all_bert.sh bert$TS mydest/ my_kill_list my_other_pdfs_dir/ my_research_pdfs_dir/
where the basename is given a timestamp to record when the data was gathered, my_kill_list is a file with pdf file basenames (one per line, no .pdf extension, no path) which is used to ignore certain PDF files (the file can be empty, like /dev/null), and two directories of PDF files, one for each class.
Note that the kill_list is only applied to the 'other' category. This is because the Internet Archive collection of 'research' PDFs is taken as a given since it is derived from authoritative sources. By contrast, the 'other' collection is derived from random PDFs at large, so it can contain docs that actually are 'research'. In practice, about 6% of random PDFs on the Internet are research docs. Our process has been to analyze misclassified docs by viewing them, and if they are labeled 'other' but should be 'research', they are added to the kill list. This makes it possible to decouple the exceptions from the gathering of random PDFs for 'other'.
For internal provenance reasons, an optional, additional research PDF directory is supported on the command line.
We are using run_classifier.py from BERT to train (fine-tune, actually) the classifier.
Training/fine-tuning is best done on Google Cloud using TPUs for quick performance (15min), but can also be done using just CPU (30hr).
Spin up a linux node in Google Cloud. The following was tested against Ubuntu 16, but should only need minor tweaks on other versions or flavors. The TPU was a v2-8 TPU (named tpu-node-1), which is basically the least-powerful TPU available, but totally sufficient for the present purpose.
The TPU is the most expensive part to rent (about $100/day), so do not start it until ready to train/fine-tune and double-check when done that you stopped the TPU. Note that defining the TPU instance will auto-start it.
# install miniconda
curl https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh >Miniconda3-latest-Linux-x86_64.sh
chmod +x *.sh
./Miniconda3-latest-Linux-x86_64.sh
# update env
. ~/.bashrc
# should be in (base) conda env now
# create tf_hub env
conda create --name tf_hub python=3.7 scipy tensorflow=1.14.0 tensorboard=1.14.0 numpy tensorflow-hub -c conda-forge
conda activate tf_hub
pip install --upgrade google-api-python-client
pip install --upgrade oauth2client
The run_classifier.py module of BERT is used to train/fine-tune and test/validate the model.
Use this particular BERT repo because it has been modified to produce a SavedModel which works with a REST api of tensorflow serving; the REST definition matches usage in app.py in top level directory:
git clone https://github.com/tralfamadude/bert.git
A bucket must be defined, called BUCKET below.
When using cloud TPU, cloud storage (gs://) must be used for pretrained model,
and the output directory, as indicated by bert repo readme.
Also put the inputs into gs://${BUCKET}/$BASE where BASE is the basename used for creating the dataset above.
Note that the bucket, TPU, and machine instance need to be in the same region.
In the next commands, we assume ~/$BASE corresponds to the directory generated by the dataset creation above, so scp your dataset to the cloud instance.
BASE=?fill-in?
BUCKET=mybucket
curl https://storage.googleapis.com/bert_models/2018_10_18/uncased_L-12_H-768_A-12.zip > uncased_L-12_H-768_A-12.zip
unzip uncased_L-12_H-768_A-12.zip
export BERT_BASE_DIR=gs://$BUCKET/multi_cased_L-12_H-768_A-12
gsutil cp -r multi_cased_L-12_H-768_A-12 gs://$BUCKET/
gsutil cp -r ~/$BASE gs://$BUCKET/
Start your TPU; it is assumed to be called tpu-node-1 below.
TS=$(date '+%Y%m%dT%H%M')
BOUT=bert_output_$TS
nohup python ./run_classifier.py \
--task_name=cola \
--do_train=true \
--do_eval=true \
--vocab_file=$BERT_BASE_DIR/vocab.txt \
--bert_config_file=$BERT_BASE_DIR/bert_config.json \
--init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt \
--max_seq_length=512 \
--train_batch_size=64 \
--learning_rate=2e-5 \
--num_train_epochs=3 \
--do_lower_case=False \
--data_dir=gs://${BUCKET}/$BASE \
--output_dir=gs://${BUCKET}/$BOUT \
--use_tpu=true \
--tpu_name=tpu-node-1 \
--tpu_zone=us-central1-b \
--num_tpu_cores=8 \
>run.out 2>&1 &
wait
Wait for it to finish before proceeding. It should take about 15min for 20k samples. If it finishes in a minute, then something probably went wrong. Look at run.out to see what happened.
Using the Google cloud console, look at gs://${BUCKET}/$BOUT
This step will measure accuracy against the withheld dataset. The model checkpoint name and bert_output dir must be manually substituted below. (ToDo: need gsutil cmd to find the max ckpt file.)
SAVED_MODEL=gs://${BUCKET}/bert_finetuned_${TS}
python ./run_classifier.py \
--task_name=cola \
--do_predict=true \
--vocab_file=$BERT_BASE_DIR/vocab.txt \
--bert_config_file=$BERT_BASE_DIR/bert_config.json \
--init_checkpoint=gs://${BUCKET}/${BOUT}/model.ckpt-1136 \
--max_seq_length=512 \
--data_dir=gs://${BUCKET}/$BASE \
--use_tpu=true \
--tpu_name=tpu-node-1 \
--tpu_zone=us-central1-b \
--num_tpu_cores=8 \
--output_dir=gs://${BUCKET}/${BOUT} 2>&1 | tee measure.out
#
# measure withheld samples:
gsutil cp gs://${BUCKET}/${BOUT}/test_results.tsv .
gsutil cp gs://${BUCKET}/bert_20000/test_original.tsv .
python ./evaluate_test_set_predictions.py --tsv ./test_original.tsv --results ./test_results.tsv > test_tally
# see test_tally for details on results on witheld samples
Back in research-pub/data_prep/bert_data_prep/ (not necessarily in google cloud), measure withheld samples:
gsutil cp gs://${BUCKET}/${BOUT}/test_results.tsv .
gsutil cp gs://${BUCKET}/$BASE/test_original.tsv .
python ./evaluate_test_set_predictions.py --tsv ./test_original.tsv --results ./test_results.tsv > test_tally
look at the file test_tally for details on results on withheld samples. The last line shows the
accuracy like this: n 1818 correct 1777 0.977448