转载本文请注明出处:https://yudonglee.me/ctc-explained  作者:yudonglee

现实应用中许多问题可以抽象为序列学习(sequence learning)问题,比如词性标注(POS Tagging)、语音识别(Speech Recognition)、手写字识别(Handwriting Recognition)、机器翻译(Machine Translation)等,其核心问题都是训练模型将一个领域的输入序列转换为另一个领域的输出序列。

近年来,基于 RNN 的序列到序列模型(sequence-to-sequence models)在这类任务中取得了显著的效果提升。本文介绍一种 RNN(Recurrent Neural Networks)的端到端训练方法——CTC(Connectionist Temporal Classification)算法。CTC 可以让 RNN 直接对序列数据进行学习,无需事先标注输入序列和输出序列之间的映射关系,从而打破了 RNN 应用于语音识别、手写字识别等领域的数据依赖约束,使模型在序列学习任务中取得更好的效果。

本系列文章总共分为三部分来全面阐述CTC算法(本篇为Part 1):
Part 1:Training the Network(训练算法篇),介绍CTC理论原理,包括问题定义、公式推导、算法过程等。Part 1链接
Part 2:Decoding the Network(解码算法篇),介绍CTC Decoding的几种常用算法。Part 2链接

Part 3:CTC Demo by Speech Recognition(语音识别实战篇),基于 TensorFlow 实现完整的 CTC 语音识别系统。Part 3链接

接下来,我们先从“问题”的背景说起。

1. 背景介绍

在序列学习任务中,RNN模型对训练样本一般有这样的依赖条件:输入序列和输出序列之间的映射关系已经事先标注好了。比如,在词性标注任务中,训练样本中每个词(或短语)对应的词性会事先标注好,如下图(DT、NN等都是词性的标注,具体含义请参考链接)。由于输入序列和输出序列是一一对应的,所以RNN模型的训练和预测都是端到端的,即可以根据输出序列和标注样本间的差异来直接定义RNN模型的Loss函数,传统的RNN训练和预测方式可直接适用。

然而,在语音识别、手写字识别等任务中,音频数据和图像数据都是将现实世界的模拟信号转为数字信号后采集得到的,这些数据天然就很难进行”分割”,这使得我们很难获取到包含输入序列和输出序列映射关系的大规模训练样本(人工标注成本巨高,且启发式挖掘方法存在很大局限性)。因此,在这种条件下,RNN无法直接进行端到端的训练和预测。

如下图,输入是“apple”对应的一段说话音频和手写字图片,从连续的音频信号和图像信号中逐一分割并标注出对应的输出序列非常费时费力,在大规模训练下这种数据要求是完全不切实际的。而如果输入序列和输出序列之间映射关系没有提前标注好,那传统的RNN训练方式就不能直接适用了,无法直接对音频数据和图像数据进行训练。

因此,在语音识别、图像识别等领域中,由于数据天然无法切割,且难以标注出输入和输出的序列映射关系,导致传统的RNN训练方法不能直接适用。那么,如何让RNN模型实现端到端的训练成为了关键问题。

Connectionist Temporal Classification(CTC)[1]是Alex Graves等人在ICML 2006上提出的一种端到端的RNN训练方法,它可以让RNN直接对序列数据进行学习,而无需事先标注好训练数据中输入序列和输入序列的映射关系,使得RNN模型在语音识别等序列学习任务中取得更好的效果,在语音识别和图像识别等领域CTC算法都有很比较广泛的应用。总的来说,CTC的核心思路主要分为以下几部分:

  • 它扩展了RNN的输出层,在输出序列和最终标签之间增加了多对一的空间映射,并在此基础上定义了CTC Loss函数
  • 它借鉴了HMM(Hidden Markov Model)的Forward-Backward算法思路,利用动态规划算法有效地计算CTC Loss函数及其导数,从而解决了RNN端到端训练的问题
  • 最后,结合CTC Decoding算法RNN可以有效地对序列数据进行端到端的预测

接下来,通过一个语音识别的实际例子来引出CTC的解决思路

2. 一个实际的例子–声学模型

语音识别的核心问题是把一段音频信号序列转化文字序列,传统的语音识别系统主要分为以下几部分,如下图。

其中,X表示音频信号,O是它的特征表示,一般基于LPC、MFCC等方法提取特征,也可以基于DNN的方式“学到”声学特征的表示。为了简化问题,我们暂且把O理解为是由实数数组组成的序列,它是音频信号的特征表示。Q是O对应的发音字符序列,即建模单元,一般可以是音素、音节、字、词等。W是音频信号X对应的文字序列,即我们最终的识别结果。

