A modular, research-friendly framework for high-performance and inference of sequence models at many scales

Related tags

Deep Learningt5x
Overview

T5X

T5X is a modular, composable, research-friendly framework for high-performance, configurable, self-service training, evaluation, and inference of sequence models (starting with language) at many scales.

It is essentially a new and improved implementation of the T5 codebase (based on Mesh TensorFlow) in JAX and Flax.

Installation

Note that all the commands in this document should be run in the commandline of the TPU VM instance unless otherwise stated.

  1. Follow the instructions to set up a Google Cloud Platform (GCP) account and enable the Cloud TPU API.

    Note: While T5X works with GPU as well, we haven't heavily tested the GPU usage.

  2. Create a Cloud TPU VM instance following this instruction. We recommend that you develop your workflow in a single v3-8 TPU (i.e., --accelerator-type=v3-8) and scale up to pod slices once the pipeline is ready. In this README, we focus on using a single v3-8 TPU. See here to learn more about TPU architectures.

  3. With Cloud TPU VMs, you ssh directly into the host machine of the TPU VM. You can install packages, run your code run, etc. in the host machine. Once the TPU instance is created, ssh into it with

    gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --zone=${ZONE}

    where TPU_NAME and ZONE are the name and the zone used in step 2.

  4. Install T5X and the dependencies. JAX and Gin-config need to be installed from the source.

    git clone --branch=main https://github.com/google-research/t5x
    cd t5x
    
    python3 -m pip install -e . -f \
      https://storage.googleapis.com/jax-releases/libtpu_releases.html
    
  5. Create toogle Cloud Storage (GCS) bucket to store the dataset and model checkpoints. To create a GCS bucket, see these instructions.

Example: English to German translation

As a running example, we use the WMT14 En-De translation. The raw dataset is available in TensorFlow Datasets as "wmt_t2t_translate".

T5 casts the translation task such as the following

{'en': 'That is good.', 'de': 'Das ist gut.'}

to the form called "text-to-text":

{'inputs': 'translate English to German: That is good.', 'targets': 'Das ist gut.'}

This formulation allows many different classes of language tasks to be expressed in a uniform manner and a single encoder-decoder architecture can handle them without any task-specific parameters. For more detail, refer to the T5 paper (Raffel et al. 2019).

For a scalable data pipeline and an evaluation framework, we use SeqIO, which was factored out of the T5 library. A seqio.Task packages together the raw dataset, vocabulary, preprocessing such as tokenization and evaluation metrics such as BLEU and provides a tf.data instance.

The T5 library provides a number of seqio.Tasks that were used in the T5 paper. In this example, we use wmt_t2t_ende_v003.

Training

To run a training job, we use the t5x/train.py script.

# Model dir to save logs, ckpts, etc. in "gs://model_dir" format.
MODEL_DIR="..."

# Data dir to save the processed dataset in "gs://data_dir" format.
TFDS_DATA_DIR="..."
T5X_DIR="..."  # directory where the T5X repo is cloned.

python3 ${T5X_DIR}/t5x/train.py \
  --gin_file="t5x/examples/t5/t5_1_1/examples/t5_1_1_base_wmt_from_scratch.gin" \
  --gin.MODEL_DIR="'${MODEL_DIR}'" \
  --tfds_data_dir=${TFDS_DATA_DIR}

The configuration for this training run is defined in the Gin file t5_1_1_base_wmt_from_scratch.gin. Gin-config is a library to handle configurations based on dependency injection. Among many benefits, Gin allows users to pass custom components such as a custom model to the T5X library without having to modify the core library. The custom components section shows how this is done.

While the core library is independent of Gin, it is central to the examples we provide. Therefore, we provide a short introduction to Gin in the context of T5X. All the configurations are written to a file "config.gin" in MODEL_DIR. This makes debugging as well as reproducing the experiment much easier.

In addition to the config.json, model-info.txt file summarizes the model parameters (shape, names of the axes, partitioning info) as well as the optimizer states.

TensorBoard

To monitor the training in TensorBoard, it is much easier (due to authentification issues) to launch the TensorBoard on your own machine and not in the TPU VM. So in the commandline where you ssh'ed into the TPU VM, launch the TensorBoard with the logdir pointing to the MODEL_DIR.

# NB: run this on your machine not TPU VM!
MODEL_DIR="..."  # Copy from the TPU VM.
tensorboard --logdir=${MODEL_DIR}

Or you can launch the TensorBoard inside a Colab. In a Colab cell, run

