Using LSTM write Tang poetry

Overview

本教程将通过一个示例对LSTM进行介绍。通过搭建训练LSTM网络,我们将训练一个模型来生成唐诗。本文将对该实现进行详尽的解释,并阐明此模型的工作方式和原因。并不需要过多专业知识,但是可能需要新手花一些时间来理解的模型训练的实际情况。为了节省时间,请尽量选择GPU进行训练。

1 简介

本项目基于paddlepaddle2.0,结合长短期记忆(Long short-term memory, LSTM),以唐诗为数据集通过监督学习的方式,训练生成唐诗。

参考文档:http://colah.github.io/posts/2015-08-Understanding-LSTMs/

1.1 LSTM解决什么问题?

RNN突出的优点之一就是可以用来连接先前的信息到当前的任务上,有时候,我们仅仅需要知道先前的信息来执行当前的任务。例如,我们有一个语言模型用来基于先前的词来预测下一个词。如果试着预测 “the clouds are in the ___ ” 最后的词,我们并不需要任何其他的上下文下一个词显然就应该是sky。在这样的场景中,相关的信息和预测的词位置之间的间隔是非常小的,RNN 可以充分使用先前信息来预测。

但是同样会有一些更加复杂的场景。假设我们试着去预测“I grew up in France... I speak fluent French”最后的词。当前的信息建议下一个词可能是一种语言的名字,但是如果我们需要弄清楚是什么语言,我们是需要先前提到的离当前位置很远的 France 的上下文的。这说明相关信息和当前预测位置之间的间隔就肯定变得相当的大,在这个间隔不断增大时,RNN 会丧失学习到连接如此远的信息的能力。

LSTM的出现为了解决长序列训练过程中的梯度消失和梯度爆炸问题。简单来说,相比普通的RNN,LSTM能够在更长的序列中有更好的表现。

1.2 什么是LSTM?

Long Short Term 网络,一般就叫做 LSTM ——是一种 RNN 特殊的类型,可以学习长期依赖信息。LSTM 由Hochreiter & Schmidhuber (1997)提出,并在近期被Alex Graves进行了改良和推广。在很多问题,LSTM 都取得相当巨大的成功,并得到了广泛的使用。 LSTM 通过专门的设计来避免长期依赖问题。 所有 RNN 都具有一种重复神经网络模块的链式的形式。在标准的 RNN 中,这个重复的模块只有一个非常简单的结构,例如一个 tanh 层。 LSTM 同样是这样的结构,但是重复的模块拥有一个不同的结构。不同于 单一神经网络层,这里是有四个,以一种非常特殊的方式进行交互。 现在,我们先来熟悉一下图中使用的各种元素的图标。

在上面的图例中,每一条黑线传输着一整个向量,从一个节点的输出到其他节点的输入。粉色的圈代表 pointwise 的操作,诸如向量的和,而黄色的矩阵就是学习到的神经网络层。合在一起的线表示向量的连接,分开的线表示内容被复制,然后分发到不同的位置。

1.3 LSTM核心思想

LSTM 的关键就是细胞状态,水平线在图上方贯穿运行。细胞状态类似于传送带。直接在整个链上运行,只有一些少量的线性交互。信息在上面流传保持不变会很容易。

LSTM 有通过精心设计的称作为“门”的结构来去除或者增加信息到细胞状态的能力。门是一种让信息选择式通过的方法。他们包含一个 sigmoid 神经网络层和一个 pointwise 乘法操作。Sigmoid 层输出 0 到 1 之间的数值,描述每个部分有多少量可以通过。0 代表“不许任何量通过”,1 就指“允许任意量通过”! LSTM 拥有三个门,来保护和控制细胞状态。

1.4 LSTM详解

在我们 LSTM 中的第一步是决定我们会从细胞状态中丢弃什么信息。这个决定通过一个称为遗忘门层完成。该门会读取$h_{t-1}$和$x_t$,输出一个在 0 到 1 之间的数值给每个在细胞状态C_{t-1}中的数字。1 表示“完全保留”,0 表示“完全舍弃”。让我们回到语言模型的例子中来基于已经看到的预测下一个词。在这个问题中,细胞状态可能包含当前主语的性别,因此正确的代词可以被选择出来。当我们看到新的主语,我们希望忘记旧的主语。 下一步是确定什么样的新信息被存放在细胞状态中。这里包含两个部分。第一,sigmoid 层称输入门层决定什么值我们将要更新。然后,一个 tanh 层创建一个新的候选值向量会被加入到状态中。下一步,我们会讲这两个信息来产生对状态的更新。在我们语言模型的例子中,我们希望增加新的主语的性别到细胞状态中,来替代旧的需要忘记的主语。 现在是时候去更新上一个状态值C_{t−1}了,将其更新为C_t。前面的步骤以及决定了应该做什么,我们只需实际执行即可。 我们将上一个状态值乘以f_t,以此表达期待忘记的部分。之后我们将得到的值加上i_t*\tilde{C}_t。这个得到的是新的候选值, 按照我们决定更新每个状态值的多少来衡量. 在语言模型的例子中,对应着实际删除关于旧主题性别的信息,并添加新信息,正如在之前的步骤中描述的那样。 最后,我们需要决定要输出的内容。此输出将基于我们的单元状态,但将是过滤后的版本。首先,我们运行一个Sigmoid层,该层决定要输出的单元状态的哪些部分。然后,我们通过激活函数Tanh(将值规范化介于−1和1之间)并乘以Sigmoid的输出,这样我们就只输出我们决定的部分。 对于语言模型示例,由于它只是看到一个主语,因此可能要输出与动词相关的信息,以防万一。例如,它可能输出主语是单数还是复数,以便我们知道如果接下来是动词,则应将动词以哪种形式组合。

