6.4. 延后初始化
在 Colab 中打开 Notebook
在 Colab 中打开 Notebook
在 Colab 中打开 Notebook
在 Colab 中打开 Notebook
在 SageMaker Studio Lab 中打开 Notebook

到目前为止,我们建立网络时似乎一直在做一些“不寻常”的事情。我们一直在执行下面这些可能看起来不应该工作的操作:

  • 我们定义了网络架构,但没有指定输入维度。

  • 我们添加层时没有指定前一层的输出维度。

  • 我们甚至在提供足够信息来确定模型应该包含多少参数之前就“初始化”了这些参数。

你可能会惊讶于我们的代码居然能运行。毕竟,深度学习框架无法知道网络的输入维度是什么。这里的诀窍是框架会*延后初始化*(deferred initialization),即等到第一次将数据传递给模型时,才会动态地推断出每一层的形状大小。

稍后,在处理卷积神经网络时,这项技术会变得更加方便,因为输入维度(例如,图像的分辨率)将影响每个后续层的维度。因此,能够在编写代码时无需知道维度值就可以设置参数,可以大大简化指定和修改模型的任务。接下来,我们将更深入地探讨初始化的机制。

import torch
from torch import nn
from d2l import torch as d2l
from mxnet import np, npx
from mxnet.gluon import nn

npx.set_np()
import jax
from flax import linen as nn
from jax import numpy as jnp
from d2l import jax as d2l
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
import tensorflow as tf

首先,让我们实例化一个多层感知机。

net = nn.Sequential(nn.LazyLinear(256), nn.ReLU(), nn.LazyLinear(10))
net = nn.Sequential()
net.add(nn.Dense(256, activation='relu'))
net.add(nn.Dense(10))
net = nn.Sequential([nn.Dense(256), nn.relu, nn.Dense(10)])
net = tf.keras.models.Sequential([
    tf.keras.layers.Dense(256, activation=tf.nn.relu),
    tf.keras.layers.Dense(10),
])

此时,网络不可能知道输入层权重的维度,因为输入维度仍然未知。

因此,框架尚未初始化任何参数。我们通过尝试访问下面的参数来确认这一点。

net[0].weight
<UninitializedParameter>

因此,框架尚未初始化任何参数。我们通过尝试访问下面的参数来确认这一点。

print(net.collect_params)
print(net.collect_params())
<bound method Block.collect_params of Sequential(
  (0): Dense(-1 -> 256, Activation(relu))
  (1): Dense(-1 -> 10, linear)
)>
sequential0_ (
  Parameter dense0_weight (shape=(256, -1), dtype=float32)
  Parameter dense0_bias (shape=(256,), dtype=float32)
  Parameter dense1_weight (shape=(10, -1), dtype=float32)
  Parameter dense1_bias (shape=(10,), dtype=float32)
)

请注意,虽然参数对象存在,但每个层的输入维度都列为-1。MXNet 使用特殊值-1来表示参数维度仍然未知。此时,尝试访问 net[0].weight.data() 将触发一个运行时错误,指出在访问参数之前必须初始化网络。现在让我们看看当我们尝试通过 initialize 方法初始化参数时会发生什么。

net.initialize()
net.collect_params()
[22:11:11] ../src/storage/storage.cc:196: Using Pooled (Naive) StorageManager for CPU
sequential0_ (
  Parameter dense0_weight (shape=(256, -1), dtype=float32)
  Parameter dense0_bias (shape=(256,), dtype=float32)
  Parameter dense1_weight (shape=(10, -1), dtype=float32)
  Parameter dense1_bias (shape=(10,), dtype=float32)
)

正如我们所看到的,没有任何改变。当输入维度未知时,调用initialize并不会真正地初始化参数。相反,这个调用向MXNet注册了我们希望(可选地,根据哪个分布)初始化参数的意图。

正如在 6.2.1节 中提到的,在Jax和Flax中,参数和网络定义是解耦的,用户需要手动处理两者。Flax模型是无状态的,因此没有 parameters 属性。

因此,框架尚未初始化任何参数。我们通过尝试访问下面的参数来确认这一点。

[net.layers[i].get_weights() for i in range(len(net.layers))]
[[], []]

请注意,每个层对象都存在,但权重是空的。使用 net.get_weights() 会抛出一个错误,因为权重还没有被初始化。

接下来,让我们将数据传递给网络,让框架最终初始化参数。

X = torch.rand(2, 20)
net(X)

net[0].weight.shape
torch.Size([256, 20])
X = np.random.uniform(size=(2, 20))
net(X)

net.collect_params()
sequential0_ (
  Parameter dense0_weight (shape=(256, 20), dtype=float32)
  Parameter dense0_bias (shape=(256,), dtype=float32)
  Parameter dense1_weight (shape=(10, 256), dtype=float32)
  Parameter dense1_bias (shape=(10,), dtype=float32)
)
params = net.init(d2l.get_key(), jnp.zeros((2, 20)))
jax.tree_util.tree_map(lambda x: x.shape, params).tree_flatten_with_keys()
(((DictKey(key='params'),
   {'layers_0': {'bias': (256,), 'kernel': (20, 256)},
    'layers_2': {'bias': (10,), 'kernel': (256, 10)}}),),
 ('params',))
X = tf.random.uniform((2, 20))
net(X)
[w.shape for w in net.get_weights()]
[(20, 256), (256,), (256, 10), (10,)]

一旦我们知道了输入维度是20,框架就可以通过代入20这个值来确定第一层权重矩阵的形状。在识别出第一层的形状后,框架会继续处理第二层,以此类推,直到计算图中所有形状都已知。请注意,在这种情况下,只有第一层需要延后初始化,但框架是按顺序初始化的。一旦所有参数形状都已知,框架最终就可以初始化参数了。

以下方法通过网络传递虚拟输入进行一次“演练”,以推断所有参数的形状,并随后初始化参数。当不希望使用默认的随机初始化时,稍后会用到此方法。

@d2l.add_to_class(d2l.Module)  #@save
def apply_init(self, inputs, init=None):
    self.forward(*inputs)
    if init is not None:
        self.net.apply(init)

在Flax中,参数初始化总是由用户手动完成和处理的。以下方法接受一个虚拟输入和一个键字典作为参数。这个键字典包含了用于初始化模型参数的rngs和用于为带dropout层的模型生成dropout掩码的dropout rng。关于dropout的更多内容将在 5.6节 中稍后介绍。最终,该方法初始化模型并返回参数。在之前的章节中,我们也在后台使用了它。

@d2l.add_to_class(d2l.Module)  #@save
def apply_init(self, dummy_input, key):
    params = self.init(key, *dummy_input)  # dummy_input tuple unpacked
    return params

6.4.1. 小结

延后初始化可能很方便,它允许框架自动推断参数形状,使得修改架构变得容易,并消除了一个常见的错误来源。我们可以通过模型传递数据,让框架最终初始化参数。

6.4.2. 练习

  1. 如果你为第一层指定了输入维度,但没有为后续层指定,会发生什么?你会得到立即初始化吗?

  2. 如果你指定了不匹配的维度,会发生什么?

  3. 如果你有不同维度的输入,你需要做什么?提示:查看参数绑定。