6.1. 层和块
在 Colab 中打开 Notebook
在 Colab 中打开 Notebook
在 Colab 中打开 Notebook
在 Colab 中打开 Notebook
在 SageMaker Studio Lab 中打开 Notebook

在首次介绍神经网络时,我们关注的是具有单一输出的线性模型。 在这里,整个模型只包含一个神经元。 注意,单个神经元(i)获取一组输入;(ii)生成一个相应的标量输出;(iii)具有一组相关 的参数(parameters),这些参数可以更新,以优化某个感兴趣的目标函数。 然后,当我们开始考虑具有多个输出的网络时,我们利用矢量化算术来描述整个层的神经元。 就像单个神经元一样,层(i)获取一组输入,(ii)生成相应的输出,(iii)由一组可调整参数描述。 当我们使用softmax回归时,一个单层本身就是模型。 然而,即使我们随后引入了多层感知机,我们仍然可以认为该模型保留了这种基本架构。

有趣的是,对于多层感知机而言,整个模型及其组成层都是这种结构。 整个模型获取原始输入(特征),生成输出(预测),并拥有参数(所有组成层的参数集合)。 同样,每个单独的层接收输入(由前一层提供),生成输出(到下一层的输入),并具有一组可调参数,这些参数根据从下一层反向传播的信号进行更新。

虽然你可能认为神经元、层和模型为我们提供了足够的抽象, 但事实证明,我们经常发现谈论比单个层大、但比整个模型小的组件更方便。 例如,在计算机视觉中广泛流行的ResNet-152架构拥有数百个层, 这些层是由层组的重复模式组成。 一次只实现一个层,实现这样的网络会变得很繁琐。 这种担忧不仅仅是假设,这种设计模式在实践中很常见。 上面提到的ResNet架构赢得了2015年ImageNet和COCO计算机视觉竞赛的识别和检测任务 (He et al., 2016), 至今仍然是许多视觉任务的首选架构。 在自然语言处理和语音等其他领域,层以各种重复模式排列的类似架构现在也无处不在。

为了实现这些复杂的网络,我们引入了神经网络的概念。 块(block)可以描述单个层、由多个层组成的组件或整个模型本身。 使用块进行抽象的一个好处是,它们可以组合成更大的组件,通常是递归的, 如 :numref:`fig_blocks` 中所示。 通过定义代码来按需生成任意复杂度的块, 我们可以编写出惊人紧凑的代码,并仍然实现复杂的神经网络。

../_images/blocks.svg

图 6.1.1 多个层被组合成块,形成更大模型的重复模式。

从编程的角度来看,块由(class)表示。 它的任何子类都必须定义一个将输入转换为输出的前向传播函数, 并且必须存储任何必需的参数。 注意,有些块不需要任何参数。 最后,为了计算梯度,块必须具有反向传播函数。 幸运的是,在定义我们自己的块时,由于自动微分(在 :numref:`sec_autograd` 中引入)提供了一些后台魔法,我们只需要考虑参数和前向传播函数。

import torch
from torch import nn
from torch.nn import functional as F
from mxnet import np, npx
from mxnet.gluon import nn

npx.set_np()
from typing import List
import jax
from flax import linen as nn
from jax import numpy as jnp
from d2l import jax as d2l
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
import tensorflow as tf

首先,我们回顾一下用于实现多层感知机的代码 ( :numref:`sec_mlp`)。 下面的代码生成一个网络,其中包含一个具有256个单元和ReLU激活函数的全连接隐藏层, 然后是一个具有10个单元且不带激活函数的全连接输出层。

net = nn.Sequential(nn.LazyLinear(256), nn.ReLU(), nn.LazyLinear(10))

X = torch.rand(2, 20)
net(X).shape
torch.Size([2, 10])

在这个例子中,我们通过实例化`nn.Sequential`来构建我们的模型, 层的列表作为参数传递给`nn.Sequential`的构造函数。 简而言之,`nn.Sequential`定义了一种特殊的`Module`,即在PyTorch中表示一个块的类, 它维护了一个由`Module`组成的有序列表。 注意,两个全连接层都是`Linear`类的实例,`Linear`类本身就是`Module`的子类。 前向传播(`forward`)函数也非常简单:它将列表中的每个块连接在一起,将每个块的输出作为下一个块的输入。 注意,到目前为止,我们一直在通过`net(X)`调用我们的模型来获得它们的输出。 这实际上是`net.__call__(X)`的简写。