2 定义超参数

class Config(object):
    num_layers = 3                                      # LSTM层数
    data_path = 'work/tang_poem.npz'                    # 诗歌的文本文件存放路径
    lr = 1e-3                                           # 学习率
    use_gpu = True                                      # 是否使用GPU
    epoch = 20                                  
    batch_size = 4                                      # mini-batch大小
    maxlen = 125                                        # 超过这个长度的之后字被丢弃,小于这个长度的在前面补空格
    plot_every = 1000                                   # 隔batch 可视化一次
    max_gen_len = 200                                   # 生成诗歌最长长度
    model_path = "work/checkpoints/model.params.50"     # 预训练模型路径
    prefix_words = '欲穷千里目,更上一层楼'                 # 不是诗歌的组成部分,用来控制生成诗歌的意境
    start_words = '老夫聊发少年狂,'                       # 诗歌开始
    model_prefix = 'work/checkpoints/model.params'      # 模型保存路径
    embedding_dim = 256                                 # 词向量维度
    hidden_dim = 512                                    # LSTM hidden层维度

3.定义DataLoader

数据集由唐诗组成,包含唐诗57580首125字(不足和多余125字的都被补充或者截断)、ix2word以及word2ix共三个字典存储为npz格式

4.定义网络

网络由一层Embedding层和三层LSTM层再通过全连接层组成

  • input:[seq_len,batch_size]
  • 经过embedding层,embeddings(input)
    • output:[batch_size,seq_len,embedding_size]
  • 经过LSTM,lstm(embeds, (h_0, c_0)),输出output,hidden
    • output:[batch, seq_len, hidden_size]
  • Reshape再进过Linear层判别
    • output:[batch*seq_len, vocabsize]

5.训练过程

  • 输入的input为(batch_size,seq_len)
  • 通过input_,target = data_[:,:-1],data_[:,1:]将每句话分为前n-1个字作为真正的输入,后n-1个字作为label,size都是(batch_size, seq_len-1)
  • 经过网络,得出output:((seq_len-1)*batch, vocab_size)
  • 通过label经过reshape将target变成((seq_len-1)*batch)

损失函数为:crossEntropy,优化器为:Adam

6.生成唐诗

6.1 模式一 <首句续写唐诗>

例如:”老夫聊发少年狂“

老夫聊发少年狂,嫁得双鬟梳似桃。
青丝长发娇且红,輭舞脸低时未昏。
娇嚬欲尽一双转,笑语千里相竞言。
朝阳上去花欲尽,花时且落花前过。
秾雨霏霏满地晓,红妆白鸟飞下郭。
灯前织女嫁新租,袖里垂纶舞袖舞。
罗袖焰扬簷下樱,一宿十二花绵绵。
西施夹道春风暖,嫩粉萦丝弄金蘂。

6.2 模式二<藏头诗>

例如:”夜月一帘幽梦春风十里柔情“

夜半星初洽,月明星未稀。
一缄琼烛动,帘外玉环飞。
幽匣光华溢,梦中形影微。
春风吹蕙笏,风绪拂莓苔。
十月涵金井,里尘氛祲微。
柔荑暎肌骨,情酒围唇肌。

Entity-Based Knowledge Conflicts in Question Answering.

Entity-Based Knowledge Conflicts in Question Answering Run Instructions | Paper | Citation | License This repository provides the Substitution Framewo

Apple 35 Oct 19, 2022
Blind Video Temporal Consistency via Deep Video Prior

deep-video-prior (DVP) Code for NeurIPS 2020 paper: Blind Video Temporal Consistency via Deep Video Prior PyTorch implementation | paper | project web

Chenyang LEI 272 Dec 21, 2022
CAST: Character labeling in Animation using Self-supervision by Tracking

CAST: Character labeling in Animation using Self-supervision by Tracking (Published as a conference paper at EuroGraphics 2022) Note: The CAST paper c

15 Nov 18, 2022
This is the winning solution of the Endocv-2021 grand challange.