from google.colab import auth
auth.authenticate_user()

to authorize the Colab to access the GCS bucket and launch the TensorBoard.

%load_ext tensorboard
model_dir = "..."  # Copy from the TPU VM.
%tensorboard --logdir=model_dir

TODO(hwchung): Add tfds preparation instruction

Fine-tuning

We can leverage the benefits of self-supervised pre-training by initializing from one of our pre-trained models. Here we use the T5.1.1 Base checkpoint.

# Model dir to save logs, ckpts, etc. in "gs://model_dir" format.
MODEL_DIR="..."

# Data dir to save the processed dataset in "gs://data_dir" format.
TFDS_DATA_DIR="..."
T5X_DIR="..."  # directory where the T5X repo is cloned.

python3 ${T5X_DIR}/t5x/train.py \
  --gin_file="t5x/examples/t5/t5_1_1/examples/t5_1_1_base_wmt_finetune.gin" \
  --gin.MODEL_DIR="'${MODEL_DIR}'" \
  --tfds_data_dir=${TFDS_DATA_DIR}

Note: when supplying a string, dict, list, tuple value, or a bash variable via a flag, you must put it in quotes. In the case of strings, it requires "triple quotes" ("' '" ). For example: --gin.utils.DatasetConfig.split="'validation'" or --gin.MODEL_DIR="'${MODEL_DIR}'".

Gin makes it easy to change a number of configurations. For example, you can change the partitioning.ModelBasedPjitPartitioner.num_partitions (overriding the value in t5_1_1_base_wmt_from_scratch.gin) to chanage the parallelism strategy and pass it as a commandline arg.

--gin.partitioning.ModelBasedPjitPartitioner.num_partitions=8

Evaluation

To run the offline (i.e. without training) evaluation, you can use t5x/eval.py script.

EVAL_OUTPUT_DIR="..."  # directory to write eval output
T5X_DIR="..."  # directory where the t5x is cloned, e.g., ${HOME}"/t5x".
TFDS_DATA_DIR="..."
CHECKPOINT_PATH="..."

python3 ${T5X_DIR}/t5x/eval.py \
  --gin_file="t5x/examples/t5/t5_1_1/examples/t5_1_1_base_wmt_eval.gin" \
  --gin.CHECKPOINT_PATH="'${CHECKPOINT_PATH}'" \
  --gin.EVAL_OUTPUT_DIR="'${EVAL_OUTPUT_DIR}'" \
  --tfds_data_dir=${TFDS_DATA_DIR}

Inference

To run inference, you can use t5x/infer.py script. Here we use the same seqio.Task, but for inference we do not use the targets features other than logging them alongside the prediction in a JSON file.

INFER_OUTPUT_DIR="..."  # directory to write infer output
T5X_DIR="..."  # directory where the t5x is cloned, e.g., ${HOME}"/t5x".
TFDS_DATA_DIR="..."
CHECKPOINT_PATH="..."

python3 ${T5X_DIR}/t5x/infer.py \
  --gin_file="t5x/examples/t5/t5_1_1/examples/t5_1_1_base_wmt_infer.gin" \
  --gin.CHECKPOINT_PATH="'${CHECKPOINT_PATH}'" \
  --gin.INFER_OUTPUT_DIR="'${INFER_OUTPUT_DIR}'" \
  --tfds_data_dir=${TFDS_DATA_DIR}

Custom components

The translation example uses the encoder-decoder model that T5X provides as well as the dataset from the T5 library. This section shows how you can use your own dataset and a model and pass via Gin.

Example: custom dataset in a user directory

For this example, we have the following directory structure with ${HOME}/dir1/user_dir representing a user directory with custom components.

${HOME}
└── dir1
    └── user_dir
        ├── t5_1_1_base_de_en.gin
        └── tasks.py

As an example, let's define a new dataset. Here we use the same Translation dataset but we define the translation task in the opposite direction, i.e., German to English intead of English to German. We define this task in tasks.py

# ${HOME}/dir1/user_dir/tasks.py

import functools
import seqio
import tensorflow_datasets as tfds
from t5.evaluation import metrics
from t5.data import preprocessors

vocabulary = seqio.SentencePieceVocabulary(
    'gs://t5-data/vocabs/cc_all.32000/sentencepiece.model', extra_ids=100)
output_features = {
    'inputs': seqio.Feature(vocabulary=vocabulary),
    'targets': seqio.Feature(vocabulary=vocabulary)
}