net = nn.Sequential()
net.add(nn.Dense(256, activation='relu'))
net.add(nn.Dense(10))
net.initialize()

X = np.random.uniform(size=(2, 20))
net(X).shape
[21:53:59] ../src/storage/storage.cc:196: Using Pooled (Naive) StorageManager for CPU
(2, 10)

在这个例子中,我们通过实例化`nn.Sequential`来构建模型,并将返回的对象赋给`net`变量。 接下来,我们重复调用它的`add`方法,按顺序追加层。 简而言之,`nn.Sequential`定义了一种特殊的`Block`,即在Gluon中表示一个的类。 它维护了一个由`Block`组成的有序列表。 `add`方法只是方便地将每个后续`Block`添加到列表中。 注意,每个层都是`Dense`类的实例,`Dense`类本身就是`Block`的子类。 前向传播(`forward`)函数也非常简单:它将列表中的每个`Block`连接在一起,将每个块的输出作为下一个块的输入。 注意,到目前为止,我们一直在通过`net(X)`调用我们的模型来获得它们的输出。 这实际上是`net.forward(X)`的简写,这是通过`Block`类的`__call__`方法实现的Python技巧。

net = nn.Sequential([nn.Dense(256), nn.relu, nn.Dense(10)])

# get_key is a d2l saved function returning jax.random.PRNGKey(random_seed)
X = jax.random.uniform(d2l.get_key(), (2, 20))
params = net.init(d2l.get_key(), X)
net.apply(params, X).shape
(2, 10)
net = tf.keras.models.Sequential([
    tf.keras.layers.Dense(256, activation=tf.nn.relu),
    tf.keras.layers.Dense(10),
])

X = tf.random.uniform((2, 20))
net(X).shape
TensorShape([2, 10])

在这个例子中,我们通过实例化`keras.models.Sequential`来构建模型,层的列表作为参数传递给它的构造函数。 简而言之,`Sequential`定义了一种特殊的`keras.Model`,即在Keras中表示一个块的类。 它维护了一个由`Model`组成的有序列表。 注意,两个全连接层都是`Dense`类的实例,`Dense`类本身就是`Model`的子类。 前向传播(`call`)函数也非常简单:它将列表中的每个块连接在一起,将每个块的输出作为下一个块的输入。 注意,到目前为止,我们一直在通过`net(X)`调用我们的模型来获得它们的输出。 这实际上是`net.call(X)`的简写,这是通过块类的`__call__`方法实现的Python技巧。

6.1.1. 自定义块

要了解块是如何工作的,最简单的方法可能就是自己实现一个。 在我们开始自己实现一个块之前,我们简要总结一下每个块必须提供的基本功能:

  1. 将输入数据作为其前向传播函数的参数。

  2. 通过前向传播函数来生成一个输出。请注意,输出的形状可能与输入的形状不同。例如,我们上面模型中的第一个全连接层接收一个任意维度的输入,但是返回一个维度256的输出。

  3. 计算其输出关于其输入的梯度,可通过其反向传播函数进行访问。通常这是自动发生的。

  4. 存储和访问执行前向传播计算所必需的那些参数。

  5. 根据需要初始化模型参数。

在下面的代码片段中,我们从零开始编写一个块, 该块对应于一个具有256个隐藏单元的单隐藏层和10维输出层的多层感知机。 注意,下面的`MLP`类继承了表示块的类。 我们将严重依赖父类的方法,只需要提供我们自己的构造函数(Python中的`__init__`函数)和前向传播函数。

class MLP(nn.Module):
    def __init__(self):
        # Call the constructor of the parent class nn.Module to perform
        # the necessary initialization
        super().__init__()
        self.hidden = nn.LazyLinear(256)
        self.out = nn.LazyLinear(10)

    # Define the forward propagation of the model, that is, how to return the
    # required model output based on the input X
    def forward(self, X):
        return self.out(F.relu(self.hidden(X)))
class MLP(nn.Block):
    def __init__(self):
        # Call the constructor of the MLP parent class nn.Block to perform
        # the necessary initialization
        super().__init__()
        self.hidden = nn.Dense(256, activation='relu')
        self.out = nn.Dense(10)

    # Define the forward propagation of the model, that is, how to return the
    # required model output based on the input X
    def forward(self, X):
        return self.out(self.hidden(X))