Endocv2021-winner [Paper] This is the winning solution of the Endocv-2021 grand challange. Dependencies pytorch # tested with 1.7 and 1.8 torchvision

Vajira Thambawita 14 Dec 03, 2022
✨风纪委员会自动投票脚本,利用Github Action帮你进行裁决操作(为了让其他风纪委员有案件可判,本程序从中午12点才开始运行,有需要请自己修改运行时间)

风纪委员会自动投票 本脚本通过使用Github Action来实现B站风纪委员的自动投票功能,喜欢请给我点个STAR吧! 如果你不是风纪委员,在符合风纪委员申请条件的情况下,本脚本会自动帮你申请 投票时间是早上八点,如果有需要请自行修改.github/workflows/Judge.yml中的时间,

Pesy Wu 25 Feb 17, 2021
Code for paper Adaptively Aligned Image Captioning via Adaptive Attention Time

Adaptively Aligned Image Captioning via Adaptive Attention Time This repository includes the implementation for Adaptively Aligned Image Captioning vi

Lun Huang 45 Aug 27, 2022
Immortal tracker

Immortal_tracker Prerequisite Our code is tested for Python 3.6. To install required liabraries: pip install -r requirements.txt Waymo Open Dataset P

74 Dec 03, 2022
adversarial_multi_armed_bandit_variable_plays

Adversarial Multi-Armed Bandit with Variable Plays This code is for paper: Adversarial Online Learning with Variable Plays in the Evasion-and-Pursuit

Yiyang Wang 1 Oct 28, 2021
Pytorch implementation of MLP-Mixer with loading pre-trained models.

MLP-Mixer-Pytorch PyTorch implementation of MLP-Mixer: An all-MLP Architecture for Vision with the function of loading official ImageNet pre-trained p

Qiushi Yang 2 Sep 29, 2022
A DeepStack custom model for detecting common objects in dark/night images and videos.

DeepStack_ExDark This repository provides a custom DeepStack model that has been trained and can be used for creating a new object detection API for d

MOSES OLAFENWA 98 Dec 24, 2022
Keeper for Ricochet Protocol, implemented with Apache Airflow

Ricochet Keeper This repository contains Apache Airflow DAGs for executing keeper operations for Ricochet Exchange. Usage You will need to run this us

Ricochet Exchange 5 May 24, 2022
Keras implementation of PersonLab for Multi-Person Pose Estimation and Instance Segmentation.

PersonLab This is a Keras implementation of PersonLab for Multi-Person Pose Estimation and Instance Segmentation. The model predicts heatmaps and vari

OCTI 160 Dec 21, 2022
Partial implementation of ODE-GAN technique from the paper Training Generative Adversarial Networks by Solving Ordinary Differential Equations

ODE GAN (Prototype) in PyTorch Partial implementation of ODE-GAN technique from the paper Training Generative Adversarial Networks by Solving Ordinary

Somshubra Majumdar 15 Feb 10, 2022
Image Lowpoly based on Centroid Voronoi Diagram via python-opencv and taichi

CVTLowpoly: Image Lowpoly via Centroid Voronoi Diagram Image Sharp Feature Extraction using Guide Filter's Local Linear Theory via opencv-python. The

Pupa 4 Jul 29, 2022
(CVPR2021) Kaleido-BERT: Vision-Language Pre-training on Fashion Domain

Kaleido-BERT: Vision-Language Pre-training on Fashion Domain Mingchen Zhuge*, Dehong Gao*, Deng-Ping Fan#, Linbo Jin, Ben Chen, Haoming Zhou, Minghui

248 Dec 04, 2022
Least Square Calibration for Peer Reviews

Least Square Calibration for Peer Reviews Requirements gurobipy - for solving convex programs GPy - for Bayesian baseline numpy pandas To generate p

Sigma <a href=[email protected]"> 1 Nov 01, 2021
An OpenAI Gym environment for multi-agent car racing based on Gym's original car racing environment.

Multi-Car Racing Gym Environment This repository contains MultiCarRacing-v0 a multiplayer variant of Gym's original CarRacing-v0 environment. This env

Igor Gilitschenski 56 Nov 01, 2022
A PyTorch-centric hybrid classical-quantum machine learning framework

torchquantum A PyTorch-centric hybrid classical-quantum dynamic neural networks framework. News Add a simple example script using quantum gates to do

MIT HAN Lab 400 Jan 02, 2023
All the code and files related to the MI-Lab of UE19CS305 course in sem 5

Machine-Intelligence-Lab-CS305 The compilation of all the code an drelated files from MI-Lab UE19CS305 (of batch 2019-2023) offered by PES University

Arvind Krishna 3 Nov 10, 2022
OBG-FCN - implementation of 'Object Boundary Guided Semantic Segmentation'

OBG-FCN This repository is to reproduce the implementation of 'Object Boundary Guided Semantic Segmentation' in http://arxiv.org/abs/1603.09742 Object

Jiu XU 3 Mar 11, 2019