如图所示,核心问题是通过解码器找到令P(W|X)最大化的的W,通过贝叶斯公式可将其分解为P(O|Q)、P(Q|W)、P(W),分别对应声学模型、发音模型、语言模型。

其中,声学模型就是对P(O|Q)进行建模,通过训练可以“学到”音频信号和文字发音间的联系。为了简化问题,我们假定声学模型的建模单元Q选择的是音节,O选择的是MFCC特征(由39维数组组成的序列)。

以下图为例,输入序列是一段”我爱你中国”的音频,输出序列是音节序列 “wo3 ai4 ni3 zhong1 guo2″。如果训练样本中已经将音频”分割”好,并标注了音频帧与音节的对应关系,则 RNN 模型的结构如下:

然而,如前面所述,对音频进行精确”分割”并标注映射关系在实际中是不可行的。实际做法是按照固定时间窗口滑动提取特征,例如每 10 毫秒提取一帧,得到一个 N 维特征向量。这种方式下,输入序列的长度远大于输出标签的长度,如下图所示:

由于人说话发音是连续的,且中间也会有“停顿”,所以输出序列中存在重复的元素,比如“wo3 wo3”,也存在表示间隔符号“_”。需从输出序列中去除掉重复的元素以及间隔符,才可得到最终的音节序列,比如,“wo3 wo3 ai4 _ ni3 _ zhong1 guo2 _” 归一处理后得到“wo3 ai4 ni3 zhong1 guo2”。因此,输出序列和最终的label之间存在多对一的映射关系,如下图:

RNN模型本质是对𝒑(𝒛│𝒙)建模,其中x表示输入序列,o表示输出序列,z表示最终的label,o和l存在多对一的映射关系,即:𝒑(𝒛│𝒙)=sum of all P(o|x),其中o是所有映射到z的输出序列。因此,只需要穷举出所有的o,累加一起即可得到𝒑(𝒛│𝒙),从而使得RNN模型对最终的label进行建模。

经过以上的映射转换,解决了端到端训练的问题,RNN模型实际上是对映射到最终label的输出序列的空间建模。然而,对每一个z都“穷举所有的o”,这个计算的复杂度太大,会使得训练速度变得非常慢,因此怎么更高效地进行端到端训练成为待解决的关键问题。

通过以上的实际例子,我们对问题的解决思路有了更加直观的了解,接下来就开始正式介绍CTC的理论原理。

3. 问题定义

以RNN声学模型为例子,建模的目标是通过训练得到一个RNN模型,使其满足:

本质上是最大似然预估, S是训练数据集,X和Z分别是输入空间(由音频信号向量序列组成的集合)和目标空间(由声学模型建模单元序列组成的集合),L是由输出的字符集(声学建模单元的集合),且x的序列长度小于或等于z的序列长度。

接下来,在介绍如何计算Loss函数之前,我们需要对RNN输出层做一个简单的扩展。

4. RNN输出层扩展

为了便于读者理解,下图简化了 RNN 的结构:仅使用单向单层 LSTM,将声学建模单元设为字母 {a-z}。在此基础上,对建模单元字符集做了两项扩展:一是增加了 blank 符号(表示”无输出”),二是定义了从输出层序列到最终 label 序列的多对一映射函数 B。通过该映射函数,多条不同的输出路径可以映射到同一个最终 label 序列。

所以,计算𝒑(𝒛│𝒙)的思路就是:枚举所有经映射函数 B 映射到最终 label z 的输出序列(即”路径”),将它们的概率累加即可。如下图所示:

5. CTC Loss函数定义

CTC Loss 函数的定义基于一个重要的条件独立假设:RNN 在各时间步的输出相互独立。在此假设下,一条路径的概率等于路径上各时间步输出概率的乘积,进而可以写出 CTC Loss 函数的完整定义,如下图所示:

假定选择单层 LSTM 作为 RNN 的具体结构,则整体模型架构(从输入特征到 CTC Loss 计算)如下图所示:

6. CTC Loss函数计算

由于直接穷举所有路径来计算 𝒑(𝒛│𝒙) 的时间复杂度是指数级的,作者借鉴了 HMM 的 Forward-Backward 算法思路,利用动态规划将复杂度降至多项式级别。

为了更形象地表示问题的搜索空间,如下图所示,用 X 轴表示时间步,Y 轴表示扩展后的输出序列。具体地,对最终标签 l 做标准化处理:在每个字符之间以及首尾都插入 blank 符号,得到扩展序列 l’,其长度满足 |l’| = 2|l| + 1。例如:l = “apple”(长度 5),则 l’ = “_a_p_p_l_e_”(长度 11)。