class MLP(nn.Module):
    def setup(self):
        # Define the layers
        self.hidden = nn.Dense(256)
        self.out = nn.Dense(10)

    # Define the forward propagation of the model, that is, how to return the
    # required model output based on the input X
    def __call__(self, X):
        return self.out(nn.relu(self.hidden(X)))
class MLP(tf.keras.Model):
    def __init__(self):
        # Call the constructor of the parent class tf.keras.Model to perform
        # the necessary initialization
        super().__init__()
        self.hidden = tf.keras.layers.Dense(units=256, activation=tf.nn.relu)
        self.out = tf.keras.layers.Dense(units=10)

    # Define the forward propagation of the model, that is, how to return the
    # required model output based on the input X
    def call(self, X):
        return self.out(self.hidden((X)))

我们首先关注前向传播函数。 注意,它以`X`作为输入, 计算带有激活函数的隐藏表示,并输出其未规范化的输出值。 在这个`MLP`实现中,两个层都是实例变量。 要知道为什么这样做是合理的,可以想象实例化两个多层感知机(`net1`和`net2`),并根据不同的数据对它们进行训练。 当然,我们希望它们学到两种不同的模型。

我们在构造函数中实例化多层感知机的层,然后在每次调用前向传播函数时调用这些层。 注意一些关键细节:首先,我们自定义的 `__init__` 函数通过 `super().__init__()` 调用父类的 `__init__` 函数,省去了重复编写模版代码的麻烦。然后,我们实例化两个全连接层,并将它们分别赋值给 `self.hidden` 和 `self.out`。注意,除非我们实现一个新的层,否则我们不必担心反向传播函数或参数初始化,系统将自动生成这些。我们来试一下:

net = MLP()
net(X).shape
torch.Size([2, 10])
net = MLP()
net.initialize()
net(X).shape
(2, 10)
net = MLP()
params = net.init(d2l.get_key(), X)
net.apply(params, X).shape
(2, 10)
net = MLP()
net(X).shape
TensorShape([2, 10])

块的一个主要优点是它的多功能性。 我们可以子类化块以创建层(如全连接层类)、 整个模型(如上面的`MLP`类),或具有中等复杂度的各种组件。 我们在接下来的章节中充分利用了这种多功能性,比如在处理卷积神经网络时。

6.1.2. Sequential 块

现在我们可以更仔细地看看`Sequential`类是如何工作的, 回忆一下`Sequential`的设计是为了把其他模块串起来。 为了构建我们自己的简化的`MySequential`,我们只需要定义两个关键函数:

  1. 一种将块逐个追加到列表中的函数。

  2. 一种前向传播函数,用于将输入按追加块的顺序传递给块列表。

下面的`MySequential`类提供了与默认`Sequential`类相同的功能。

class MySequential(nn.Module):
    def __init__(self, *args):
        super().__init__()
        for idx, module in enumerate(args):
            self.add_module(str(idx), module)

    def forward(self, X):
        for module in self.children():
            X = module(X)
        return X

在 `__init__` 方法中,我们通过调用 `add_modules` 方法添加每个模块。这些模块以后可以通过 `children` 方法访问。通过这种方式,系统知道已添加的模块,并将正确初始化每个模块的参数。

class MySequential(nn.Block):
    def add(self, block):
        # Here, block is an instance of a Block subclass, and we assume that
        # it has a unique name. We save it in the member variable _children of
        # the Block class, and its type is OrderedDict. When the MySequential
        # instance calls the initialize method, the system automatically
        # initializes all members of _children
        self._children[block.name] = block

    def forward(self, X):
        # OrderedDict guarantees that members will be traversed in the order
        # they were added
        for block in self._children.values():
            X = block(X)
        return X

`add`方法将单个块添加到有序字典`_children`中。 你可能想知道为什么每个Gluon `Block`都有一个`_children`属性?为什么我们使用它而不是自己定义一个Python列表? 简而言之,`_children`的主要优点是在块的参数初始化期间, Gluon知道在`_children`字典中查找需要初始化参数的子块。

class MySequential(nn.Module):
    modules: List

    def __call__(self, X):
        for module in self.modules:
            X = module(X)
        return X
class MySequential(tf.keras.Model):
    def __init__(self, *args):
        super().__init__()
        self.modules = args

    def call(self, X):
        for module in self.modules:
            X = module(X)
        return X

当`MySequential`的前向传播函数被调用时, 每个添加的块都按照它们被添加的顺序执行。 现在,我们可以使用`MySequential`类重新实现多层感知机。

