PyTorch trainer and model for Sequence Classification

Overview

PyTorch-trainer-and-model-for-Sequence-Classification

After cloning the repository, modify your training data so that the training data is a .csv file and it has 2 columns: Text and Label

In the below example, we will assume that our training data has 3 labels, the name of our training data file is train_data.csv

Example Usage

Import dependencies

import pandas as pd
import numpy as np
from transformers import AutoModel, AutoTokenizer, AutoConfig

from EarlyStopping import *
from modelling import *
from utils import *

Specify arguments

args.pretrained_path will be the path of our pretrained language model

class args:
    fold = 0
    pretrained_path = 'bert-base-uncased'
    max_length = 400
    train_batch_size = 16
    val_batch_size = 64
    epochs = 5
    learning_rate = 1e-5
    accumulation_steps = 2
    num_splits = 5

Create train and validation data

In this example we will train the model using cross-validation. We will split our training data into args.num_splits folds.

df = pd.read_csv('./train_data.csv')
df = create_k_folds(df, args.num_splits)

df_train = df[df['kfold'] == args.fold].reset_index(drop = True)
df_valid = df[df['kfold'] == args.fold].reset_index(drop = True)

Load the language model and its tokenizer

config = AutoConfig.from_pretrained(args.path)
tokenizer = AutoTokenizer.from_pretrained(args.path)
model_transformer = AutoModel.from_pretrained(args.path)

Prepare train and validation dataloaders

features = []
for i in range(len(df_train)):
    features.append(prepare_features(tokenizer, df_train.iloc[i, :].to_dict(), args.max_length))
    
train_dataset = CreateDataset(features)
train_dataloader = create_dataloader(train_dataset, args.train_batch_size, 'train')

features = []
for i in range(len(df_valid)):
    features.append(prepare_features(tokenizer, df_valid.iloc[i, :].to_dict(), args.max_length))
    
val_dataset = CreateDataset(features)
val_dataloader = create_dataloader(val_dataset, args.val_batch_size, 'val')

Use EarlyStopping and customize the score function

NOTE: The customized score function should have 2 parameters: the logits, and the actual label

def accuracy(logits, labels):
    logits = logits.detach().cpu().numpy()
    labels = labels.detach().cpu().numpy()
    pred_classes = np.argmax(logits * (1 / np.sum(logits, axis = -1)).reshape(logits.shape[0], 1), axis = -1)
    pred_classes = pred_classes.reshape(labels.shape)
    
    return np.sum(pred_classes == labels) / labels.shape[0]

es = EarlyStopping(mode = 'max', patience = 3, monitor = 'val_acc', out_path = 'model.bin')
es.monitor_score_function = accuracy

Create and train the model

Calling the fit method, the training process will begin

model = Model(config, model_transformer, num_labels = 3)
model.to('cuda')
num_train_steps = int(len(train_dataset) / args.train_batch_size * args.epochs)
model.fit(args.epochs, args.learning_rate, num_train_steps, args.accumulation_steps, 
          train_dataloader, val_dataloader, es)

NOTE: To complete the cross-validation training process, run the code above again with args.fold equals 1, 2, ..., args.num_splits - 1

Owner
NhanTieu
NhanTieu
This is the official code of L2G, Unrolling and Recurrent Unrolling in Learning to Learn Graph Topologies.

Learning to Learn Graph Topologies This is the official code of L2G, Unrolling and Recurrent Unrolling in Learning to Learn Graph Topologies. Requirem

Stacy X PU 16 Dec 09, 2022
A TikTok-like recommender system for GitHub repositories based on Gorse

GitRec GitRec is the missing recommender system for GitHub repositories based on Gorse. Architecture The trending crawler crawls trending repositories

337 Jan 04, 2023
DeOldify - A Deep Learning based project for colorizing and restoring old images (and video!)

DeOldify - A Deep Learning based project for colorizing and restoring old images (and video!)

Jason Antic 15.8k Jan 04, 2023
Revealing and Protecting Labels in Distributed Training

Revealing and Protecting Labels in Distributed Training

Google Interns 0 Nov 09, 2022
A Fast Sequence Transducer Implementation with PyTorch Bindings

transducer A Fast Sequence Transducer Implementation with PyTorch Bindings. The corresponding publication is Sequence Transduction with Recurrent Neur

Awni Hannun 184 Dec 18, 2022
✅ How Robust are Fact Checking Systems on Colloquial Claims?. In NAACL-HLT, 2021.

