Standalone pre-training recipe with JAX+Flax

Overview

Sabertooth

Sabertooth is standalone pre-training recipe based on JAX+Flax, with data pipelines implemented in Rust. It runs on CPU, GPU, and/or TPU, but this README targets TPU usage.

Contents

  1. Installation
  2. Data preparation
  3. Tokenizer preparation
  4. Pre-training
  5. Fine-tuning on GLUE

Installation

Automatic installation in TPU VMs

The TPU code is not intended to work with the "JAX Cloud TPU Preview" that uses tpu_driver, but rather only with the "new JAX on Cloud TPU in private alpha" that involves direct SSH access. In order to sign up for private alpha access, please follow this link.

Rather than creating VMs manually and installing dependencies by hand, you can use the scripts in tpu_management to help with this.

Before using these scripts:

  • Follow the one-time setup instructions in the "Cloud TPU VM Alpha User Guide" document up to (but not including) the point where you actually create a TPU VM.
  • Then follow the instructions in tpu_management/config.env.template to set configuration variables used by the TPU management scripts.

After that, the following TPU management commands are available:

  • tpu_management/tpu.sh create NAME creates a single-host TPU VM with the specified NAME and prints an ssh config for accessing the VM. Adding this config entry to ~/.ssh/config will then allow you to use ssh NAME to access the VM.
  • tpu_management/tpu.sh provision NAME will push this git repository to the created VM, and build/install all required software within the VM. It also sets up NAME as a git remote, so you can do git push NAME ... to push code to the VM, and git fetch NAME to fetch commits developed within the VM.
  • tpu_management/tpu.sh delete NAME deletes the TPU instance. All files on the TPU VM filesystem will be lost.
  • tpu_management/tpu.sh config-ssh NAME prints the SSH config you need to connect to the VM. This command is useful when aVM automatically restarts and gets assigned a new IP address.

The tpu_management/pod.sh behaves similarly, but for multi-host TPU training:

  • tpu_management/pod.sh create NAME creates a multi-host TPU setup with specified NAME, and sets up ssh config for each worker (${NAME}0, ${NAME}1, ...).
  • tpu_management/pod.sh provision NAME will configure all worker VMs, as well as setting up a git such that git push NAME ... will push to all four workers
  • tpu_management/tpu.sh delete NAME deletes the TPU instance, including all workers. All files on the worker VM filesystems will be lost.
  • tpu_management/tpu.sh config-ssh NAME prints the SSH config you need to connect to the workers. This command is useful when a VM automatically restarts and gets assigned a new IP address.

Note: on a brand new TPU VM, doing import tensorflow for the first time can take minutes. This issue goes away on all subsequent calls. If a pre-training/fine-tuning script appears to hang on a new VM, this is probably the cause. You can test this by waiting out python3 -c "import tensorflow" and then running the actual script you want.

Manual installation

For Python dependencies, see requirements_tpu.txt. The input pipeline and data processing scripts are implemented in Rust. See https://rustup.rs/ for one-line shell command that installs Rust.

Run ./install_sabertooth_pipeline.sh to build and install the sabertooth_pipeline helper package into the currently active python environment. This requires CMake to be installed.

Data preparation

Accepted formats

Sabertooth accepts pre-training data in either of the following formats:

  • Text format with one sentence per line, with blank lines in between documents (the BERT format). With this format, sentence segmentation is assumed to have been fully carried out during pre-processing.
  • JSONlines format, which is automatically used for all files ending in .jsonl. A zstandard-compressed version is also accepted, which is used for all files ending in .jsonl.zst or plain .zst. Each line match the schema {"text": "[JSON-encoded text of an entire document...]"} (the Pile format). With this format, sentence segmentation will be performed at training time by background CPU threads that are also responsible for tokenization.

Pre-processed downloads

