10.6. 编码器—解码器架构¶ 在 SageMaker Studio Lab 中打开 Notebook
在诸如机器翻译(10.5节)之类的一般“序列到序列”问题中,输入和输出都是长度可变的、未对齐的序列。处理此类数据的标准方法是设计一个*编码器-解码器*(encoder-decoder)架构(图 10.6.1),它由两个主要组件组成:一个*编码器*,它接受一个可变长度的序列作为输入;一个*解码器*,它作为一个条件语言模型,接受编码后的输入和目标序列的左侧上下文,并预测目标序列中的下一个词元。
图 10.6.1 编码器—解码器架构。¶
让我们以从英语到法语的机器翻译为例。给定一个英语的输入序列:“They”、“are”、“watching”、“.”,这个编码器—解码器架构首先将这个可变长度的输入编码成一个状态,然后解码该状态,以生成翻译后的序列,一次一个词元地作为输出:“Ils”、“regardent”、“.”。由于编码器—解码器架构是后续章节中不同序列到序列模型的基础,本节将把这个架构转换成一个接口,稍后将实现它。
from torch import nn
from d2l import torch as d2l
from mxnet.gluon import nn
from d2l import mxnet as d2l
from flax import linen as nn
from d2l import jax as d2l
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
import tensorflow as tf
from d2l import tensorflow as d2l
10.6.1. 编码器¶
在编码器接口中,我们只指定编码器接受可变长度序列作为输入X
。具体的实现将由任何继承这个基础Encoder
类的模型来提供。
class Encoder(nn.Module): #@save
"""The base encoder interface for the encoder--decoder architecture."""
def __init__(self):
super().__init__()
# Later there can be additional arguments (e.g., length excluding padding)
def forward(self, X, *args):
raise NotImplementedError
class Encoder(nn.Block): #@save
"""The base encoder interface for the encoder--decoder architecture."""
def __init__(self):
super().__init__()
# Later there can be additional arguments (e.g., length excluding padding)
def forward(self, X, *args):
raise NotImplementedError
class Encoder(nn.Module): #@save
"""The base encoder interface for the encoder--decoder architecture."""
def setup(self):
raise NotImplementedError
# Later there can be additional arguments (e.g., length excluding padding)
def __call__(self, X, *args):
raise NotImplementedError
class Encoder(tf.keras.layers.Layer): #@save
"""The base encoder interface for the encoder--decoder architecture."""
def __init__(self):
super().__init__()
# Later there can be additional arguments (e.g., length excluding padding)
def call(self, X, *args):
raise NotImplementedError
10.6.2. 解码器¶
在下面的解码器接口中,我们添加了一个额外的init_state
方法,用于将编码器的输出(enc_all_outputs
)转换为编码后的状态。请注意,此步骤可能需要额外的输入,例如输入的有效长度,这在10.5节中已经解释过。为了逐个词元地生成可变长度的序列,解码器每次可以将一个输入(例如,前一个时间步生成的词元)和编码状态映射到当前时间步的输出词元。
class Decoder(nn.Module): #@save
"""The base decoder interface for the encoder--decoder architecture."""
def __init__(self):
super().__init__()
# Later there can be additional arguments (e.g., length excluding padding)
def init_state(self, enc_all_outputs, *args):
raise NotImplementedError
def forward(self, X, state):
raise NotImplementedError
class Decoder(nn.Block): #@save
"""The base decoder interface for the encoder--decoder architecture."""
def __init__(self):
super().__init__()
# Later there can be additional arguments (e.g., length excluding padding)
def init_state(self, enc_all_outputs, *args):
raise NotImplementedError
def forward(self, X, state):
raise NotImplementedError
class Decoder(nn.Module): #@save
"""The base decoder interface for the encoder--decoder architecture."""
def setup(self):
raise NotImplementedError
# Later there can be additional arguments (e.g., length excluding padding)
def init_state(self, enc_all_outputs, *args):
raise NotImplementedError
def __call__(self, X, state):
raise NotImplementedError
class Decoder(tf.keras.layers.Layer): #@save
"""The base decoder interface for the encoder--decoder architecture."""
def __init__(self):
super().__init__()
# Later there can be additional arguments (e.g., length excluding padding)
def init_state(self, enc_all_outputs, *args):
raise NotImplementedError
def call(self, X, state):
raise NotImplementedError
10.6.3. 整合编码器和解码器¶
在前向传播中,编码器的输出用于生成编码状态,这个状态将进一步被解码器用作其输入之一。
class EncoderDecoder(d2l.Classifier): #@save
"""The base class for the encoder--decoder architecture."""
def __init__(self, encoder, decoder):
super().__init__()
self.encoder = encoder
self.decoder = decoder
def forward(self, enc_X, dec_X, *args):
enc_all_outputs = self.encoder(enc_X, *args)
dec_state = self.decoder.init_state(enc_all_outputs, *args)
# Return decoder output only
return self.decoder(dec_X, dec_state)[0]
class EncoderDecoder(d2l.Classifier): #@save
"""The base class for the encoder--decoder architecture."""
def __init__(self, encoder, decoder):
super().__init__()
self.encoder = encoder
self.decoder = decoder
def forward(self, enc_X, dec_X, *args):
enc_all_outputs = self.encoder(enc_X, *args)
dec_state = self.decoder.init_state(enc_all_outputs, *args)
# Return decoder output only
return self.decoder(dec_X, dec_state)[0]
class EncoderDecoder(d2l.Classifier): #@save
"""The base class for the encoder--decoder architecture."""
encoder: nn.Module
decoder: nn.Module
training: bool
def __call__(self, enc_X, dec_X, *args):
enc_all_outputs = self.encoder(enc_X, *args, training=self.training)
dec_state = self.decoder.init_state(enc_all_outputs, *args)
# Return decoder output only
return self.decoder(dec_X, dec_state, training=self.training)[0]
class EncoderDecoder(d2l.Classifier): #@save
"""The base class for the encoder--decoder architecture."""
def __init__(self, encoder, decoder):
super().__init__()
self.encoder = encoder
self.decoder = decoder
def call(self, enc_X, dec_X, *args):
enc_all_outputs = self.encoder(enc_X, *args, training=True)
dec_state = self.decoder.init_state(enc_all_outputs, *args)
# Return decoder output only
return self.decoder(dec_X, dec_state, training=True)[0]
在下一节中,我们将看到如何应用循环神经网络来设计基于此编码器—解码器架构的序列到序列模型。
10.6.4. 小结¶
编码器-解码器架构可以处理由可变长度序列组成的输入和输出,因此适用于如机器翻译等序列到序列问题。编码器接收一个可变长度的序列作为输入,并将其转换为一个固定形状的状态。解码器将固定形状的编码状态映射到一个可变长度的序列。