How Robust are Fact Checking Systems on Colloquial Claims? Official PyTorch implementation of our NAACL paper: Byeongchang Kim*, Hyunwoo Kim*, Seokhee

Byeongchang Kim 19 Mar 15, 2022
Code for paper "Extract, Denoise and Enforce: Evaluating and Improving Concept Preservation for Text-to-Text Generation" EMNLP 2021

The repo provides the code for paper "Extract, Denoise and Enforce: Evaluating and Improving Concept Preservation for Text-to-Text Generation" EMNLP 2

Yuning Mao 18 May 24, 2022
Bottleneck Transformers for Visual Recognition

Bottleneck Transformers for Visual Recognition Experiments Model Params (M) Acc (%) ResNet50 baseline (ref) 23.5M 93.62 BoTNet-50 18.8M 95.11% BoTNet-

Myeongjun Kim 236 Jan 03, 2023
Experimenting with computer vision techniques to generate annotated image datasets from gameplay recordings automatically.

Experimenting with computer vision techniques to generate annotated image datasets from gameplay recordings automatically. The collected data will then be used to train a deep neural network that can

Martin Valchev 3 Apr 24, 2022
Repo for EchoVPR: Echo State Networks for Visual Place Recognition

EchoVPR Repo for EchoVPR: Echo State Networks for Visual Place Recognition Currently under development Dirs: data: pre-collected hidden representation

Anil Ozdemir 4 Oct 04, 2022
Memory-Augmented Model Predictive Control

Memory-Augmented Model Predictive Control This repository hosts the source code for the journal article "Composing MPC with LQR and Neural Networks fo

Fangyu Wu 1 Jun 19, 2022
AdelaiDet is an open source toolbox for multiple instance-level detection and recognition tasks.

AdelaiDet is an open source toolbox for multiple instance-level detection and recognition tasks.

Adelaide Intelligent Machines (AIM) Group 3k Jan 02, 2023
Code for BMVC2021 "MOS: A Low Latency and Lightweight Framework for Face Detection, Landmark Localization, and Head Pose Estimation"

MOS-Multi-Task-Face-Detect Introduction This repo is the official implementation of "MOS: A Low Latency and Lightweight Framework for Face Detection,

104 Dec 08, 2022
Speeding-Up Back-Propagation in DNN: Approximate Outer Product with Memory

Approximate Outer Product Gradient Descent with Memory Code for the numerical experiment of the paper Speeding-Up Back-Propagation in DNN: Approximate

2 Mar 02, 2022
Point detection through multi-instance deep heatmap regression for sutures in endoscopy

Suture detection PyTorch This repo contains the reference implementation of suture detection model in PyTorch for the paper Point detection through mu

artificial intelligence in the area of cardiovascular healthcare 3 Jul 16, 2022
[NeurIPS2021] Code Release of Learning Transferable Perturbations

Learning Transferable Adversarial Perturbations This is an official release of the paper Learning Transferable Adversarial Perturbations. The code is

Krishna Kanth 17 Nov 11, 2022
A PyTorch implementation of "SimGNN: A Neural Network Approach to Fast Graph Similarity Computation" (WSDM 2019).

SimGNN ⠀⠀⠀ A PyTorch implementation of SimGNN: A Neural Network Approach to Fast Graph Similarity Computation (WSDM 2019). Abstract Graph similarity s

Benedek Rozemberczki 534 Dec 25, 2022
PyTorch implementation of the Flow Gaussian Mixture Model (FlowGMM) model from our paper

Flow Gaussian Mixture Model (FlowGMM) This repository contains a PyTorch implementation of the Flow Gaussian Mixture Model (FlowGMM) model from our pa

Pavel Izmailov 124 Nov 06, 2022
Fbone (Flask bone) is a Flask (Python microframework) starter/template/bootstrap/boilerplate application.

Fbone (Flask bone) is a Flask (Python microframework) starter/template/bootstrap/boilerplate application.

Wilson 1.7k Dec 30, 2022
Official code repository for ICCV 2021 paper: Gravity-Aware Monocular 3D Human Object Reconstruction

GraviCap Official code repository for ICCV 2021 paper: Gravity-Aware Monocular 3D Human Object Reconstruction. Gravity-Aware Monocular 3D Human-Object

Rishabh Dabral 15 Dec 09, 2022