If you don't want to run the processing commands described in the next subsection, you can skip ahead by downloading already processed data from one of the following sources:

  • English Wikipedia in BERT format (4.5GB compressed download), courtesy of GluonNLP. The processing scripts in this repository are largely identical to the GluonNLP processing scripts, except that ours are implemented in Rust. The only downside of this download is that it does not include any Books data (potentially resulting in a lower GLUE score after pre-training), and does not shuffle the full corpus.
  • The Pile: 825GB of text from diverse sources, but some of this data is not quite as clean as our Wikipedia processing. Training BERT-base on the Pile consistently achieves a GLUE test score above 77, but we have not nailed the right set of hyperparameters for matching the effectiveness of Wiki+Books data.

Preparing wikibooks data

Downloading and extracting: Wikipedia

First, download the Wikipedia dump and extract the pages. The Wikipedia dump can be downloaded from this link, and should contain the following file: enwiki-latest-pages-articles-multistream.xml.bz2. Note that older database dumps are periodically deleted from the official Wikimedia downloads website, so we can't pin a specific version of Wikipedia without hosting a mirror of the full dataset.

Next, run WikiExtractor (rust/create_pretraining_data/WikiExtractor.py) to extract the wiki pages from the XML. The generated wiki pages file will be stored as <data dir>/LL/wiki_nn; for example <data dir>/AA/wiki_00. Each file is ~1MB, and each sub directory has 100 files from wiki_00 to wiki_99, except the last sub directory. For the dump around December 2020, the last file is FL/wiki_09.

DATA_ROOT="$HOME/prep_sabertooth"
mkdir -p $DATA_ROOT
cd $DATA_ROOT
wget https://dumps.wikimedia.org/enwiki/latest/enwiki-latest-pages-articles-multistream.xml.bz2    # Optionally use curl instead
bzip2 -d enwiki-latest-pages-articles-multistream.xml.bz2
python3 rust/create_pretraining_data/WikiExtractor.py enwiki-latest-pages-articles-multistream.xml    # Results are placed in text/

Downloading and extracting: Books

Download and extract books1.

Preprocessing and sharding

DATA_ROOT="$HOME/prep_sabertooth"
mkdir -p $DATA_ROOT/wikibooks
cd rust/create_pretraining_data
cargo build --release  # Removing `--release` will make the code *much* slower
target/release/create_pretraining_data --output $DATA_ROOT/wikibooks --num-shards 500 --wiki $DATA_ROOT/'text/??/wiki_??' --books $DATA_ROOT/'books1/epubtxt/*.txt'

This will write 500 shards in JSONlines format to the folder wikibooks/; each should be around 40MB in size. Set the RAYON_NUM_THREADS environment variable to limit the number of parallel threads used for processing; by default, a thread will be spawned per CPU core.

create_pretraining_data will load the full dataset into RAM to perform a global shuffle. If you run out of memory, you can it multiple times with a subset of the files. --wiki and --books both accept glob patterns, so you can do e.g. --wiki 'text/A?/wiki_??'. The --wiki or --books can also be repeated (with different files each time), or omitted. The JSONlines format should also make it easy to manually concatenate/split/shuffle shards of data.

Compressing the data

Compressing the data is optional, but it can save space when storing and transferring large datasets. All of our code supports uncompressed .jsonl and zstandard-compressed .jsonl.zst files interchangeably, so you can also skip this step (and change .jsonl.zst to .jsonl in all commands below).