seqio.TaskRegistry.add(
    'wmt_t2t_de_en_v003',
    source=seqio.TfdsDataSource(tfds_name='wmt_t2t_translate/de-en:1.0.0'),
    preprocessors=[
        functools.partial(
            preprocessors.translate,
            source_language='de', target_language='en'),
        seqio.preprocessors.tokenize,
        seqio.CacheDatasetPlaceholder(),
        seqio.preprocessors.append_eos_after_trim,
    ],
    metric_fns=[metrics.bleu],
    output_features=output_features)

In the Gin file, most of the settings are equivalent to those used in the En->De example. So we include the Gin file from that example. To use "wmt_t2t_de_en_v003" task we just defined, we need to import the task module "tasks.py". Note that we use a relative path defined with respect to the user directory. This will be specified as a flag.

# ${HOME}/dir1/user_dir/t5_1_1_base_de_en.gin
from __gin__ import dynamic_registration
import tasks  # This imports the task defined in dir1/user_dir/tasks.py.

include "t5x-tmp/t5x/examples/t5/t5_1_1/examples/t5_1_1_base_wmt_from_scratch.gin"
MIXTURE_OR_TASK_NAME = "wmt_t2t_de_en_v003"

Finally, we launch training passing the user directory as a flag gin_search_paths such that the Gin file and python modules can be specified with relative paths.

PROJECT_DIR=${HOME}"/dir1/user_dir"
T5X_DIR="..."  # directory where the t5x is cloned.
TFDS_DATA_DIR="..."
MODEL_DIR="..."
export PYTHONPATH=${PROJECT_DIR}

python3 ${T5X_DIR}/t5x/train.py \
  --gin_search_paths=${PROJECT_DIR} \
  --gin_file="t5_1_1_base_de_en.gin" \
  --gin.MODEL_DIR="'${MODEL_DIR}'" \
  --tfds_data_dir=${TFDS_DATA_DIR}

Released Checkpoints

We release the checkpoints for the T5.1.1 models in a native T5X format.

These are converted from the public Mesh TensorFlow checkpoints .

Compatibility with the Mesh TensorFlow checkpoints

The Mesh TensorFlow checkpoints trained using the T5 library can be directly loaded into T5X. For example, we can rerun the fine-tuning example initializing from the MTF checkpoint by changing the INIT_CHECKPOINT Gin macro.

# Model dir to save logs, ckpts, etc. in "gs://model_dir" format.
MODEL_DIR="..."

# Data dir to save the processed dataset in "gs://data_dir" format.
TFDS_DATA_DIR="..."
T5X_DIR="..."  # directory where the T5X repo is cloned.

python3 ${T5X_DIR}/t5x/train.py \
  --gin_file="t5x/examples/t5/t5_1_1/examples/wmt19_ende_from_scratch.gin" \
  --gin.MODEL_DIR="'${MODEL_DIR}'" \
  --gin.MIXTURE_OR_TASK_NAME="'wmt_t2t_ende_v003'" \
  --gin.INIT_CHECKPOINT="'gs://t5-data/pretrained_models/t5.1.1.base/model.ckpt-1000000'" \
  --tfds_data_dir=${TFDS_DATA_DIR}

Note that restoring directly from the Mesh TensorFlow checkpoints can be inefficient if heavy model parallelism is used for large models. This is because each host loads the entire copy of the model first and then keep only the relevant slices dictated by the model parallelism specification. If you have Mesh TensorFlow checkpoints that you run often, we recommend converting the checkpoints to T5X native format using Checkpointer.convert_from_tf_checkpoint.

TODO(hwchung): Add a conversion script.

Note

This is not an officially supported Google product

Owner
Google Research
Google Research
Rocket-recycling with Reinforcement Learning

Rocket-recycling with Reinforcement Learning Developed by: Zhengxia Zou I have long been fascinated by the recovery process of SpaceX rockets. In this

Zhengxia Zou 202 Jan 03, 2023
Codes for "CSDI: Conditional Score-based Diffusion Models for Probabilistic Time Series Imputation"

CSDI This is the github repository for the NeurIPS 2021 paper "CSDI: Conditional Score-based Diffusion Models for Probabilistic Time Series Imputation

106 Jan 04, 2023
Pyserini is a Python toolkit for reproducible information retrieval research with sparse and dense representations.

Pyserini Pyserini is a Python toolkit for reproducible information retrieval research with sparse and dense representations. Retrieval using sparse re

