Using LSTM write Tang poetry

Overview

本教程将通过一个示例对LSTM进行介绍。通过搭建训练LSTM网络,我们将训练一个模型来生成唐诗。本文将对该实现进行详尽的解释,并阐明此模型的工作方式和原因。并不需要过多专业知识,但是可能需要新手花一些时间来理解的模型训练的实际情况。为了节省时间,请尽量选择GPU进行训练。

1 简介

本项目基于paddlepaddle2.0,结合长短期记忆(Long short-term memory, LSTM),以唐诗为数据集通过监督学习的方式,训练生成唐诗。

参考文档:http://colah.github.io/posts/2015-08-Understanding-LSTMs/

1.1 LSTM解决什么问题?

RNN突出的优点之一就是可以用来连接先前的信息到当前的任务上,有时候,我们仅仅需要知道先前的信息来执行当前的任务。例如,我们有一个语言模型用来基于先前的词来预测下一个词。如果试着预测 “the clouds are in the ___ ” 最后的词,我们并不需要任何其他的上下文下一个词显然就应该是sky。在这样的场景中,相关的信息和预测的词位置之间的间隔是非常小的,RNN 可以充分使用先前信息来预测。

但是同样会有一些更加复杂的场景。假设我们试着去预测“I grew up in France... I speak fluent French”最后的词。当前的信息建议下一个词可能是一种语言的名字,但是如果我们需要弄清楚是什么语言,我们是需要先前提到的离当前位置很远的 France 的上下文的。这说明相关信息和当前预测位置之间的间隔就肯定变得相当的大,在这个间隔不断增大时,RNN 会丧失学习到连接如此远的信息的能力。

LSTM的出现为了解决长序列训练过程中的梯度消失和梯度爆炸问题。简单来说,相比普通的RNN,LSTM能够在更长的序列中有更好的表现。

1.2 什么是LSTM?

Long Short Term 网络,一般就叫做 LSTM ——是一种 RNN 特殊的类型,可以学习长期依赖信息。LSTM 由Hochreiter & Schmidhuber (1997)提出,并在近期被Alex Graves进行了改良和推广。在很多问题,LSTM 都取得相当巨大的成功,并得到了广泛的使用。 LSTM 通过专门的设计来避免长期依赖问题。 所有 RNN 都具有一种重复神经网络模块的链式的形式。在标准的 RNN 中,这个重复的模块只有一个非常简单的结构,例如一个 tanh 层。 LSTM 同样是这样的结构,但是重复的模块拥有一个不同的结构。不同于 单一神经网络层,这里是有四个,以一种非常特殊的方式进行交互。 现在,我们先来熟悉一下图中使用的各种元素的图标。

在上面的图例中,每一条黑线传输着一整个向量,从一个节点的输出到其他节点的输入。粉色的圈代表 pointwise 的操作,诸如向量的和,而黄色的矩阵就是学习到的神经网络层。合在一起的线表示向量的连接,分开的线表示内容被复制,然后分发到不同的位置。

1.3 LSTM核心思想

LSTM 的关键就是细胞状态,水平线在图上方贯穿运行。细胞状态类似于传送带。直接在整个链上运行,只有一些少量的线性交互。信息在上面流传保持不变会很容易。

LSTM 有通过精心设计的称作为“门”的结构来去除或者增加信息到细胞状态的能力。门是一种让信息选择式通过的方法。他们包含一个 sigmoid 神经网络层和一个 pointwise 乘法操作。Sigmoid 层输出 0 到 1 之间的数值,描述每个部分有多少量可以通过。0 代表“不许任何量通过”,1 就指“允许任意量通过”! LSTM 拥有三个门,来保护和控制细胞状态。

1.4 LSTM详解

在我们 LSTM 中的第一步是决定我们会从细胞状态中丢弃什么信息。这个决定通过一个称为遗忘门层完成。该门会读取$h_{t-1}$和$x_t$,输出一个在 0 到 1 之间的数值给每个在细胞状态C_{t-1}中的数字。1 表示“完全保留”,0 表示“完全舍弃”。让我们回到语言模型的例子中来基于已经看到的预测下一个词。在这个问题中,细胞状态可能包含当前主语的性别,因此正确的代词可以被选择出来。当我们看到新的主语,我们希望忘记旧的主语。 下一步是确定什么样的新信息被存放在细胞状态中。这里包含两个部分。第一,sigmoid 层称输入门层决定什么值我们将要更新。然后,一个 tanh 层创建一个新的候选值向量会被加入到状态中。下一步,我们会讲这两个信息来产生对状态的更新。在我们语言模型的例子中,我们希望增加新的主语的性别到细胞状态中,来替代旧的需要忘记的主语。 现在是时候去更新上一个状态值C_{t−1}了,将其更新为C_t。前面的步骤以及决定了应该做什么,我们只需实际执行即可。 我们将上一个状态值乘以f_t,以此表达期待忘记的部分。之后我们将得到的值加上i_t*\tilde{C}_t。这个得到的是新的候选值, 按照我们决定更新每个状态值的多少来衡量. 在语言模型的例子中,对应着实际删除关于旧主题性别的信息,并添加新信息,正如在之前的步骤中描述的那样。 最后,我们需要决定要输出的内容。此输出将基于我们的单元状态,但将是过滤后的版本。首先,我们运行一个Sigmoid层,该层决定要输出的单元状态的哪些部分。然后,我们通过激活函数Tanh(将值规范化介于−1和1之间)并乘以Sigmoid的输出,这样我们就只输出我们决定的部分。 对于语言模型示例,由于它只是看到一个主语,因此可能要输出与动词相关的信息,以防万一。例如,它可能输出主语是单数还是复数,以便我们知道如果接下来是动词,则应将动词以哪种形式组合。