net = MySequential(nn.LazyLinear(256), nn.ReLU(), nn.LazyLinear(10))
net(X).shape
torch.Size([2, 10])
net = MySequential()
net.add(nn.Dense(256, activation='relu'))
net.add(nn.Dense(10))
net.initialize()
net(X).shape
(2, 10)
net = MySequential([nn.Dense(256), nn.relu, nn.Dense(10)])
params = net.init(d2l.get_key(), X)
net.apply(params, X).shape
(2, 10)
net = MySequential(
    tf.keras.layers.Dense(units=256, activation=tf.nn.relu),
    tf.keras.layers.Dense(10))
net(X).shape
TensorShape([2, 10])

请注意,`MySequential`的用法与之前为`Sequential`类编写的代码相同(如 :numref:`sec_mlp` 中所述)。

6.1.3. 在前向传播函数中执行代码

`Sequential`类可以使模型构造变得简单,允许我们组合新的架构,而无需定义自己的类。 然而,并不是所有的架构都是简单的顺序架构。 当需要更强的灵活性时,我们需要定义自己的块。 例如,我们可能希望在前向传播函数中执行Python的控制流。 此外,我们可能希望执行任意的数学运算,而不仅仅是依赖预定义的神经网络层。

你可能已经注意到,到目前为止,我们网络中的所有操作都作用于网络的激活值和参数。 然而,有时我们可能希望合并既不是上一层的结果也不是可更新参数的项,我们称之为常数参数(constant parameters)。 比如说,我们需要一个计算函数 \(f(\mathbf{x},\mathbf{w}) = c \cdot \mathbf{w}^\top \mathbf{x}\) 的层,其中 \(\mathbf{x}\) 是输入,\(\mathbf{w}\) 是参数,\(c\) 是某个在优化过程中不会更新的指定常量。因此我们实现了如下的 `FixedHiddenMLP` 类。

class FixedHiddenMLP(nn.Module):
    def __init__(self):
        super().__init__()
        # Random weight parameters that will not compute gradients and
        # therefore keep constant during training
        self.rand_weight = torch.rand((20, 20))
        self.linear = nn.LazyLinear(20)

    def forward(self, X):
        X = self.linear(X)
        X = F.relu(X @ self.rand_weight + 1)
        # Reuse the fully connected layer. This is equivalent to sharing
        # parameters with two fully connected layers
        X = self.linear(X)
        # Control flow
        while X.abs().sum() > 1:
            X /= 2
        return X.sum()
class FixedHiddenMLP(nn.Block):
    def __init__(self):
        super().__init__()
        # Random weight parameters created with the get_constant method
        # are not updated during training (i.e., constant parameters)
        self.rand_weight = self.params.get_constant(
            'rand_weight', np.random.uniform(size=(20, 20)))
        self.dense = nn.Dense(20, activation='relu')

    def forward(self, X):
        X = self.dense(X)
        # Use the created constant parameters, as well as the relu and dot
        # functions
        X = npx.relu(np.dot(X, self.rand_weight.data()) + 1)
        # Reuse the fully connected layer. This is equivalent to sharing
        # parameters with two fully connected layers
        X = self.dense(X)
        # Control flow
        while np.abs(X).sum() > 1:
            X /= 2
        return X.sum()
class FixedHiddenMLP(nn.Module):
    # Random weight parameters that will not compute gradients and
    # therefore keep constant during training
    rand_weight: jnp.array = jax.random.uniform(d2l.get_key(), (20, 20))

    def setup(self):
        self.dense = nn.Dense(20)

    def __call__(self, X):
        X = self.dense(X)
        X = nn.relu(X @ self.rand_weight + 1)
        # Reuse the fully connected layer. This is equivalent to sharing
        # parameters with two fully connected layers
        X = self.dense(X)
        # Control flow
        while jnp.abs(X).sum() > 1:
            X /= 2
        return X.sum()
class FixedHiddenMLP(tf.keras.Model):
    def __init__(self):
        super().__init__()
        self.flatten = tf.keras.layers.Flatten()
        # Random weight parameters created with tf.constant are not updated
        # during training (i.e., constant parameters)
        self.rand_weight = tf.constant(tf.random.uniform((20, 20)))
        self.dense = tf.keras.layers.Dense(20, activation=tf.nn.relu)

    def call(self, inputs):
        X = self.flatten(inputs)
        # Use the created constant parameters, as well as the relu and
        # matmul functions
        X = tf.nn.relu(tf.matmul(X, self.rand_weight) + 1)
        # Reuse the fully connected layer. This is equivalent to sharing
        # parameters with two fully connected layers
        X = self.dense(X)
        # Control flow
        while tf.reduce_sum(tf.math.abs(X)) > 1:
            X /= 2
        return tf.reduce_sum(X)