Castorini 706 Dec 29, 2022
The Incredible PyTorch: a curated list of tutorials, papers, projects, communities and more relating to PyTorch.

This is a curated list of tutorials, projects, libraries, videos, papers, books and anything related to the incredible PyTorch. Feel free to make a pu

Ritchie Ng 9.2k Jan 02, 2023
Implementation of character based convolutional neural network

Character Based CNN This repo contains a PyTorch implementation of a character-level convolutional neural network for text classification. The model a

Ahmed BESBES 248 Nov 21, 2022
State of the art Semantic Sentence Embeddings

Contrastive Tension State of the art Semantic Sentence Embeddings Published Paper · Huggingface Models · Report Bug Overview This is the official code

Fredrik Carlsson 88 Dec 30, 2022
Prior-Guided Multi-View 3D Head Reconstruction

Prior-Guided Head MVS This repository includes some reconstruction results of our IEEE TMM 2021 paper, Prior-Guided Multi-View 3D Head Reconstruction.

11 Aug 17, 2022
Model Zoo for AI Model Efficiency Toolkit

We provide a collection of popular neural network models and compare their floating point and quantized performance.

Qualcomm Innovation Center 137 Jan 03, 2023
Zero-Shot Text-to-Image Generation VQGAN+CLIP Dockerized

VQGAN-CLIP-Docker About Zero-Shot Text-to-Image Generation VQGAN+CLIP Dockerized This is a stripped and minimal dependency repository for running loca

Kevin Costa 73 Sep 11, 2022
Este conversor criará a medida exata para sua receita de capuccino gelado da grandiosa Rafaella Ballerini!

ConversorDeMedidas_CapuccinoGelado Este conversor criará a medida exata para sua receita de capuccino gelado da grandiosa Rafaella Ballerini! Requirem

Arthur Ottoni Ribeiro 48 Nov 15, 2022
official code for dynamic convolution decomposition

Revisiting Dynamic Convolution via Matrix Decomposition (ICLR 2021) A pytorch implementation of DCD. If you use this code in your research please cons

Yunsheng Li 110 Nov 23, 2022
PyTorch and Tensorflow functional model definitions

functional-zoo Model definitions and pretrained weights for PyTorch and Tensorflow PyTorch, unlike lua torch, has autograd in it's core, so using modu

Sergey Zagoruyko 590 Dec 22, 2022
Contrastive Learning Inverts the Data Generating Process

Official code to reproduce the results and data presented in the paper Contrastive Learning Inverts the Data Generating Process.

71 Nov 25, 2022
Auxiliary Raw Net (ARawNet) is a ASVSpoof detection model taking both raw waveform and handcrafted features as inputs, to balance the trade-off between performance and model complexity.

Overview This repository is an implementation of the Auxiliary Raw Net (ARawNet), which is ASVSpoof detection system taking both raw waveform and hand

6 Jul 08, 2022
Translation-equivariant Image Quantizer for Bi-directional Image-Text Generation

Translation-equivariant Image Quantizer for Bi-directional Image-Text Generation Woncheol Shin1, Gyubok Lee1, Jiyoung Lee1, Joonseok Lee2,3, Edward Ch

Woncheol Shin 7 Sep 26, 2022
Generative Flow Networks

Flow Network based Generative Models for Non-Iterative Diverse Candidate Generation Implementation for our paper, submitted to NeurIPS 2021 (also chec

Emmanuel Bengio 381 Jan 04, 2023
😇A pyTorch implementation of the DeepMoji model: state-of-the-art deep learning model for analyzing sentiment, emotion, sarcasm etc

------ Update September 2018 ------ It's been a year since TorchMoji and DeepMoji were released. We're trying to understand how it's being used such t

Hugging Face 865 Dec 24, 2022
Reduce end to end training time from days to hours (or hours to minutes), and energy requirements/costs by an order of magnitude using coresets and data selection.

COResets and Data Subset selection Reduce end to end training time from days to hours (or hours to minutes), and energy requirements/costs by an order

decile-team 244 Jan 09, 2023
PyBrain - Another Python Machine Learning Library.

PyBrain -- the Python Machine Learning Library =============================================== INSTALLATION ------------ Quick answer: make sure you

2.8k Dec 31, 2022
[Arxiv preprint] Causality-inspired Single-source Domain Generalization for Medical Image Segmentation (code&data-processing pipeline)

Causality-inspired Single-source Domain Generalization for Medical Image Segmentation Arxiv preprint Repository under construction. Might still be bug

Cheng 31 Dec 27, 2022