2 定义超参数

class Config(object):
    num_layers = 3                                      # LSTM层数
    data_path = 'work/tang_poem.npz'                    # 诗歌的文本文件存放路径
    lr = 1e-3                                           # 学习率
    use_gpu = True                                      # 是否使用GPU
    epoch = 20                                  
    batch_size = 4                                      # mini-batch大小
    maxlen = 125                                        # 超过这个长度的之后字被丢弃,小于这个长度的在前面补空格
    plot_every = 1000                                   # 隔batch 可视化一次
    max_gen_len = 200                                   # 生成诗歌最长长度
    model_path = "work/checkpoints/model.params.50"     # 预训练模型路径
    prefix_words = '欲穷千里目,更上一层楼'                 # 不是诗歌的组成部分,用来控制生成诗歌的意境
    start_words = '老夫聊发少年狂,'                       # 诗歌开始
    model_prefix = 'work/checkpoints/model.params'      # 模型保存路径
    embedding_dim = 256                                 # 词向量维度
    hidden_dim = 512                                    # LSTM hidden层维度

3.定义DataLoader

数据集由唐诗组成,包含唐诗57580首125字(不足和多余125字的都被补充或者截断)、ix2word以及word2ix共三个字典存储为npz格式

4.定义网络

网络由一层Embedding层和三层LSTM层再通过全连接层组成

  • input:[seq_len,batch_size]
  • 经过embedding层,embeddings(input)
    • output:[batch_size,seq_len,embedding_size]
  • 经过LSTM,lstm(embeds, (h_0, c_0)),输出output,hidden
    • output:[batch, seq_len, hidden_size]
  • Reshape再进过Linear层判别
    • output:[batch*seq_len, vocabsize]

5.训练过程

  • 输入的input为(batch_size,seq_len)
  • 通过input_,target = data_[:,:-1],data_[:,1:]将每句话分为前n-1个字作为真正的输入,后n-1个字作为label,size都是(batch_size, seq_len-1)
  • 经过网络,得出output:((seq_len-1)*batch, vocab_size)
  • 通过label经过reshape将target变成((seq_len-1)*batch)

损失函数为:crossEntropy,优化器为:Adam

6.生成唐诗

6.1 模式一 <首句续写唐诗>

例如:”老夫聊发少年狂“

老夫聊发少年狂,嫁得双鬟梳似桃。
青丝长发娇且红,輭舞脸低时未昏。
娇嚬欲尽一双转,笑语千里相竞言。
朝阳上去花欲尽,花时且落花前过。
秾雨霏霏满地晓,红妆白鸟飞下郭。
灯前织女嫁新租,袖里垂纶舞袖舞。
罗袖焰扬簷下樱,一宿十二花绵绵。
西施夹道春风暖,嫩粉萦丝弄金蘂。

6.2 模式二<藏头诗>

例如:”夜月一帘幽梦春风十里柔情“

夜半星初洽,月明星未稀。
一缄琼烛动,帘外玉环飞。
幽匣光华溢,梦中形影微。
春风吹蕙笏,风绪拂莓苔。
十月涵金井,里尘氛祲微。
柔荑暎肌骨,情酒围唇肌。

CM building dataset Timisoara

CM_building_dataset_Timisoara Date created: Febr-2020 The Timi\c{s}oara Building Dataset - TMBuD - is composed of 160 images with the resolution of 76

Orhei Ciprian 5 Sep 07, 2022
Leveraging Instance-, Image- and Dataset-Level Information for Weakly Supervised Instance Segmentation

Leveraging Instance-, Image- and Dataset-Level Information for Weakly Supervised Instance Segmentation This paper has been accepted and early accessed

Yun Liu 39 Sep 20, 2022
Uncertainty Estimation via Response Scaling for Pseudo-mask Noise Mitigation in Weakly-supervised Semantic Segmentation