To compress: download zstandard, compile it with make, and run a compression command such as zstd $DATA_ROOT/wikibooks/*.jsonl.

The overhead of decompressing data is negligible when compared to more costly operations like tokenization or sentence segmentation, which we will be doing at training time anyway. The only downside of compressing the data is that you can't inspect compressed files with a simple text viewer/editor.

Tokenizer preparation

We use SentencePiece for tokenization.

For large datasets, feeding the raw data directly to the sentencepiece trainer is very slow in memory-hungry (RAM usage is at least 10x the size of the uncompressed raw text, and there are single-threaded O(n) sections in the sentencepiece trainer code).

Instead, we will first count unique words in our data, separated only by whitespace. We will then pass a TSV file containing the counts to the sentencepiece trainer, which will further decompose them and build a subword vocabulary.

To count the tokens in our data, run:

DATA_ROOT="$HOME/prep_sabertooth"
pushd rust/count_tokens
cargo build --release
popd
rust/count_tokens/target/release/count_tokens $DATA_ROOT/wikibooks/*.jsonl.zst > $DATA_ROOT/counts_wikibooks.tsv

count_tokens accepts filenames as arguments, and the supports the following formats:

  • .jsonl: JSONlines data, where each line matches the schema {"text": "[JSON-encoded text of an entire document...]"}
  • .jsonl.zst and .zst: same as above, but zstandard-compressed
  • .tsv: merges counts from another tsv file into the output. If your available CPU RAM is not sufficient to count your corpus all at once, you can count shards of the data separately and then merge the tsv files
  • All other extensions are treated as plain-text data

Once the TSV file is created, use our provided script to train the tokenizer:

python3 rust/count_tokens/train_tokenizer.py --input $DATA_ROOT/counts_wikibooks.tsv --model_prefix $DATA_ROOT/wikibooks_32k --vocab_size 32128

After this you're ready to run pre-training!

Tip for data transfer: The data folder (with zstandard-compressed shards) can be packed into a tar archive using the command cd $DATA_DIR/.. && tar cvf prep_sabertooth.tar prep_sabertooth/wikibooks/*.jsonl.zst prep_sabertooth/*.model prep_sabertooth/*.vocab. If you have multiple TPU VMs (either multiple workers for multi-host training, or just multiple VMs for different jobs), scp-ing data from one VM to another is 10x-100x faster than copying via gsutil cp or from any non-gcloud machine. SSH into a TPU VM and run ifconfig to determine its interal IP address that can be accessed by other VMs (typically of the form 10.x.y.z).

Pre-training

For single-host pre-training, run a command such as:

python3 run_pretraining.py --config=configs/pretraining.py --config.train_batch_size=1024 --config.optimizer="adam" --config.learning_rate=1e-4 --config.num_train_steps=1000000 --config.num_warmup_steps=10000 --config.max_seq_length=128 --config.max_predictions_per_seq=20

For BERT base, our best hyperparameter setting thus far involves pre-training with global batch size 4096. To launch this training on TPUv3-32 with 4 hosts, set up each host with the required environment variables and then run the following command:

python3 run_pretraining.py --config=configs/pretraining.py --config.optimizer="adam" --config.train_batch_size=4096 --config.learning_rate=1e-3 --config.num_train_steps=125000 --config.num_warmup_steps=3125 --config.adam_epsilon=1e-11 --config.adam_beta1=0.9 --config.adam_beta2=0.98 --config.weight_decay=0.1 --config.max_grad_norm=0.4

Use --config.input_files and --config.tokenizer to configure dataset and tokenizer paths for pre-training (see configs/pretraining.py for the full set of configuration options and hyperparameters.)

Pre-training notes

Our pre-training recipe is close to BERT, but there are a few differences:

  • We use a SentencePiece unigram tokenizer, instead of WordPiece
  • The next-sentence prediction (NSP) task from BERT is replaced with a sentence order prediction (SOP) from ALBERT. We do this primarily to simplify the data pipeline implementation, but past work has observed SOP to give better results than NSP.
  • BERT's Adam optimizer departs from the Adam paper in that it omits bias correction terms. This codebase uses Flax's implementation of Adam, which includes bias correction.
  • Pre-training uses a fixed maximum sequence length of 128, and does not increase the sequence length to 512 for the last 10% of training.
  • The wiki+books data used in this repository is designed to match the BERT paper as closely as possible, but it's not identical. The data used by BERT was never publicly available, so most BERT replications have this property.
  • Random masking and sentence shuffling occurs each time a batch of examples is sampled during training, rather than a single time during the data generation step.

Fine-tuning on GLUE

Sample command for fine-tuning on GLUE:

python3 run_classifier.py --config=configs/classifier.py --config.init_checkpoint="/path/to/checkpoint/folder/" --config.dataset_name="cola" --config.learning_rate="5e-5"

The dataset_name should be one of: cola, mrpc, qqp, sst2, stsb, mnli, qnli, rte. WNLI is not supported because BERT accuracy on WNLI is below the baseline, unless a special training recipe is used.

Leaderboard evaluation

To evaluate a model on GLUE, we typically run a sweep across different learning rates and use the development set to select the best one:

OUTPUT_DIR="$HOME/glue"
./sweep_glue.sh /path/to/checkpoint/folder $OUTPUT_DIR 5e-5 4e-5 3e-5 2e-5

Here are the results from one such sweep, using a "base" size model trained with batch size 1024 for 1M steps (see the single-host training command above). Our understanding is that this option most closely approximates the total compute resources used to train the original BERT-base, except that we do not increase the sequence length to 512 at any point during training. The learning rates the sweep found are random to some extent, since the sweep doubles as both a learning rate search and a chance to try different random seeds.

CoLA SST-2 MRPC (f1/a) STS-B (p/s) QQP (f1/acc) MNLI (m/mm) QNLI RTE
dev 56.5 91.9 90.8 / 87.3 88.3 / 88.4 87.0 / 90.4 84.9 / 85.5 92.1 70.4
test 57.0 92.5 87.3 / 82.8 87.8 / 87.0 71.3 / 89.1 85.1 / 84.3 92.1 66.7
lr 3e-5 3e-5 3e-5 5e-5 3e-5 2e-5 3e-5 5e-5

Here are the results from another "base" size model, this time trained with batch size 4096 for 125K steps (see the multi-host training command above). Note that model only sees half the number of examples compared to the single-host training command above.

CoLA SST-2 MRPC (f1/a) STS-B (p/s) QQP (f1/acc) MNLI (m/mm) QNLI RTE
dev 60.3 91.9 89.6 / 85.5 87.9 / 88.1 87.2 / 90.6 84.7 / 85.0 91.7 68.6
test 51.9 92.1 88.8 / 84.7 86.2 / 85.0 70.9 / 89.0 84.5 / 84.0 91.5 65.9
lr 4e-5 4e-5 5e-5 2e-5 4e-5 3e-5 3e-5 5e-5
Owner
Nikita Kitaev
Nikita Kitaev
AtlasNet: A Papier-Mâché Approach to Learning 3D Surface Generation

AtlasNet [Project Page] [Paper] [Talk] AtlasNet: A Papier-Mâché Approach to Learning 3D Surface Generation Thibault Groueix, Matthew Fisher, Vladimir

577 Dec 17, 2022
Look Closer: Bridging Egocentric and Third-Person Views with Transformers for Robotic Manipulation

Look Closer: Bridging Egocentric and Third-Person Views with Transformers for Robotic Manipulation Official PyTorch implementation for the paper Look

Rishabh Jangir 20 Nov 24, 2022
MAME is a multi-purpose emulation framework.

MAME's purpose is to preserve decades of software history. As electronic technology continues to rush forward, MAME prevents this important "vintage" software from being lost and forgotten.

Michael Murray 6 Oct 25, 2020
Code for "Neural 3D Scene Reconstruction with the Manhattan-world Assumption" CVPR 2022 Oral

News 05/10/2022 To make the comparison on ScanNet easier, we provide all quantitative and qualitative results of baselines here, including COLMAP, COL

ZJU3DV 365 Dec 30, 2022
ESTDepth: Multi-view Depth Estimation using Epipolar Spatio-Temporal Networks (CVPR 2021)

ESTDepth: Multi-view Depth Estimation using Epipolar Spatio-Temporal Networks (CVPR 2021) Project Page | Video | Paper | Data We present a novel metho

65 Nov 28, 2022
学习 python3 以来写的一些垃圾玩具……

和东哥做兄弟 Author: chiupam 版权 未经本人同意,仓库内所有资源文件,禁止任何公众号、自媒体、开发者进行任何形式的转载、发布、搬运。 声明 这不是一个开源项目,只是把 GitHub 当作一个代码的存储空间,本项目不接受任何开源要求。 仅用于学习研究,禁止用于商业用途,不能保证其合法性

Chiupam 67 Mar 26, 2022
Kaggle | 9th place single model solution for TGS Salt Identification Challenge

UNet for segmenting salt deposits from seismic images with PyTorch. General We, tugstugi and xuyuan, have participated in the Kaggle competition TGS S

Erdene-Ochir Tuguldur 276 Dec 20, 2022
Online-compatible Unsupervised Non-resonant Anomaly Detection Repository

Online-compatible Unsupervised Non-resonant Anomaly Detection Repository Repository containing all scripts used in the studies of Online-compatible Un

0 Nov 09, 2021
Planner_backend - Academic planner application designed for students and counselors.

Planner (backend) Academic planner application designed for students and advisors.

2 Dec 31, 2021
Semi-supervised learning for object detection

Source code for STAC: A Simple Semi-Supervised Learning Framework for Object Detection STAC is a simple yet effective SSL framework for visual object

Google Research 348 Dec 25, 2022
Automatic self-diagnosis program (python required)Automatic self-diagnosis program (python required)

auto-self-checker 자동으로 자가진단 해주는 프로그램(python 필요) 중요 이 프로그램이 실행될때에는 절대로 마우스포인터를 움직이거나 키보드를 건드리면 안된다(화면인식, 마우스포인터로 직접 클릭) 사용법 프로그램을 구동할 폴더 내의 cmd창에서 pip

1 Dec 30, 2021
Annotate with anyone, anywhere.

h h is the web app that serves most of the https://hypothes.is/ website, including the web annotations API at https://hypothes.is/api/. The Hypothesis

Hypothesis 2.6k Jan 08, 2023
Kohei's 5th place solution for xview3 challenge

xview3-kohei-solution Usage This repository assumes that the given data set is stored in the following locations: $ ls data/input/xview3/*.csv data/in

Kohei Ozaki 2 Jan 17, 2022
Multi-layer convolutional LSTM with Pytorch

Convolution_LSTM_pytorch Thanks for your attention. I haven't got time to maintain this repo for a long time. I recommend this repo which provides an

Zijie Zhuang 734 Jan 03, 2023
Code and data to accompany the camera-ready version of "Cross-Attention is All You Need: Adapting Pretrained Transformers for Machine Translation" in EMNLP 2021

Code and data to accompany the camera-ready version of "Cross-Attention is All You Need: Adapting Pretrained Transformers for Machine Translation" in EMNLP 2021

Mozhdeh Gheini 16 Jul 16, 2022
Oriented Object Detection: Oriented RepPoints + Swin Transformer/ReResNet

Oriented RepPoints for Aerial Object Detection The code for the implementation of “Oriented RepPoints + Swin Transformer/ReResNet”. Introduction Based

96 Dec 13, 2022
LIVECell - A large-scale dataset for label-free live cell segmentation

LIVECell dataset This document contains instructions of how to access the data associated with the submitted manuscript "LIVECell - A large-scale data

Sartorius Corporate Research 112 Jan 07, 2023
ViViT: Curvature access through the generalized Gauss-Newton's low-rank structure

ViViT is a collection of numerical tricks to efficiently access curvature from the generalized Gauss-Newton (GGN) matrix based on its low-rank structure. Provided functionality includes computing

Felix Dangel 12 Dec 08, 2022
Code for Referring Image Segmentation via Cross-Modal Progressive Comprehension, CVPR2020.

CMPC-Refseg Code of our CVPR 2020 paper Referring Image Segmentation via Cross-Modal Progressive Comprehension. Shaofei Huang*, Tianrui Hui*, Si Liu,

spyflying 55 Dec 01, 2022
A keras-based real-time model for medical image segmentation (CFPNet-M)

CFPNet-M: A Light-Weight Encoder-Decoder Based Network for Multimodal Biomedical Image Real-Time Segmentation This repository contains the implementat

268 Nov 27, 2022