Implementation of ETSformer, state of the art time-series Transformer, in Pytorch

Overview

ETSformer - Pytorch

Implementation of ETSformer, state of the art time-series Transformer, in Pytorch

Install

$ pip install etsformer-pytorch

Usage

import torch
from etsformer_pytorch import ETSFormer

model = ETSFormer(
    time_features = 4,
    model_dim = 512,                # in paper they use 512
    embed_kernel_size = 3,          # kernel size for 1d conv for input embedding
    layers = 2,                     # number of encoder and corresponding decoder layers
    heads = 8,                      # number of exponential smoothing attention heads
    K = 4,                          # num frequencies with highest amplitude to keep (attend to)
    dropout = 0.2                   # dropout (in paper they did 0.2)
)

timeseries = torch.randn(1, 1024, 4)

pred = model(timeseries, num_steps_forecast = 32) # (1, 32, 4) - (batch, num steps forecast, num time features)

For using ETSFormer for classification, using cross attention pooling on all latents and level output

import torch
from etsformer_pytorch import ETSFormer, ClassificationWrapper

etsformer = ETSFormer(
    time_features = 1,
    model_dim = 512,
    embed_kernel_size = 3,
    layers = 2,
    heads = 8,
    K = 4,
    dropout = 0.2
)

adapter = ClassificationWrapper(
    etsformer = etsformer,
    dim_head = 32,
    heads = 16,
    dropout = 0.2,
    level_kernel_size = 5,
    num_classes = 10
)

timeseries = torch.randn(1, 1024)

logits = adapter(timeseries) # (1, 10)

Citation

