本文为Fairseq漫游指南系列的第二篇文章。前面一篇文章以基于Transformer的翻译模型为例,对Fairseq的命令行使用方法进行了初步的介绍。Fairseq预设了大量的任务和模型,可以根据需要准备数据,并参考对应任务、模型的参数进行训练和解码。
在实际的使用中,现有的模型可能无法满足真实任务的需要,我们可能需要处理不同类型的输入输出,或者需要对模型进行修改以验证新的想法。在这种情况下,只通过命令行调用预设任务和模型的方法就存在很大的局限,我们需要对Fairseq本身进行扩展,以满足实际多样化的需求。
本文以实现一个可以双向翻译(EN-DE和DE-EN)的Transformer模型为例,介绍Fairseq的插件扩展。
Fairseq扩展概述
Fairseq允许用户在不修改源代码的情况下,以插件的形式进行扩展。目前,可以自定义五种插件:
1. 任务(Tasks):任务定义了我们要完成的整个流程,包括读取数据组成batch、模型初始化、训练、测试等。
2. 模型(Models):模型定义了网络的结构、包含的参数、前向计算过程。
3. 评价准则(Criterions):评价准则也就是损失函数,用来根据网络输出和真实标签计算损失。
4. 优化器(Optimizers):在反向传播之后,优化器决定了更新模型参数的方式。
5. 学习率调度器(Learning Rate Schedulers):学习率调度器可以用来根据训练过的步数,动态调整学习率。
对于这五种插件,Fairseq自身的代码中提供了大量的预设,可以在对应的目录下查看,如fairseq/models
目录下提供了多种模型的实现。在指定了这五种插件(可以为预设值,也可以为用户编写的插件)之后,fairseq的训练流程可以抽象为:
1 | for epoch in range(num_epochs): |
如前所述,模型的单步训练过程在任务中定义,即task.train_step
。默认情况下,其实现如下:
1 | def train_step(self, batch, model, criterion, optimizer, **unused): |
只通过命令行的方式,可以选择使用不同的预设插件,如LSTM、Transformer等不同的模型。但如果我们想要扩展Fairseq没有提供的一些功能,那么就需要我们自己编写一些插件,并进行注册,以便Fairseq在运行的时候可以加载我们自定义的插件。接下来我们以一个最简单的例子,来实现自己的Transformer模型。
首先需要建立我们的代码仓库,假设代码存放在$HOME/codebase/custom
:
1 | ├── custom |
其中,__init__.py
的内容如下:
1 | from fairseq.models.transformer import TransformerModel, transformer_iwslt_de_en |
在Fairseq中,模型称为model
,模型对应的超参数称为model_architecture
。在这个例子中,我们定义了一个名为my_transformer
的模型,以及其对应的iwslt_arch
超参数。由于模型直接继承了预设的TransformerModel
,超参数直接调用了transformer_iwslt_de_en
,因此其功能没有任何的改变,只是名字发生了改变。在编写了这个简单的插件后,就可以通过命令行来进行调用了:
1 | fairseq-train data-bin --arch iwslt_arch --user-dir $HOME/codebase/custom --max-tokens 4096 --optimizer adam |
其中,data-bin
是上一篇文章”命令行工具“中预处理的数据路径。该命令可以在任何目录下执行,只要通过--user-dir $HOME/codebase/custom
参数指定我们的插件代码位置即可。
从上面的例子可以看出,自定义并使用一个模型插件需要以下几个步骤:
1. 创建一个python module,即包含__init__.py
文件的目录(这个例子中为$HOME/codebase/custom
);
2. 定义新的模型类(类名可以任意,只要不和其他重复即可),并用@register_model('model_name')
装饰器来进行注册(model_name即模型名,Fairseq通过这个名字来定位插件对应的类);
3. 定义模型对应的预设超参数model_architecture,这是一个函数,接收args
参数。比如想将dropout预设为0.1,可以通过args.dropout = 0.1
来完成。和模型类似,想要Fairseq能够将其识别为预设超参数,需要使用@register_model_architecture('model_name', 'arch_name')
来进行注册,其中model_name
是模型名,arch_name
是预设值的名字;
4. 如果插件的实现在__init__.py
之外的文件中,那么还需要在__init__.py
文件中导入注册的model和model_architecture,这是因为fairseq在运行时通过查找已经导入(加载)的插件名(如模型名)来定位具体的实现,如果不进行导入,那么即便指定了--user-dir
,fairseq也只能加载在__init__.py
中的代码,而找不到在其他文件中定义的插件。在这个例子中,由于model和model_architecture都定义在了__init__.py
文件中,因此不需要额外的导入;
5. 在命令行调用的时候,指定--user-dir
参数为插件路径,并使用--arch
来告诉Fairseq使用我们自定义的模型和超参数。
定义新的任务、优化器等,和定义新的模型基本一致,都是通过定义一个新的类,并通过@register_*
来注册。下面,我们将实现一个双向翻译、参数共享的翻译系统,来看一下扩展在实际中如何使用。
准备工作
我们使用和系列第一篇《命令行工具》中一致的环境:
1. python 3.7
2. pytorch 1.6.0
3. Fairseq,commit 522c76b
4. cuda 10.1
5. Apex 0.1
对于数据,我们同样使用iwslt 14英德平行数据来进行训练。由于我们的目的是进行两种语言的双向翻译,编码器和解码器都需要拥有处理两种语言的能力,因此我们需要对两种语言使用共享的词表,在fairseq的预处理命令中,可以通过--joined-dictionary
参数来指定:
1 | bash fairseq/examples/translation/prepare-iwslt14.sh |
默认情况下,预处理后的二进制数据文件保存在data-bin目录下。
目标
对于双向翻译任务,我们希望给定一个源语言的句子,模型能解码出一个目标语言的句子;给定一个目标语言的句子,模型能够解码出一个源语言的句子。为了达到这个目的,我们需要模型能够区分出输入是哪种语言,或者说,希望翻译为哪种语言。在多语言机器翻译中,一个简单而有效的做法是,在输入的句子前面加上一个标签来指明希望模型输出的语言,比如在句子前面加一个__2<en>__
,来告诉模型我们希望得到英文的翻译结果。
为了给输入句子加上标签,我们需要在读取数据和组成batch之间进行处理,即读取所有句对,给句对的源端部分加上指明目标语言的标签,再根据句长,将相似长度的句子打包为一个batch,并将这个batch数值化,来构成模型的输入。如前所述,读取数据组成batch的操作需要在Task中进行,因此我们需要自定义一个Task,来对数据进行处理。
在模型部分,我们希望编码器和解码器共享自注意力和前馈神经网络中的参数,即Transformer中self attention和feed forward模块的参数。这一部分的改变在模型中体现,因此我们还需要自定义一个基于Transformer的Model,以实现参数的共享。
在明确了目标之后,我们首先需要创建代码库,保存在codebase/custom
目录下:
1 | └── custom |
其中,bidirectional_transformer.py
保存我们自定义的模型,bidirectional_translation_task.py
保存我们自定义的任务。为了使Fairseq能够加载自定义模型和任务,需要在__init__.py
中将其导入:
1 | from . import ( |
接下来,我们的目标就是实现bidirectional_transformer
和bidirectional_translation_task
了。
参数共享的模型
模型部分相对比较简单,由于Fairseq中实现了大量的预设模型,因此我们在实现自定义模型的时候,应该尽量复用已有的代码,通过模型类的继承、方法的重载来实现功能上的修改和扩展。我们直接使用Transformer的实现,并在模型初始化之后,指定参数共享的部分:
1 | from fairseq.models.transformer import TransformerModel, transformer_iwslt_de_en |
通过继承Fairseq中的Transformer
模型,我们的BidirectionalTransformerModel
就可以实现与Transformer相同的功能。在模型的实例化方法__init__
中,首先调用父类TransformerModel
的初始化方法,来初始化模型及其参数,然后调用make_shared_component
方法,来共享编码器和解码器每一层中的self_attn
和fc1
、fc2
参数。同时,我们使用了transformer_iwslt_de_en
来定义名为iwslt_arch
的预设超参数。最后通过register_model
和register_model_architecture
来注册模型,就可以在Fairseq中使用了。
双向翻译任务
在自定义的双向翻译任务中,我们需要将标签加到每个源端句子前面。由于我们的目的和翻译任务基本一致,因此可以复用Fairseq中的TranslationTask
,只需要实现数据加载部分即可。完整代码如下:
1 | import os |
参考fairseq/tasks/translation.py
的代码可以看到,数据加载实在方法load_dataset
中完成的,我们可以在其基础上(加载源语言到目标语言的数据),增加目标语言到源语言数据的加载,并给加载的数据添加标签。load_dataset
方法的基本流程是,通过spilt
参数,来加载对应的数据,并将加载的数据赋值给self.datasets[split]
。其中split
参数一般为train
、valid
或者test
。默认情况下,训练、验证、解码分别使用对应的数据,但也可以通过命令行来指定,如fairseq-generate --gen-subset train
就会解码训练数据(即split为train)。
在我们的实现中,读取数据和添加标签的流程如下:
1. 仿照fairseq/tasks/translation.py
中的代码,使用data_utils.load_indexed_dataset
来分别读取两种语言预处理后的二进制数据;
2. 使用PrependTokenDataset
给两种语言的数据都创建一个加标签的版本;
3. 如果是测试的情况下split == 'test'
,只使用 src_prepend_dataset
和tgt_raw_dataset
来构建数据集;如果是训练或者验证,则将加标签的源语言和目标语言数据使用ConcatDataset
进行拼接,得到src_dataset
,将两种语言不加标签的数据拼接,得到tgt_dataset
,来构建数据集;
4. 根据src_dataset
和tgt_dataset
,创建一个LanguagePairDataset
,并赋值给self.datasets[split]
。
在这个例子中,我们使用到了PrependTokenDataset
、LanguagePairDataset
、ConcatDataset
三个Fairseq中定义的类来完成加标签、拼接数据等操作。在fairseq/data
目录下,还有大量预定义的数据类可供使用,同时,我们还可以继承预定义的类来扩展其功能,完成更复杂的数据处理。
最后,由于我们使用了额外的标签来指定目标语言,所以需要在词表中添加对应的语言标签。通过查看TranslationTask
的代码可知,词表的创建和初始化是在setup_task
中进行的,我们通过重写该方法,在任务创建完成后,为src_dict
和tgt_dict
分别添加源语言标签和目标语言标签。
训练和解码
在创建了自定义的任务和模型后,就可以使用该插件来进行训练了。进行训练和解码的命令和前文所介绍的基本一致,只需要指定插件代码的位置--user-dir
、模型结构--arch
和任务--task
:
1 | fairseq-train data-bin --max-tokens 4096 --max-update 50000 \ |
解码的命令不需要指定模型结构:
1 | fairseq-generate data-bin --path checkpoints/checkpoint_best.pt --remove-bpe --user-dir $HOME/codebase/custom --task bidirectional_translation_task --source-lang en --target-lang de |
其中,参数--source-lang
和--target-lang
可以进行特定方向的翻译,用来验证模型训练得到的双向翻译能力。如果不指定这两个参数,则默认是和数据预处理时相同的翻译方向(德语到英语)。
总结
本文通过一个双向翻译的例子,介绍了Fairseq扩展插件的基本使用方法。大多数的NLP任务都可以在不修改源码的情况下,通过编写插件来实现,这在很大程度上简化了实验的流程,我们只需要编写插件实现与原方法、模型不同的部分,而不需要关注重复的模式和训练流程。
在实际开发插件的过程中,关键的问题在于如何定位我们需要修改的部分,以及如何最大程度地复用Fairseq已经实现的部分。后续文章将介绍Fairseq中已经实现的一些任务、模型,以及数据集等常用的工具,以便了解我们要实现的功能在fairseq中是否已经有对应的实现及实现对应的位置。