10.6. 编码器—解码器架构
在 Colab 中打开 Notebook
在 Colab 中打开 Notebook
在 Colab 中打开 Notebook
在 Colab 中打开 Notebook
在 SageMaker Studio Lab 中打开 Notebook

在诸如机器翻译(10.5节)之类的一般“序列到序列”问题中,输入和输出都是长度可变的、未对齐的序列。处理此类数据的标准方法是设计一个*编码器-解码器*(encoder-decoder)架构(图 10.6.1),它由两个主要组件组成:一个*编码器*,它接受一个可变长度的序列作为输入;一个*解码器*,它作为一个条件语言模型,接受编码后的输入和目标序列的左侧上下文,并预测目标序列中的下一个词元。

../_images/encoder-decoder.svg

图 10.6.1 编码器—解码器架构。

让我们以从英语到法语的机器翻译为例。给定一个英语的输入序列:“They”、“are”、“watching”、“.”,这个编码器—解码器架构首先将这个可变长度的输入编码成一个状态,然后解码该状态,以生成翻译后的序列,一次一个词元地作为输出:“Ils”、“regardent”、“.”。由于编码器—解码器架构是后续章节中不同序列到序列模型的基础,本节将把这个架构转换成一个接口,稍后将实现它。

from torch import nn
from d2l import torch as d2l
from mxnet.gluon import nn
from d2l import mxnet as d2l
from flax import linen as nn
from d2l import jax as d2l
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
import tensorflow as tf
from d2l import tensorflow as d2l

10.6.1. 编码器

在编码器接口中,我们只指定编码器接受可变长度序列作为输入X。具体的实现将由任何继承这个基础Encoder类的模型来提供。

class Encoder(nn.Module):  #@save
    """The base encoder interface for the encoder--decoder architecture."""
    def __init__(self):
        super().__init__()

    # Later there can be additional arguments (e.g., length excluding padding)
    def forward(self, X, *args):
        raise NotImplementedError
class Encoder(nn.Block):  #@save
    """The base encoder interface for the encoder--decoder architecture."""
    def __init__(self):
        super().__init__()

    # Later there can be additional arguments (e.g., length excluding padding)
    def forward(self, X, *args):
        raise NotImplementedError
class Encoder(nn.Module):  #@save
    """The base encoder interface for the encoder--decoder architecture."""
    def setup(self):
        raise NotImplementedError

    # Later there can be additional arguments (e.g., length excluding padding)
    def __call__(self, X, *args):
        raise NotImplementedError
class Encoder(tf.keras.layers.Layer):  #@save
    """The base encoder interface for the encoder--decoder architecture."""
    def __init__(self):
        super().__init__()

    # Later there can be additional arguments (e.g., length excluding padding)
    def call(self, X, *args):
        raise NotImplementedError

10.6.2. 解码器

在下面的解码器接口中,我们添加了一个额外的init_state方法,用于将编码器的输出(enc_all_outputs)转换为编码后的状态。请注意,此步骤可能需要额外的输入,例如输入的有效长度,这在10.5节中已经解释过。为了逐个词元地生成可变长度的序列,解码器每次可以将一个输入(例如,前一个时间步生成的词元)和编码状态映射到当前时间步的输出词元。

class Decoder(nn.Module):  #@save
    """The base decoder interface for the encoder--decoder architecture."""
    def __init__(self):
        super().__init__()

    # Later there can be additional arguments (e.g., length excluding padding)
    def init_state(self, enc_all_outputs, *args):
        raise NotImplementedError

    def forward(self, X, state):
        raise NotImplementedError
class Decoder(nn.Block):  #@save
    """The base decoder interface for the encoder--decoder architecture."""
    def __init__(self):
        super().__init__()

    # Later there can be additional arguments (e.g., length excluding padding)
    def init_state(self, enc_all_outputs, *args):
        raise NotImplementedError

    def forward(self, X, state):
        raise NotImplementedError
class Decoder(nn.Module):  #@save
    """The base decoder interface for the encoder--decoder architecture."""
    def setup(self):
        raise NotImplementedError

    # Later there can be additional arguments (e.g., length excluding padding)
    def init_state(self, enc_all_outputs, *args):
        raise NotImplementedError

    def __call__(self, X, state):
        raise NotImplementedError
class Decoder(tf.keras.layers.Layer):  #@save
    """The base decoder interface for the encoder--decoder architecture."""
    def __init__(self):
        super().__init__()

    # Later there can be additional arguments (e.g., length excluding padding)
    def init_state(self, enc_all_outputs, *args):
        raise NotImplementedError

    def call(self, X, state):
        raise NotImplementedError

10.6.3. 整合编码器和解码器

在前向传播中,编码器的输出用于生成编码状态,这个状态将进一步被解码器用作其输入之一。

class EncoderDecoder(d2l.Classifier):  #@save
    """The base class for the encoder--decoder architecture."""
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, enc_X, dec_X, *args):
        enc_all_outputs = self.encoder(enc_X, *args)
        dec_state = self.decoder.init_state(enc_all_outputs, *args)
        # Return decoder output only
        return self.decoder(dec_X, dec_state)[0]
class EncoderDecoder(d2l.Classifier):  #@save
    """The base class for the encoder--decoder architecture."""
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, enc_X, dec_X, *args):
        enc_all_outputs = self.encoder(enc_X, *args)
        dec_state = self.decoder.init_state(enc_all_outputs, *args)
        # Return decoder output only
        return self.decoder(dec_X, dec_state)[0]
class EncoderDecoder(d2l.Classifier):  #@save
    """The base class for the encoder--decoder architecture."""
    encoder: nn.Module
    decoder: nn.Module
    training: bool

    def __call__(self, enc_X, dec_X, *args):
        enc_all_outputs = self.encoder(enc_X, *args, training=self.training)
        dec_state = self.decoder.init_state(enc_all_outputs, *args)
        # Return decoder output only
        return self.decoder(dec_X, dec_state, training=self.training)[0]
class EncoderDecoder(d2l.Classifier):  #@save
    """The base class for the encoder--decoder architecture."""
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder

    def call(self, enc_X, dec_X, *args):
        enc_all_outputs = self.encoder(enc_X, *args, training=True)
        dec_state = self.decoder.init_state(enc_all_outputs, *args)
        # Return decoder output only
        return self.decoder(dec_X, dec_state, training=True)[0]

在下一节中,我们将看到如何应用循环神经网络来设计基于此编码器—解码器架构的序列到序列模型。

10.6.4. 小结

编码器-解码器架构可以处理由可变长度序列组成的输入和输出,因此适用于如机器翻译等序列到序列问题。编码器接收一个可变长度的序列作为输入,并将其转换为一个固定形状的状态。解码器将固定形状的编码状态映射到一个可变长度的序列。

10.6.5. 练习

  1. 假设我们使用神经网络来实现编码器—解码器架构。编码器和解码器必须是同一种类型的神经网络吗?

  2. 除了机器翻译,你能想出另一个可以应用编码器—解码器架构的例子吗?