6.6. 文件读写
在 Colab 中打开 Notebook
在 Colab 中打开 Notebook
在 Colab 中打开 Notebook
在 Colab 中打开 Notebook
在 SageMaker Studio Lab 中打开 Notebook

到目前为止,我们讨论了如何处理数据,以及如何构建、训练和测试深度学习模型。然而,有时我们希望保存训练好的模型,以备将来在各种环境中使用(比如在部署中进行预测)。此外,当运行一个耗时良久的训练过程时,最佳的做法是定期保存中间结果(检查点),以确保在服务器电源被不小心断掉时,我们不会损失几天的计算结果。因此,现在是时候学习如何加载和存储权重向量和整个模型了。本节将解决这两个问题。

import torch
from torch import nn
from torch.nn import functional as F
from mxnet import np, npx
from mxnet.gluon import nn

npx.set_np()
import flax
import jax
from flax import linen as nn
from flax.training import checkpoints
from jax import numpy as jnp
from d2l import jax as d2l
WARNING:absl:GlobalAsyncCheckpointManager is not imported correctly. Checkpointing of GlobalDeviceArrays will not be available.To use the feature, install tensorstore.
WARNING:jax._src.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
import numpy as np
import tensorflow as tf

6.6.1. 加载和保存张量

对于单个张量,我们可以直接调用loadsave函数分别读写它们。这两个函数都要求我们提供一个名称,save要求将要保存的变量作为输入。

x = torch.arange(4)
torch.save(x, 'x-file')
x = np.arange(4)
npx.save('x-file', x)
[21:49:50] ../src/storage/storage.cc:196: Using Pooled (Naive) StorageManager for CPU
x = jnp.arange(4)
jnp.save('x-file.npy', x)
x = tf.range(4)
np.save('x-file.npy', x)

我们现在可以从存储的文件中将数据读回内存。

x2 = torch.load('x-file')
x2
tensor([0, 1, 2, 3])
x2 = npx.load('x-file')
x2
[array([0., 1., 2., 3.])]
x2 = jnp.load('x-file.npy', allow_pickle=True)
x2
Array([0, 1, 2, 3], dtype=int32)
x2 = np.load('x-file.npy', allow_pickle=True)
x2
array([0, 1, 2, 3], dtype=int32)

我们可以存储一个张量列表,然后把它们读回内存。

y = torch.zeros(4)
torch.save([x, y],'x-files')
x2, y2 = torch.load('x-files')
(x2, y2)
(tensor([0, 1, 2, 3]), tensor([0., 0., 0., 0.]))
y = np.zeros(4)
npx.save('x-files', [x, y])
x2, y2 = npx.load('x-files')
(x2, y2)
(array([0., 1., 2., 3.]), array([0., 0., 0., 0.]))
y = jnp.zeros(4)
jnp.save('xy-files.npy', [x, y])
x2, y2 = jnp.load('xy-files.npy', allow_pickle=True)
(x2, y2)
(Array([0., 1., 2., 3.], dtype=float32),
 Array([0., 0., 0., 0.], dtype=float32))
y = tf.zeros(4)
np.save('xy-files.npy', [x, y])
x2, y2 = np.load('xy-files.npy', allow_pickle=True)
(x2, y2)
(array([0., 1., 2., 3.]), array([0., 0., 0., 0.]))

我们甚至可以写入和读取一个从字符串映射到张量的字典。当我们要读取或写入模型中的所有权重时,这很方便。

mydict = {'x': x, 'y': y}
torch.save(mydict, 'mydict')
mydict2 = torch.load('mydict')
mydict2
{'x': tensor([0, 1, 2, 3]), 'y': tensor([0., 0., 0., 0.])}
mydict = {'x': x, 'y': y}
npx.save('mydict', mydict)
mydict2 = npx.load('mydict')
mydict2
{'x': array([0., 1., 2., 3.]), 'y': array([0., 0., 0., 0.])}
mydict = {'x': x, 'y': y}
jnp.save('mydict.npy', mydict)
mydict2 = jnp.load('mydict.npy', allow_pickle=True)
mydict2
array({'x': Array([0, 1, 2, 3], dtype=int32), 'y': Array([0., 0., 0., 0.], dtype=float32)},
      dtype=object)
mydict = {'x': x, 'y': y}
np.save('mydict.npy', mydict)
mydict2 = np.load('mydict.npy', allow_pickle=True)
mydict2
array({'x': <tf.Tensor: shape=(4,), dtype=int32, numpy=array([0, 1, 2, 3], dtype=int32)>, 'y': <tf.Tensor: shape=(4,), dtype=float32, numpy=array([0., 0., 0., 0.], dtype=float32)>},
      dtype=object)

