Implementation of Sequence Generative Adversarial Nets with Policy Gradient

Related tags

Deep LearningSeqGAN
Overview

SeqGAN

Requirements:

  • Tensorflow r1.0.1
  • Python 2.7
  • CUDA 7.5+ (For GPU)

Introduction

Apply Generative Adversarial Nets to generating sequences of discrete tokens.

The illustration of SeqGAN. Left: D is trained over the real data and the generated data by G. Right: G is trained by policy gradient where the final reward signal is provided by D and is passed back to the intermediate action value via Monte Carlo search.

The research paper SeqGAN: Sequence Generative Adversarial Nets with Policy Gradient has been accepted at the Thirty-First AAAI Conference on Artificial Intelligence (AAAI-17).

We provide example codes to repeat the synthetic data experiments with oracle evaluation mechanisms. To run the experiment with default parameters:

$ python sequence_gan.py

You can change the all the parameters in sequence_gan.py.

The experiment has two stages. In the first stage, use the positive data provided by the oracle model and Maximum Likelihood Estimation to perform supervise learning. In the second stage, use adversarial training to improve the generator.

After running the experiments, you could get the negative log-likelihodd performance saved in save/experiment-log.txt like:

pre-training...
epoch:	0	nll:	10.1716
epoch:	5	nll:	9.42939
epoch:	10	nll:	9.2388
epoch:	15	nll:	9.11899
epoch:	20	nll:	9.13099
epoch:	25	nll:	9.14474
epoch:	30	nll:	9.12539
epoch:	35	nll:	9.13982
epoch:	40	nll:	9.135
epoch:	45	nll:	9.13081
epoch:	50	nll:	9.10678
epoch:	55	nll:	9.10694
epoch:	60	nll:	9.10349
epoch:	65	nll:	9.10403
epoch:	70	nll:	9.07613
epoch:	75	nll:	9.091
epoch:	80	nll:	9.08909
epoch:	85	nll:	9.0807
epoch:	90	nll:	9.08434
epoch:	95	nll:	9.08936
epoch:	100	nll:	9.07443
epoch:	105	nll:	9.08305
epoch:	110	nll:	9.06973
epoch:	115	nll:	9.07058
adversarial training...
epoch:	0	nll:	9.08457
epoch:	5	nll:	9.04511
epoch:	10	nll:	9.03079
epoch:	15	nll:	8.99239
epoch:	20	nll:	8.96401
epoch:	25	nll:	8.93864
epoch:	30	nll:	8.91642
epoch:	35	nll:	8.87761
epoch:	40	nll:	8.88582
epoch:	45	nll:	8.8592
epoch:	50	nll:	8.83388
epoch:	55	nll:	8.81342
epoch:	60	nll:	8.80247
epoch:	65	nll:	8.77778
epoch:	70	nll:	8.7567
epoch:	75	nll:	8.73002
epoch:	80	nll:	8.72488
epoch:	85	nll:	8.72233
epoch:	90	nll:	8.71473
epoch:	95	nll:	8.71163
epoch:	100	nll:	8.70113
epoch:	105	nll:	8.69879
epoch:	110	nll:	8.69208
epoch:	115	nll:	8.69291
epoch:	120	nll:	8.68371
epoch:	125	nll:	8.689
epoch:	130	nll:	8.68989
epoch:	135	nll:	8.68269
epoch:	140	nll:	8.68647
epoch:	145	nll:	8.68066
epoch:	150	nll:	8.6832

Note: this code is based on the previous work by ofirnachum. Many thanks to ofirnachum.

Owner
Lantao Yu
Ph.D. Student at Stanford CS Department
Lantao Yu
The repo of the preprinting paper "Labels Are Not Perfect: Inferring Spatial Uncertainty in Object Detection"

Inferring Spatial Uncertainty in Object Detection A teaser version of the code for the paper Labels Are Not Perfect: Inferring Spatial Uncertainty in

ZINING WANG 21 Mar 03, 2022
Geometric Algebra package for JAX

JAXGA - JAX Geometric Algebra GitHub | Docs JAXGA is a Geometric Algebra package on top of JAX. It can handle high dimensional algebras by storing onl

Robin Kahlow 36 Dec 22, 2022
This repository contains the scripts for downloading and validating scripts for the documents

HC4: HLTCOE CLIR Common-Crawl Collection This repository contains the scripts for downloading and validating scripts for the documents. Document ids,