Uncertainty Estimation via Response Scaling for Pseudo-mask Noise Mitigation in Weakly-supervised Semantic Segmentation Introduction This is a PyTorch

XMed-Lab 30 Sep 23, 2022
LiDAR R-CNN: An Efficient and Universal 3D Object Detector

LiDAR R-CNN: An Efficient and Universal 3D Object Detector Introduction This is the official code of LiDAR R-CNN: An Efficient and Universal 3D Object

TuSimple 295 Jan 05, 2023
Official Code Implementation of the paper : XAI for Transformers: Better Explanations through Conservative Propagation

Official Code Implementation of The Paper : XAI for Transformers: Better Explanations through Conservative Propagation For the SST-2 and IMDB expermin

Ameen Ali 23 Dec 30, 2022
A simple rest api serving a deep learning model that classifies human gender based on their faces. (vgg16 transfare learning)

this is a simple rest api serving a deep learning model that classifies human gender based on their faces. (vgg16 transfare learning)

crispengari 5 Dec 09, 2021
2 Jul 19, 2022
[CVPR'21] Projecting Your View Attentively: Monocular Road Scene Layout Estimation via Cross-view Transformation

Projecting Your View Attentively: Monocular Road Scene Layout Estimation via Cross-view Transformation Weixiang Yang, Qi Li, Wenxi Liu, Yuanlong Yu, Y

118 Dec 26, 2022
Churn-Prediction-Project - In this project, a churn prediction model is developed for a private bank as a term project for Data Mining class.

Churn-Prediction-Project In this project, a churn prediction model is developed for a private bank as a term project for Data Mining class. Project in

1 Jan 03, 2022
Official implementation of the paper "Topographic VAEs learn Equivariant Capsules"

Topographic Variational Autoencoder Paper: https://arxiv.org/abs/2109.01394 Getting Started Install requirements with Anaconda: conda env create -f en

T. Andy Keller 69 Dec 12, 2022
Principled Detection of Out-of-Distribution Examples in Neural Networks

ODIN: Out-of-Distribution Detector for Neural Networks This is a PyTorch implementation for detecting out-of-distribution examples in neural networks.

189 Nov 29, 2022
End-to-end speech secognition toolkit

End-to-end speech secognition toolkit This is an E2E ASR toolkit modified from Espnet1 (version 0.9.9). This is the official implementation of paper:

Jinchuan Tian 147 Dec 28, 2022
Real-time LIDAR-based Urban Road and Sidewalk detection for Autonomous Vehicles 🚗

urban_road_filter: a real-time LIDAR-based urban road and sidewalk detection algorithm for autonomous vehicles Dependency ROS (tested with Kinetic and

JKK - Vehicle Industry Research Center 180 Dec 12, 2022
Code for the paper "Improving Vision-and-Language Navigation with Image-Text Pairs from the Web" (ECCV 2020)

Improving Vision-and-Language Navigation with Image-Text Pairs from the Web Arjun Majumdar, Ayush Shrivastava, Stefan Lee, Peter Anderson, Devi Parikh

Arjun Majumdar 44 Dec 14, 2022
Freecodecamp Scientific Computing with Python Certification; Solution for Challenge 2: Time Calculator

Assignment Write a function named add_time that takes in two required parameters and one optional parameter: a start time in the 12-hour clock format

Hellen Namulinda 0 Feb 26, 2022
TResNet: High Performance GPU-Dedicated Architecture

TResNet: High Performance GPU-Dedicated Architecture paperV2 | pretrained models Official PyTorch Implementation Tal Ridnik, Hussam Lawen, Asaf Noy, I

426 Dec 28, 2022
A deep learning tabular classification architecture inspired by TabTransformer with integrated gated multilayer perceptron.

The GatedTabTransformer. A deep learning tabular classification architecture inspired by TabTransformer with integrated gated multilayer perceptron. C

Radi Cho 60 Dec 15, 2022
A Python package for time series augmentation

tsaug tsaug is a Python package for time series augmentation. It offers a set of augmentation methods for time series, as well as a simple API to conn

Arundo Analytics 278 Jan 01, 2023
Tensorflow python implementation of "Learning High Fidelity Depths of Dressed Humans by Watching Social Media Dance Videos"

Learning High Fidelity Depths of Dressed Humans by Watching Social Media Dance Videos This repository is the official tensorflow python implementation

Yasamin Jafarian 287 Jan 06, 2023
Build a small, 3 domain internet using Github pages and Wikipedia and construct a crawler to crawl, render, and index.

TechSEO Crawler Build a small, 3 domain internet using Github pages and Wikipedia and construct a crawler to crawl, render, and index. Play with the r

JR Oakes 57 Nov 24, 2022