需要注意的是,并非搜索空间中所有路径都是合法的。合法路径需要满足以下约束条件(例如:路径必须从 l’ 的前两个符号之一开始,必须在最后两个符号之一结束,且不能跳过非 blank 字符等),具体规则如下图所示:

所以,依据以上约束规则,遍历所有映射为“apple”的合法路径,最终时序T=8,标签labeling=“apple”的全部路径如下图:

接下来的问题是:如何高效地计算这些合法路径的概率总和?作者借鉴了 HMM 的 Forward-Backward 算法思路,利用动态规划求解。核心思想是定义前向概率 α(从起点到当前位置的路径概率之和)和后向概率 β(从当前位置到终点的路径概率之和),通过递推关系避免重复计算,如下图所示:

通过动态规划递推求出全部前向概率后,在最后一个时间步对 l’ 的末尾两个位置求和,即可得到 𝒑(𝒛│𝒙),进而计算 CTC Loss 函数。具体公式如下图所示:

类似的方式,我们可以定义后向概率 β,即从最后一个时间步反向递推到当前位置。同样地,后向概率也可以用来计算 CTC Loss 函数,如下图所示:

更进一步,将任意时间步 t 的前向概率 α 和后向概率 β 相乘,也可以计算 CTC Loss 函数。这一等价关系对后续的梯度求导推导至关重要,如下图所示:

总结一下,根据前向概率计算CTC Loss函数,得到以下结论:

根据后向概率计算CTC Loss函数,得到以下结论:

根据任意时刻的前向概率和后向概率计算CTC Loss函数,得到以下结论:

至此,我们已经得到了 CTC Loss 的高效计算方法。接下来,对其进行求导,以便通过反向传播算法训练 RNN 模型。

7. CTC Loss函数求导

我们先回顾 RNN 的网络结构。如下图所示,红色标注部分是 CTC Loss 函数求导的核心环节——即 CTC Loss 对 RNN 输出层(softmax 层)输出值的偏导数:

CTC Loss函数对 RNN 输出层元素的求导,核心思路是通过前向概率和后向概率将对总路径概率的求导分解为对每个时间步输出概率的求导,具体推导过程如下图所示:

8. 总结

本篇以 RNN 声学模型为例,从问题背景出发,逐步介绍了 CTC Loss 函数的定义、基于动态规划的高效计算方法,以及梯度求导过程,最终通过反向传播算法实现了对 RNN 模型的端到端训练。

值得注意的是,CTC 算法也存在一些局限性:首先,条件独立假设意味着模型在各时间步的输出相互独立,无法建模输出序列内部的依赖关系(例如语言模型信息需要外部引入);其次,CTC 要求输入序列的长度不小于输出序列的长度,这在某些任务场景下可能成为限制;此外,CTC 训练出的模型倾向于产生”尖峰”(peaky)的后验概率分布,大部分时间步的输出集中在 blank 符号上。这些局限性在后续的 Attention-based 模型和 RNN-Transducer 等方法中得到了不同程度的改进。

至此,CTC 算法的模型训练过程与原理已介绍完毕。下一篇将详细介绍 CTC 算法的推理解码过程与原理,Part2链接

References

  1. Graves et al., Connectionist Temporal Classification: Labelling Unsegmented Sequence Data with RNNs. In ICML, 2006. (Graves提出CTC算法的原始论文)
  2. Graves et al., A Novel Connectionist System for Unconstrained Handwriting Recognition. In IEEE Transactions on PAML, 2009.(CTC算法在手写字识别中的应用)
  3. Graves et al., Towards End-to-End Recognition with RNNs. In JMLR, 2014.(CTC算法在端到端声学模型中的应用)
  4. Alex Graves, Supervised Sequence Labelling with Recurrent Neural Networks. In Studies in Computational Intelligence, Springer, 2012.( Graves 的博士论文,关于sequence learning的研究,主要是CTC)
  5. Watanabe et al., Hybrid CTC/Attention Architecture for End-to-End Speech Recognition. IEEE Journal of Selected Topics in Signal Processing, 2017.(CTC/Attention 联合解码的代表性工作)
  6. Miao et al., EESEN: End-to-End Speech Recognition using Deep RNN Models and WFST-Based Decoding. In ASRU, 2015.(基于 WFST 的 CTC 解码框架)
  7. TensorFlow CTC API 文档:https://www.tensorflow.org/api_docs/python/tf/nn/ctc_loss
  8. Librosa 文档:https://librosa.org/doc/latest/index.html

Loading