JHU Human Language Technology Center of Excellence 6 Jun 07, 2022
Official implementation of the MM'21 paper Constrained Graphic Layout Generation via Latent Optimization

[MM'21] Constrained Graphic Layout Generation via Latent Optimization This repository provides the official code for the paper "Constrained Graphic La

Kotaro Kikuchi 73 Dec 27, 2022
Semi-supervised Representation Learning for Remote Sensing Image Classification Based on Generative Adversarial Networks

SSRL-for-image-classification Semi-supervised Representation Learning for Remote Sensing Image Classification Based on Generative Adversarial Networks

Feng 2 Nov 19, 2021
Model Zoo of BDD100K Dataset

Model Zoo of BDD100K Dataset

ETH VIS Group 200 Dec 27, 2022
A python package simulating the quasi-2D pseudospin-1/2 Gross-Pitaevskii equation with NVIDIA GPU acceleration.

A python package simulating the quasi-2D pseudospin-1/2 Gross-Pitaevskii equation with NVIDIA GPU acceleration. Introduction spinor-gpe is high-level,

2 Sep 20, 2022
A list of all papers and resoureces on Semantic Segmentation

Semantic-Segmentation A list of all papers and resoureces on Semantic Segmentation. Dataset importance SemanticSegmentation_DL Some implementation of

Alan Tang 1.1k Dec 12, 2022
Implementation of Artificial Neural Network Algorithm

Artificial Neural Network This repository contain implementation of Artificial Neural Network Algorithm in several programming languanges and framewor

Resha Dwika Hefni Al-Fahsi 1 Sep 14, 2022
Official PyTorch Implementation of Mask-aware IoU and maYOLACT Detector [BMVC2021]

The official implementation of Mask-aware IoU and maYOLACT detector. Our implementation is based on mmdetection. Mask-aware IoU for Anchor Assignment

Kemal Oksuz 46 Sep 29, 2022
Turning SymPy expressions into JAX functions

sympy2jax Turn SymPy expressions into parametrized, differentiable, vectorizable, JAX functions. All SymPy floats become trainable input parameters. S

Miles Cranmer 38 Dec 11, 2022
Graph WaveNet apdapted for brain connectivity analysis.

Graph WaveNet for brain network analysis This is the implementation of the Graph WaveNet model used in our manuscript: S. Wein , A. Schüller, A. M. To

4 Dec 17, 2022
Implementation of: "Exploring Randomly Wired Neural Networks for Image Recognition"

RandWireNN Unofficial PyTorch Implementation of: Exploring Randomly Wired Neural Networks for Image Recognition. Results Validation result on Imagenet

Seung-won Park 684 Nov 02, 2022
Fast, modular reference implementation and easy training of Semantic Segmentation algorithms in PyTorch.

TorchSeg This project aims at providing a fast, modular reference implementation for semantic segmentation models using PyTorch. Highlights Modular De

ycszen 1.4k Jan 02, 2023
Predict the latency time of the deep learning models

Deep Neural Network Prediction Step 1. Genernate random parameters and Run them sequentially : $ python3 collect_data.py -gp -ep -pp -pl pooling -num

QAQ 1 Nov 12, 2021
An End-to-End Machine Learning Library to Optimize AUC (AUROC, AUPRC).

Logo by Zhuoning Yuan LibAUC: A Machine Learning Library for AUC Optimization Website | Updates | Installation | Tutorial | Research | Github LibAUC a

Optimization for AI 176 Jan 07, 2023
A transformer model to predict pathogenic mutations

MutFormer MutFormer is an application of the BERT (Bidirectional Encoder Representations from Transformers) NLP (Natural Language Processing) model wi

Wang Genomics Lab 2 Nov 29, 2022
Official PyTorch implementation of "Synthesis of Screentone Patterns of Manga Characters"

Manga Character Screentone Synthesis Official PyTorch implementation of "Synthesis of Screentone Patterns of Manga Characters" presented in IEEE ISM 2

Tsubota 2 Nov 20, 2021
Optimal Adaptive Allocation using Deep Reinforcement Learning in a Dose-Response Study

Optimal Adaptive Allocation using Deep Reinforcement Learning in a Dose-Response Study Supplementary Materials for Kentaro Matsuura, Junya Honda, Imad

Kentaro Matsuura 4 Nov 01, 2022
Deep Q-Learning Network in pytorch (not actively maintained)

pytoch-dqn This project is pytorch implementation of Human-level control through deep reinforcement learning and I also plan to implement the followin

Hung-Tu Chen 342 Jan 01, 2023