@misc{woo2022etsformer,
    title   = {ETSformer: Exponential Smoothing Transformers for Time-series Forecasting}, 
    author  = {Gerald Woo and Chenghao Liu and Doyen Sahoo and Akshat Kumar and Steven Hoi},
    year    = {2022},
    eprint  = {2202.01381},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
Comments
  • What are your thoughts on using latents for additional classification task

    What are your thoughts on using latents for additional classification task

    Hi! I was wondering if you have thought about aggregating seasonal and growth latents for additional tasks (for example classification)? What are the possible ways to bring latents into single feature vector in your opinion? The easiest one would be just get the mean along layers and time dimensions but that seams to be too naive. Another idea I had it to use Cross Attention mechanic with single time query key to aggregate latents:

    all_latents = torch.cat([latent_growths, latent_seasonals], dim=-1)
    all_latents = rearrange(all_latents, 'b n l d -> (b l) n d')
    # q = nn.Parameter(torch.randn(all_latents_dim))
    q = repeat(q, 'd -> b 1 d', b = all_latents.shape[0])
    agg_latent = cross_attention(query=q, context=all_latents)
    agg_latent = rearrange(all_latents, '(b l) n d -> b (l n) d')
    agg_latent = agg_latent.mean(dim=1) # may be we should have done it before cross attention?
    

    Would be great to hear your thoughts

    opened by inspirit 15
  • Pre LayerNorm might be required for k,v?

    Pre LayerNorm might be required for k,v?

    https://github.com/lucidrains/ETSformer-pytorch/blob/2561053007e919409b3255eb1d0852c68799d24f/etsformer_pytorch/etsformer_pytorch.py#L440

    In my early tests I see some instability in training results, I was wondering if it might be good idea to LayerNorm latents before constructing key and values?

    opened by inspirit 5
  • growth_term calculation error

    growth_term calculation error

    https://github.com/lucidrains/ETSformer-pytorch/blob/e1d8514b44d113ead523aa6307986833e68eecc5/etsformer_pytorch/etsformer_pytorch.py#L233-L235

    It looks like you are not using growth and growth_smoothing_weightsto calculate growth_term

    opened by inspirit 4
  • Backward gradient error

    Backward gradient error

    Hello,

    i was trying to run the provided class and see following error: Function ScatterBackward0 returned an invalid gradient at index 1 - got [64, 4, 128] but expected shape compatible with [64, 33, 128]

    model = ETSFormer(
                time_features = 9,
                model_dim = 128,
                embed_kernel_size = 3,
                layers = 2,
                heads = 4,
                K = 4,
                dropout = 0.2
            )
    

    input = torch.rand(64, 64, 9) x = model(input, num_steps_forecast = 16)

    opened by inspirit 3
  • Does ETS-Former allow adding features

    Does ETS-Former allow adding features

    @lucidrains Thanks for making the code of the model available!

    In your paper, you state that the model infers seasonal patterns itself, so that there is no need to add time features like week, month, etc.

    Still, to increase the applicability of your approach, does the current implementation allow to add any (time-invariant and time-varying) features, e.g., categorical or numeric?

    opened by StatMixedML 2
  • wrong order of arguments

    wrong order of arguments

    https://github.com/lucidrains/ETSformer-pytorch/blob/2e0d465576c15fc8d84c4673f93fdd71d45b799c/etsformer_pytorch/etsformer_pytorch.py#L327

    you pass latents on wrong order to Level module: according to forward method first should be growth and then seasonal

    opened by inspirit 1
  • Clarification regarding data pre-processing

    Clarification regarding data pre-processing

    Hello,

    I was trying to run the ETSformer for ETT dataset. The paper mentions that the dataset is split as 60/20/20 for train, validation and test. Could you give some insight as to how the dataset split is happening in the code.

    Thank you.

    opened by vageeshmaiya 2
Owner
Phil Wang
Working with Attention. It's all we need
Phil Wang
A repo to show how to use custom dataset to train s2anet, and change backbone to resnext101

A repo to show how to use custom dataset to train s2anet, and change backbone to resnext101

jedibobo 3 Dec 28, 2022
clDice - a Novel Topology-Preserving Loss Function for Tubular Structure Segmentation

README clDice - a Novel Topology-Preserving Loss Function for Tubular Structure Segmentation CVPR 2021 Authors: Suprosanna Shit and Johannes C. Paetzo

110 Dec 29, 2022
MagFace: A Universal Representation for Face Recognition and Quality Assessment

MagFace MagFace: A Universal Representation for Face Recognition and Quality Assessment in IEEE Conference on Computer Vision and Pattern Recognition

Qiang Meng 523 Jan 05, 2023
PAIRED in PyTorch 🔥

PAIRED This codebase provides a PyTorch implementation of Protagonist Antagonist Induced Regret Environment Design (PAIRED), which was first introduce

UCL DARK Lab 46 Dec 12, 2022
Learning Calibrated-Guidance for Object Detection in Aerial Images

Learning Calibrated-Guidance for Object Detection in Aerial Images arxiv We propose a simple yet effective Calibrated-Guidance (CG) scheme to enhance

51 Sep 22, 2022
Transformer - Transformer in PyTorch

Transformer 完成进度 Embeddings and PositionalEncoding with example. MultiHeadAttent

Tianyang Li 1 Jan 06, 2022
Official Repository for our ICCV2021 paper: Continual Learning on Noisy Data Streams via Self-Purified Replay

Continual Learning on Noisy Data Streams via Self-Purified Replay This repository contains the official PyTorch implementation for our ICCV2021 paper.

Jinseo Jeong 22 Nov 23, 2022
Differentiable rasterization applied to 3D model simplification tasks

nvdiffmodeling Differentiable rasterization applied to 3D model simplification tasks, as described in the paper: Appearance-Driven Automatic 3D Model

NVIDIA Research Projects 336 Dec 30, 2022
Born-Infeld (BI) for AI: Energy-Conserving Descent (ECD) for Optimization

Born-Infeld (BI) for AI: Energy-Conserving Descent (ECD) for Optimization This repository contains the code for the BBI optimizer, introduced in the p

G. Bruno De Luca 5 Sep 06, 2022
Rethinking Portrait Matting with Privacy Preserving

Rethinking Portrait Matting with Privacy Preserving This is the official repository of the paper Rethinking Portrait Matting with Privacy Preserving.

184 Jan 03, 2023
Validated, scalable, community developed variant calling, RNA-seq and small RNA analysis

Validated, scalable, community developed variant calling, RNA-seq and small RNA analysis. You write a high level configuration file specifying your in

Blue Collar Bioinformatics 917 Jan 03, 2023
Template repository for managing machine learning research projects built with PyTorch-Lightning

Tutorial Repository with a minimal example for showing how to deploy training across various compute infrastructure.

Sidd Karamcheti 3 Feb 11, 2022
Tool which allow you to detect and translate text.

Text detection and recognition This repository contains tool which allow to detect region with text and translate it one by one. Description Two pretr

Damian Panek 176 Nov 28, 2022
Benchmarks for the Optimal Power Flow Problem

Power Grid Lib - Optimal Power Flow This benchmark library is curated and maintained by the IEEE PES Task Force on Benchmarks for Validation of Emergi

A Library of IEEE PES Power Grid Benchmarks 207 Dec 08, 2022
Some bravo or inspiring research works on the topic of curriculum learning.

Towards Scalable Unpaired Virtual Try-On via Patch-Routed Spatially-Adaptive GAN Official code for NeurIPS 2021 paper "Towards Scalable Unpaired Virtu

131 Jan 07, 2023
Official implementation of "Membership Inference Attacks Against Self-supervised Speech Models"

Introduction Official implementation of "Membership Inference Attacks Against Self-supervised Speech Models". In this work, we demonstrate that existi

Wei-Cheng Tseng 7 Nov 01, 2022
Simulate genealogical trees and genomic sequence data using population genetic models

msprime msprime is a population genetics simulator based on tskit. Msprime can simulate random ancestral histories for a sample of individuals (consis

Tskit developers 150 Dec 14, 2022
Ready-to-use code and tutorial notebooks to boost your way into few-shot image classification.

Easy Few-Shot Learning Ready-to-use code and tutorial notebooks to boost your way into few-shot image classification. This repository is made for you

Sicara 399 Jan 08, 2023
Official pytorch implementation of the paper: "SinGAN: Learning a Generative Model from a Single Natural Image"

SinGAN Project | Arxiv | CVF | Supplementary materials | Talk (ICCV`19) Official pytorch implementation of the paper: "SinGAN: Learning a Generative M

Tamar Rott Shaham 3.2k Dec 25, 2022
Face Recognition Attendance Project

Face-Recognition-Attendance-Project In This Project You will learn how to mark attendance using face recognition, Hello Guys This is Gautam Kumar, Thi

Gautam Kumar 1 Dec 03, 2022