神经网络-seq2seq


seq2seq(Sequence to Sequence)是一种序列生成神经网络结构。

1. 输入输出

训练输入:

编码器 \[ T=\{x_1,\cdots , x_N\} \] 解码器 \[ T=\{x'_1,\cdots , x'_N\} \]

其中,

  • \(x\) 为编码器输入特征域,\(x_i \in \mathbb{R}^{n},i =1,2,...,N\)
  • \(x'\) 为解码器输入特征域,\(x'_i \in \mathbb{R}^{n},i =1,2,...,N'\)

测试输入

编码器 \[ \hat X=\{\hat x_1, \cdots ,\hat x_{\hat N}\},\hat x_i \in \mathbb{R}^n \] 解码器 \[ \hat X'=\{\hat x'_1\} \] 测试输出: \[ \hat X=\{\hat x'_2, \cdots ,\hat x'_{\hat N + 1}\},\hat x'_i \in \mathbb{R}^n \]

2. 基本定义

2014年的论文:Sequence to Sequence Learning with Neural Networks

一个经典的seq2seq模型架构如下

上图是一个按时序展开的网络图,分为编码器(Encoder)和解码器(Decoder)两部分,编码器和解码器使用的基础单元都是rnn神经元。(RNN基础原理笔记

2.1. 数据流

seq2seq的输入比较特殊,编码器和解码器都需要输入。下面用中译英任务作为例子进行阐述。

  • 训练阶段

    • 需要准备两份文本数据作为输入,中文文本从编码器中输入,英文文本从解码器中输入。特别地,编码器的输入一般称为source,解码器的输入一般称为target。source和target是一一对应的。例如,在中译英任务中,source为 我爱你,target为 I love you

    • target需要统一在每个序列开始前,加入一个独特的开始标志符,例如 [s] I love you

    • source按时序进行输入,即 \[ h_{t+1} = a(W_x \cdot x_t + W_h\cdot h_t + b) \] 其中,\(a(\cdot)\) 表示激活函数。

    • 将最后一个时刻的隐层输出 \(c\) 传递到解码器中

    • target也按时序进行输入,每一时刻的输出预测下一个时刻的输入,即 \[ h_{t+1} = a(W_{x'}\cdot x'_{t+1}+ W_y\cdot y_{t}+W_c\cdot c + b) \\ y_{t+1} = \text{Search}(h_1,\cdots,h_t) \] 其中,\(\text{Search}(\cdot)\) 表示一种搜索算法,可以从以往时刻隐层的输出中搜索出当前时刻的最优解。

      例如,当前时刻输入为 [s],其输出预测为 I。像这种把当前时刻的输出作为下一时刻的输入的机制,叫做自回归(auto-regression)

    • target最后一个时刻的预测标签为一个特殊的结束标志符,例如,[s] I love you 对应的预测序列为 I love you [e]

  • 预测阶段

    • 只需准备编码器的输入文本即可。
    • source按时序进行输入,将最后一个时刻的隐层输出 \(c\) 传递到解码器中。
    • target一开始给定开始标志 [s] 进行输入,通过自回归机制,按时序进行预测,直到输出结束标志符 [e]

注:

  • 由于编码器在输入和预测阶段都可以看到完整的序列,所以可以使用双向的lstm,但解码器在预测阶段只能按时序输入,所以只能采用单向的lstm
  • 由于主要结构还是基于rnn,所以处理较长的序列,会出现遗忘,导致效果不佳。

2.2. 预测解码

在预测阶段,需要从以往时刻的输出中得到当前时刻的最优解,实际上就是一个最优路径问题

例如,已知当前最优解为 ACB[e],达到最优解时,其softmax后的输出矩阵如下

2.2.1. 枚举搜索

即枚举所有可能的路线,综合得出最优解,即 \[ y_{t+1} = \arg \max P(y | h_1, \cdots , h_t) \]

其中,\(h_i\) 为当前时刻所有可能取值的集合

对于词库大小为 \(n\),序列长度为 \(T\) 的预测任务,其时间复杂度为 \(O(n^T)\),基本不可能应用到实际中

2.2.2. 贪婪搜索

即只考虑当前时刻输出,得出最优解,即 \[ y_{t+1} = \arg \max P(y|h_t) \] 这种方法只能得到局部最优解,但对于每个时刻都是独立的情况,贪婪路径是可以得到最优解的。

例如,在上诉例子中,由于每个时刻是非独立的,时间步3的结果是基于时间步2为C时得到的,所以如果按照贪婪搜索,时间步2选择的应该为B,则时间步3的结果就未必还是[0.1,0.6,0.2,0.1] 的矩阵了

2.2.3. 束(beam)搜索

结合枚举搜索和贪婪搜索,是一种折衷的方案

设定一个束长度(beam size)\(k\),每个时刻选取概率值靠前的 \(k\) 个可能取值,然后对其组合排序枚举,得出最优解,即 \[ y_{t+1} = \arg \max P(y | h'_1, \cdots , h'_t) \] 其中,\(h'_i\)\(h_i\) 所有可能取值中概率最大的 \(k\) 个值的集合,例如,上述例子中,若\(k=2\),则\(h'_1=\{ A,B \},h'_2=\{ C,B \},...\)

下面还有一个束长度为2,序列长度为3的束搜索例子

束搜索的时间复杂度为 \(O(k^T)\)

3. references

Learning Phrase Representations using RNN Encoder-Decoder for Statistical Machine Translation

https://www.cnblogs.com/yifanrensheng/p/13167724.html


评论
  目录