6.6.2. 加载和保存模型参数

保存单个权重向量(或其他张量)是 полезно,但如果我们想保存(并在以后加载)整个模型,这就变得非常繁琐了。毕竟,我们可能有数百个参数组散布在各处。因此,深度学习框架提供了内置功能来加载和保存整个网络。需要注意的一个重要细节是,这将保存模型的*参数*而不是整个模型。例如,如果我们有一个3层的多层感知机,我们需要单独指定架构。这样做的原因是模型本身可以包含任意代码,因此它们不能被自然地序列化。因此,为了恢复一个模型,我们需要用代码生成架构,然后从磁盘加载参数。让我们从我们熟悉的多层感知机开始。

class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.hidden = nn.LazyLinear(256)
        self.output = nn.LazyLinear(10)

    def forward(self, x):
        return self.output(F.relu(self.hidden(x)))

net = MLP()
X = torch.randn(size=(2, 20))
Y = net(X)
class MLP(nn.Block):
    def __init__(self, **kwargs):
        super(MLP, self).__init__(**kwargs)
        self.hidden = nn.Dense(256, activation='relu')
        self.output = nn.Dense(10)

    def forward(self, x):
        return self.output(self.hidden(x))

net = MLP()
net.initialize()
X = np.random.uniform(size=(2, 20))
Y = net(X)
class MLP(nn.Module):
    def setup(self):
        self.hidden = nn.Dense(256)
        self.output = nn.Dense(10)

    def __call__(self, x):
        return self.output(nn.relu(self.hidden(x)))

net = MLP()
X = jax.random.normal(jax.random.PRNGKey(d2l.get_seed()), (2, 20))
Y, params = net.init_with_output(jax.random.PRNGKey(d2l.get_seed()), X)
class MLP(tf.keras.Model):
    def __init__(self):
        super().__init__()
        self.flatten = tf.keras.layers.Flatten()
        self.hidden = tf.keras.layers.Dense(units=256, activation=tf.nn.relu)
        self.out = tf.keras.layers.Dense(units=10)

    def call(self, inputs):
        x = self.flatten(inputs)
        x = self.hidden(x)
        return self.out(x)

net = MLP()
X = tf.random.uniform((2, 20))
Y = net(X)

接下来,我们将模型的参数存储在一个名为“mlp.params”的文件中。

torch.save(net.state_dict(), 'mlp.params')
net.save_parameters('mlp.params')
checkpoints.save_checkpoint('ckpt_dir', params, step=1, overwrite=True)
'ckpt_dir/checkpoint_1'
net.save_weights('mlp.params')

为了恢复模型,我们实例化了原始 MLP 模型的一个克隆。我们没有随机初始化模型参数,而是直接读取存储在文件中的参数。

clone = MLP()
clone.load_state_dict(torch.load('mlp.params'))
clone.eval()
MLP(
  (hidden): LazyLinear(in_features=0, out_features=256, bias=True)
  (output): LazyLinear(in_features=0, out_features=10, bias=True)
)
clone = MLP()
clone.load_parameters('mlp.params')
clone = MLP()
cloned_params = flax.core.freeze(checkpoints.restore_checkpoint('ckpt_dir',
                                                                target=None))
clone = MLP()
clone.load_weights('mlp.params')
<tensorflow.python.checkpoint.checkpoint.CheckpointLoadStatus at 0x7fe86a880820>

由于两个实例具有相同的模型参数,因此相同输入X的计算结果应该相同。让我们来验证一下。

Y_clone = clone(X)
Y_clone == Y
tensor([[True, True, True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True, True, True]])
Y_clone = clone(X)
Y_clone == Y
array([[ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True]])
Y_clone = clone.apply(cloned_params, X)
Y_clone == Y
Array([[ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True]], dtype=bool)
Y_clone = clone(X)
Y_clone == Y
<tf.Tensor: shape=(2, 10), dtype=bool, numpy=
array([[ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True]])>

6.6.3. 小结

saveload函数可用于对张量对象进行文件I/O操作。我们可以通过参数字典为网络保存和加载整套参数。保存架构必须在代码中完成,而不是在参数中。

6.6.4. 练习

  1. 即使不需要将训练好的模型部署到不同的设备上,存储模型参数有什么实际的好处?

  2. 假设我们只想重用网络的一部分,将其整合到一个具有不同架构的网络中。你将如何使用,比如说,一个先前网络的前两层来构建一个新的网络?

  3. 你将如何保存网络架构和参数?你会对架构施加什么限制?