在这个 `FixedHiddenMLP` 模型中,我们实现了一个隐藏层,其权重(`self.rand_weight`)在实例化时被随机初始化,之后就是常量。 这个权重不是一个模型参数,因此它永远不会被反向传播更新。 然后,网络将这个“固定”层的输出通过一个全连接层。

请注意,在返回输出之前,我们的模型做了一些不寻常的事情。 我们运行了一个while循环,在 L1 范数大于 \(1\) 的条件下, 将输出向量除以 \(2\),直到它满足条件为止。 最后,我们返回了`X`中所有项的和。 据我们所知,没有标准的神经网络会执行此操作。 请注意,这个特定操作可能在任何实际任务中都没有用。 我们的重点是向你展示如何将任意代码集成到神经网络计算的流程中。

net = FixedHiddenMLP()
net(X)
tensor(-0.3836, grad_fn=<SumBackward0>)
net = FixedHiddenMLP()
net.initialize()
net(X)
array(0.52637565)
net = FixedHiddenMLP()
params = net.init(d2l.get_key(), X)
net.apply(params, X)
Array(0.32849464, dtype=float32)
net = FixedHiddenMLP()
net(X)
<tf.Tensor: shape=(), dtype=float32, numpy=0.6186229>

我们可以混合搭配各种组合块的方法。 在下面的例子中,我们以一些创造性的方式嵌套块。

class NestMLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(nn.LazyLinear(64), nn.ReLU(),
                                 nn.LazyLinear(32), nn.ReLU())
        self.linear = nn.LazyLinear(16)

    def forward(self, X):
        return self.linear(self.net(X))

chimera = nn.Sequential(NestMLP(), nn.LazyLinear(20), FixedHiddenMLP())
chimera(X)
tensor(0.0679, grad_fn=<SumBackward0>)
class NestMLP(nn.Block):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.net = nn.Sequential()
        self.net.add(nn.Dense(64, activation='relu'),
                     nn.Dense(32, activation='relu'))
        self.dense = nn.Dense(16, activation='relu')

    def forward(self, X):
        return self.dense(self.net(X))

chimera = nn.Sequential()
chimera.add(NestMLP(), nn.Dense(20), FixedHiddenMLP())
chimera.initialize()
chimera(X)
array(0.97720534)
class NestMLP(nn.Module):
    def setup(self):
        self.net = nn.Sequential([nn.Dense(64), nn.relu,
                                  nn.Dense(32), nn.relu])
        self.dense = nn.Dense(16)

    def __call__(self, X):
        return self.dense(self.net(X))


chimera = nn.Sequential([NestMLP(), nn.Dense(20), FixedHiddenMLP()])
params = chimera.init(d2l.get_key(), X)
chimera.apply(params, X)
Array(-0.1306174, dtype=float32)
class NestMLP(tf.keras.Model):
    def __init__(self):
        super().__init__()
        self.net = tf.keras.Sequential()
        self.net.add(tf.keras.layers.Dense(64, activation=tf.nn.relu))
        self.net.add(tf.keras.layers.Dense(32, activation=tf.nn.relu))
        self.dense = tf.keras.layers.Dense(16, activation=tf.nn.relu)

    def call(self, inputs):
        return self.dense(self.net(inputs))

chimera = tf.keras.Sequential()
chimera.add(NestMLP())
chimera.add(tf.keras.layers.Dense(20))
chimera.add(FixedHiddenMLP())
chimera(X)
<tf.Tensor: shape=(), dtype=float32, numpy=0.59283525>

6.1.4. 小结

一个块可以由许多层组成;一个块可以由许多块组成。

块可以包含代码。块负责大量的内部处理,包括参数初始化和反向传播。层和块的顺序连接由`Sequential`块处理。

6.1.5. 练习

  1. 如果将`MySequential`中存储块的方式更改为Python列表,会发生什么问题?

  2. 实现一个块,它以两个块为参数,例如 `net1` 和 `net2`,并在前向传播中返回两个网络的串联输出。这也被称为并行块

  3. 假设你想要连接同一网络的多个实例。实现一个工厂函数,该函数可以生成同一块的多个实例,并使用它构建更大的网络。