[ICLR'19] Trellis Networks for Sequence Modeling

Overview

TrellisNet for Sequence Modeling

PWC PWC

This repository contains the experiments done in paper Trellis Networks for Sequence Modeling by Shaojie Bai, J. Zico Kolter and Vladlen Koltun.

On the one hand, a trellis network is a temporal convolutional network with special structure, characterized by weight tying across depth and direct injection of the input into deep layers. On the other hand, we show that truncated recurrent networks are equivalent to trellis networks with special sparsity structure in their weight matrices. Thus trellis networks with general weight matrices generalize truncated recurrent networks. This allows trellis networks to serve as bridge between recurrent and convolutional architectures, benefitting from algorithmic and architectural techniques developed in either context. We leverage these relationships to design high-performing trellis networks that absorb ideas from both architectural families. Experiments demonstrate that trellis networks outperform the current state of the art on a variety of challenging benchmarks, including word-level language modeling on Penn Treebank and WikiText-103 (UPDATE: recently surpassed by Transformer-XL), character-level language modeling on Penn Treebank, and stress tests designed to evaluate long-term memory retention.

Our experiments were done in PyTorch. If you find our work, or this repository helpful, please consider citing our work:

@inproceedings{bai2018trellis,
  author    = {Shaojie Bai and J. Zico Kolter and Vladlen Koltun},
  title     = {Trellis Networks for Sequence Modeling},
  booktitle = {International Conference on Learning Representations (ICLR)},
  year      = {2019},
}

Datasets

The code should be directly runnable with PyTorch 1.0.0 or above. This repository contains the training script for the following tasks:

  • Sequential MNIST handwritten digit classification
  • Permuted Sequential MNIST that randomly permutes the pixel order in sequential MNIST
  • Sequential CIFAR-10 classification (more challenging, due to more intra-class variations, channel complexities and larger images)
  • Penn Treebank (PTB) word-level language modeling (with and without the mixture of softmax); vocabulary size 10K
  • Wikitext-103 (WT103) large-scale word-level language modeling; vocabulary size 268K
  • Penn Treebank medium-scale character-level language modeling

Note that these tasks are on very different scales, with unique properties that challenge sequence models in different ways. For example, word-level PTB is a small dataset that a typical model easily overfits, so judicious regularization is essential. WT103 is a hundred times larger, with less danger of overfitting, but with a vocabulary size of 268K that makes training more challenging (due to large embedding size).

Pre-trained Model(s)

We provide some reasonably good pre-trained weights here so that the users don't need to train from scratch. We'll update the table from time to time. (Note: if you train from scratch using different seeds, it's likely you will get better results :-))

Description Task Dataset Model
TrellisNet-LM Word-Level Language Modeling Penn Treebank (PTB) download (.pkl)
TrellisNet-LM Character-Level Language Modeling Penn Treebank (PTB) download (.pkl)

To use the pre-trained weights, use the flag --load_weight [.pkl PATH] when starting the training script (e.g., you can just use the default arg parameters). You can use the flag --eval turn on the evaluation mode only.

Usage

All tasks share the same underlying TrellisNet model, which is in file trellisnet.py (and the eventual models, including components like embedding layer, are in model.py). As discussed in the paper, TrellisNet is able to benefit significantly from techniques developed originally for RNNs as well as temporal convolutional networks (TCNs). Some of these techniques are also included in this repository. Each task is organized in the following structure:

[TASK_NAME] /
    data/
    logs/
    [TASK_NAME].py
    model.py
    utils.py
    data.py

where [TASK_NAME].py is the training script for the task (with argument flags; use -h to see the details).

Owner
CMU Locus Lab
Zico Kolter's Research Group
CMU Locus Lab
SSL_SLAM2: Lightweight 3-D Localization and Mapping for Solid-State LiDAR (mapping and localization separated) ICRA 2021

SSL_SLAM2 Lightweight 3-D Localization and Mapping for Solid-State LiDAR (Intel Realsense L515 as an example) This repo is an extension work of SSL_SL

Wang Han 王晗 1.3k Jan 08, 2023
A Lightweight Face Recognition and Facial Attribute Analysis (Age, Gender, Emotion and Race) Library for Python

deepface Deepface is a lightweight face recognition and facial attribute analysis (age, gender, emotion and race) framework for python. It is a hybrid

Sefik Ilkin Serengil 5.2k Jan 02, 2023
"NAS-Bench-301 and the Case for Surrogate Benchmarks for Neural Architecture Search".

NAS-Bench-301 This repository containts code for the paper: "NAS-Bench-301 and the Case for Surrogate Benchmarks for Neural Architecture Search". The

AutoML-Freiburg-Hannover 57 Nov 30, 2022
This is an official implementation of CvT: Introducing Convolutions to Vision Transformers.

Introduction This is an official implementation of CvT: Introducing Convolutions to Vision Transformers. We present a new architecture, named Convolut

Microsoft 408 Dec 30, 2022
SE3 Pose Interp - Interpolate camera pose or trajectory in SE3, pose interpolation, trajectory interpolation

SE3 Pose Interpolation Pose estimated from SLAM system are always discrete, and

Ran Cheng 4 Dec 15, 2022
How to train a CNN to 99% accuracy on MNIST in less than a second on a laptop

Training a NN to 99% accuracy on MNIST in 0.76 seconds A quick study on how fast you can reach 99% accuracy on MNIST with a single laptop. Our answer

Tuomas Oikarinen 42 Dec 10, 2022
Learning Optical Flow from a Few Matches (CVPR 2021)

Learning Optical Flow from a Few Matches This repository contains the source code for our paper: Learning Optical Flow from a Few Matches CVPR 2021 Sh

Shihao Jiang (Zac) 159 Dec 16, 2022
Source code for our paper "Empathetic Response Generation with State Management"

Source code for our paper "Empathetic Response Generation with State Management" this repository is maintained by both Jun Gao and Yuhan Liu Model Ove

Yuhan Liu 3 Oct 08, 2022
Command-line tool for downloading and extending the RedCaps dataset.

RedCaps Downloader This repository provides the official command-line tool for downloading and extending the RedCaps dataset. Users can seamlessly dow

RedCaps dataset 33 Dec 14, 2022
Implementation of the GVP-Transformer, which was used in the paper "Learning inverse folding from millions of predicted structures" for de novo protein design alongside Alphafold2

GVP Transformer (wip) Implementation of the GVP-Transformer, which was used in the paper Learning inverse folding from millions of predicted structure

Phil Wang 19 May 06, 2022
Pre-training of Graph Augmented Transformers for Medication Recommendation

G-Bert Pre-training of Graph Augmented Transformers for Medication Recommendation Intro G-Bert combined the power of Graph Neural Networks and BERT (B

101 Dec 27, 2022
Codebase for Image Classification Research, written in PyTorch.

pycls pycls is an image classification codebase, written in PyTorch. It was originally developed for the On Network Design Spaces for Visual Recogniti

Facebook Research 2k Jan 01, 2023
Predicting lncRNA–protein interactions based on graph autoencoders and collaborative training

Predicting lncRNA–protein interactions based on graph autoencoders and collaborative training Code for our paper "Predicting lncRNA–protein interactio

zhanglabNKU 1 Nov 29, 2022
discovering subdomains, hidden paths, extracting unique links

python-website-crawler discovering subdomains, hidden paths, extracting unique links pip install -r requirements.txt discover subdomain: You can give

merve 4 Sep 05, 2022
GNEE - GAT Neural Event Embeddings

GNEE - GAT Neural Event Embeddings This repository contains source code for the GNEE (GAT Neural Event Embeddings) method introduced in the paper: "Se

João Pedro Rodrigues Mattos 0 Sep 15, 2021
A pytorch-based real-time segmentation model for autonomous driving

CFPNet: Channel-Wise Feature Pyramid for Real-Time Semantic Segmentation This project contains the Pytorch implementation for the proposed CFPNet: pap

342 Dec 22, 2022
Bolt Online Learning Toolbox

Bolt Online Learning Toolbox Bolt features discriminative learning of linear predictors (e.g. SVM or Logistic Regression) using fast online learning a

Peter Prettenhofer 87 Dec 12, 2022
Gas detection for Raspberry Pi using ADS1x15 and MQ-2 sensors

Gas detection Gas detection for Raspberry Pi using ADS1x15 and MQ-2 sensors. Description The MQ-2 sensor can detect multiple gases (CO, H2, CH4, LPG,

Filip Š 15 Sep 30, 2022
Code for paper entitled "Improving Novelty Detection using the Reconstructions of Nearest Neighbours"

NLN: Nearest-Latent-Neighbours A repository containing the implementation of the paper entitled Improving Novelty Detection using the Reconstructions

Michael (Misha) Mesarcik 4 Dec 14, 2022
Implement face detection, and age and gender classification, and emotion classification.

YOLO Keras Face Detection Implement Face detection, and Age and Gender Classification, and Emotion Classification. (image from wider face dataset) Ove

Chloe 10 Nov 14, 2022