3.2. 面向实现的面向对象设计
在 Colab 中打开 Notebook
在 Colab 中打开 Notebook
在 Colab 中打开 Notebook
在 Colab 中打开 Notebook
在 SageMaker Studio Lab 中打开 Notebook

在线性回归的介绍中,我们演练了各种组件,包括数据、模型、损失函数和优化算法。的确,线性回归是机器学习中最简单的模型之一。然而,训练它需要使用本书中其他模型所需的许多相同组件。因此,在深入研究实现细节之前,有必要设计一些我们在本书中使用的API。将深度学习中的组件视为对象,我们可以从为这些对象及其交互定义类开始。这种面向实现的面向对象设计将极大地简化演示,你甚至可能想在你的项目中使用它。

PyTorch Lightning 等开源库的启发,在高层次上,我们希望有三个类:(i)Module 包含模型、损失和优化方法;(ii)DataModule 提供用于训练和验证的数据加载器;(iii)这两个类都使用 Trainer 类进行组合,该类允许我们在各种硬件平台上训练模型。本书中的大多数代码都改编自 ModuleDataModule。我们只在讨论GPU、CPU、并行训练和优化算法时才会涉及到 Trainer 类。

import time
import numpy as np
import torch
from torch import nn
from d2l import torch as d2l
import time
import numpy as np
from mxnet.gluon import nn
from d2l import mxnet as d2l
import time
from dataclasses import field
from typing import Any
import jax
import numpy as np
from flax import linen as nn
from flax.training import train_state
from jax import numpy as jnp
from d2l import jax as d2l
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
import time
import numpy as np
import tensorflow as tf
from d2l import tensorflow as d2l

3.2.1. 实用工具

我们需要一些实用工具来简化Jupyter笔记本中的面向对象编程。其中一个挑战是,类定义往往是相当长的代码块。笔记本的可读性要求代码片段要短,并穿插着解释,这一要求与Python库常见的编程风格不兼容。第一个实用工具函数允许我们在类创建*之后*将函数注册为类中的方法。事实上,我们甚至可以在创建了类的实例*之后*再这样做!它允许我们将一个类的实现拆分成多个代码块。

def add_to_class(Class):  #@save
    """Register functions as methods in created class."""
    def wrapper(obj):
        setattr(Class, obj.__name__, obj)
    return wrapper

让我们快速看一下如何使用它。我们计划实现一个类 A,它有一个方法 do。我们不必将 Ado 的代码放在同一个代码块中,而是可以先声明类 A 并创建一个实例 a

class A:
    def __init__(self):
        self.b = 1

a = A()

接下来,我们像通常那样定义方法 do,但不在类 A 的作用域内。相反,我们用 add_to_class 装饰这个方法,并将类 A 作为其参数。这样做,该方法就能够访问 A 的成员变量,就像我们期望它作为 A 定义的一部分一样。让我们看看当我们为实例 a 调用它时会发生什么。

@add_to_class(A)
def do(self):
    print('Class attribute "b" is', self.b)

a.do()
Class attribute "b" is 1
@add_to_class(A)
def do(self):
    print('Class attribute "b" is', self.b)

a.do()
Class attribute "b" is 1
@add_to_class(A)
def do(self):
    print('Class attribute "b" is', self.b)

a.do()
Class attribute "b" is 1
@add_to_class(A)
def do(self):
    print('Class attribute "b" is', self.b)

a.do()
Class attribute "b" is 1

第二个是一个实用工具类,它将一个类的 __init__ 方法中的所有参数保存为类属性。这使我们能够隐式地扩展构造函数调用签名而无需额外的代码。

class HyperParameters:  #@save
    """The base class of hyperparameters."""
    def save_hyperparameters(self, ignore=[]):
        raise NotImplemented

我们将其实现推迟到 第 23.7 节。要使用它,我们定义我们的类,它继承自 HyperParameters 并在 __init__ 方法中调用 save_hyperparameters

# Call the fully implemented HyperParameters class saved in d2l
class B(d2l.HyperParameters):
    def __init__(self, a, b, c):
        self.save_hyperparameters(ignore=['c'])
        print('self.a =', self.a, 'self.b =', self.b)
        print('There is no self.c =', not hasattr(self, 'c'))

b = B(a=1, b=2, c=3)
self.a = 1 self.b = 2
There is no self.c = True
# Call the fully implemented HyperParameters class saved in d2l
class B(d2l.HyperParameters):
    def __init__(self, a, b, c):
        self.save_hyperparameters(ignore=['c'])
        print('self.a =', self.a, 'self.b =', self.b)
        print('There is no self.c =', not hasattr(self, 'c'))

b = B(a=1, b=2, c=3)
self.a = 1 self.b = 2
There is no self.c = True
# Call the fully implemented HyperParameters class saved in d2l
class B(d2l.HyperParameters):
    def __init__(self, a, b, c):
        self.save_hyperparameters(ignore=['c'])
        print('self.a =', self.a, 'self.b =', self.b)
        print('There is no self.c =', not hasattr(self, 'c'))

b = B(a=1, b=2, c=3)
self.a = 1 self.b = 2
There is no self.c = True
# Call the fully implemented HyperParameters class saved in d2l
class B(d2l.HyperParameters):
    def __init__(self, a, b, c):
        self.save_hyperparameters(ignore=['c'])
        print('self.a =', self.a, 'self.b =', self.b)
        print('There is no self.c =', not hasattr(self, 'c'))

b = B(a=1, b=2, c=3)
self.a = 1 self.b = 2
There is no self.c = True

最后一个实用工具允许我们在实验进行时以交互方式绘制实验进度。为了向功能更强大(也更复杂)的 TensorBoard 致敬,我们将其命名为 ProgressBoard。其实现推迟到 第 23.7 节。现在,让我们简单地看看它的实际效果。

draw 方法在图中绘制一个点 (x, y),并在图例中指定 label。可选的 every_n 通过仅在图中显示 \(1/n\) 个点来平滑线条。它们的值是从原始图中 \(n\) 个相邻点的平均值计算出来的。

class ProgressBoard(d2l.HyperParameters):  #@save
    """The board that plots data points in animation."""
    def __init__(self, xlabel=None, ylabel=None, xlim=None,
                 ylim=None, xscale='linear', yscale='linear',
                 ls=['-', '--', '-.', ':'], colors=['C0', 'C1', 'C2', 'C3'],
                 fig=None, axes=None, figsize=(3.5, 2.5), display=True):
        self.save_hyperparameters()

    def draw(self, x, y, label, every_n=1):
        raise NotImplemented

在下面的例子中,我们用不同的平滑度绘制 sincos。如果你运行这个代码块,你会看到线条以动画形式增长。

board = d2l.ProgressBoard('x')
for x in np.arange(0, 10, 0.1):
    board.draw(x, np.sin(x), 'sin', every_n=2)
    board.draw(x, np.cos(x), 'cos', every_n=10)
../_images/output_oo-design_a0c19f_56_0.svg
board = d2l.ProgressBoard('x')
for x in np.arange(0, 10, 0.1):
    board.draw(x, np.sin(x), 'sin', every_n=2)
    board.draw(x, np.cos(x), 'cos', every_n=10)
../_images/output_oo-design_a0c19f_59_0.svg
board = d2l.ProgressBoard('x')
for x in np.arange(0, 10, 0.1):
    board.draw(x, np.sin(x), 'sin', every_n=2)
    board.draw(x, np.cos(x), 'cos', every_n=10)
../_images/output_oo-design_a0c19f_62_0.svg
board = d2l.ProgressBoard('x')
for x in np.arange(0, 10, 0.1):
    board.draw(x, np.sin(x), 'sin', every_n=2)
    board.draw(x, np.cos(x), 'cos', every_n=10)
../_images/output_oo-design_a0c19f_65_0.svg

3.2.2. 模型

Module 类是我们所有将要实现的模型的基类。我们至少需要三个方法。第一个是 __init__,它存储可学习的参数;training_step 方法接受一个数据批次以返回损失值;最后,configure_optimizers 返回用于更新可学习参数的优化方法,或一个优化方法列表。可选地,我们可以定义 validation_step 来报告评估指标。有时我们会将计算输出的代码放入一个单独的 forward 方法中,以使其更具可重用性。

class Module(nn.Module, d2l.HyperParameters):  #@save
    """The base class of models."""
    def __init__(self, plot_train_per_epoch=2, plot_valid_per_epoch=1):
        super().__init__()
        self.save_hyperparameters()
        self.board = ProgressBoard()

    def loss(self, y_hat, y):
        raise NotImplementedError

    def forward(self, X):
        assert hasattr(self, 'net'), 'Neural network is defined'
        return self.net(X)

    def plot(self, key, value, train):
        """Plot a point in animation."""
        assert hasattr(self, 'trainer'), 'Trainer is not inited'
        self.board.xlabel = 'epoch'
        if train:
            x = self.trainer.train_batch_idx / \
                self.trainer.num_train_batches
            n = self.trainer.num_train_batches / \
                self.plot_train_per_epoch
        else:
            x = self.trainer.epoch + 1
            n = self.trainer.num_val_batches / \
                self.plot_valid_per_epoch
        self.board.draw(x, value.to(d2l.cpu()).detach().numpy(),
                        ('train_' if train else 'val_') + key,
                        every_n=int(n))

    def training_step(self, batch):
        l = self.loss(self(*batch[:-1]), batch[-1])
        self.plot('loss', l, train=True)
        return l

    def validation_step(self, batch):
        l = self.loss(self(*batch[:-1]), batch[-1])
        self.plot('loss', l, train=False)

    def configure_optimizers(self):
        raise NotImplementedError

你可能会注意到 Modulenn.Module 的一个子类,后者是 PyTorch 中神经网络的基类。它为处理神经网络提供了便利的功能。例如,如果我们定义一个 forward 方法,比如 forward(self, X),那么对于一个实例 a,我们可以通过 a(X) 调用这个方法。这是因为它在内置的 __call__ 方法中调用了 forward 方法。你可以在 第 6.1 节 中找到更多关于 nn.Module 的细节和例子。

class Module(nn.Block, d2l.HyperParameters):  #@save
    """The base class of models."""
    def __init__(self, plot_train_per_epoch=2, plot_valid_per_epoch=1):
        super().__init__()
        self.save_hyperparameters()
        self.board = ProgressBoard()
    def loss(self, y_hat, y):
        raise NotImplementedError

    def forward(self, X):
        assert hasattr(self, 'net'), 'Neural network is defined'
        return self.net(X)

    def plot(self, key, value, train):
        """Plot a point in animation."""
        assert hasattr(self, 'trainer'), 'Trainer is not inited'
        self.board.xlabel = 'epoch'
        if train:
            x = self.trainer.train_batch_idx / \
                self.trainer.num_train_batches
            n = self.trainer.num_train_batches / \
                self.plot_train_per_epoch
        else:
            x = self.trainer.epoch + 1
            n = self.trainer.num_val_batches / \
                self.plot_valid_per_epoch
        self.board.draw(x, value.asnumpy(), (
            'train_' if train else 'val_') + key, every_n=int(n))
    def training_step(self, batch):
        l = self.loss(self(*batch[:-1]), batch[-1])
        self.plot('loss', l, train=True)
        return l

    def validation_step(self, batch):
        l = self.loss(self(*batch[:-1]), batch[-1])
        self.plot('loss', l, train=False)

    def configure_optimizers(self):
        raise NotImplementedError

您可能注意到,Modulenn.Block 的一个子类,后者是 Gluon 中神经网络的基类。它为处理神经网络提供了便利的功能。例如,如果我们定义一个 forward 方法,比如 forward(self, X),那么对于一个实例 a,我们可以通过 a(X) 调用这个方法。这是因为它在内置的 __call__ 方法中调用了 forward 方法。您可以在 第 6.1 节 中找到更多关于 nn.Block 的细节和示例。

随着 Python 3.7 中 dataclasses 的引入,用 @dataclass 装饰的类会自动添加诸如 __init____repr__ 之类的魔法方法。成员变量使用类型注解来定义。所有的 Flax 模块都是 Python 3.7 的数据类。

class Module(nn.Module, d2l.HyperParameters):  #@save
    """The base class of models."""
    # No need for save_hyperparam when using Python dataclass
    plot_train_per_epoch: int = field(default=2, init=False)
    plot_valid_per_epoch: int = field(default=1, init=False)
    # Use default_factory to make sure new plots are generated on each run
    board: ProgressBoard = field(default_factory=lambda: ProgressBoard(),
                                 init=False)

    def loss(self, y_hat, y):
        raise NotImplementedError

    # JAX & Flax do not have a forward-method-like syntax. Flax uses setup
    # and built-in __call__ magic methods for forward pass. Adding here
    # for consistency
    def forward(self, X, *args, **kwargs):
        assert hasattr(self, 'net'), 'Neural network is defined'
        return self.net(X, *args, **kwargs)

    def __call__(self, X, *args, **kwargs):
        return self.forward(X, *args, **kwargs)

    def plot(self, key, value, train):
        """Plot a point in animation."""
        assert hasattr(self, 'trainer'), 'Trainer is not inited'
        self.board.xlabel = 'epoch'
        if train:
            x = self.trainer.train_batch_idx / \
                self.trainer.num_train_batches
            n = self.trainer.num_train_batches / \
                self.plot_train_per_epoch
        else:
            x = self.trainer.epoch + 1
            n = self.trainer.num_val_batches / \
                self.plot_valid_per_epoch
        self.board.draw(x, jax.device_put(value, d2l.cpu()),
                        ('train_' if train else 'val_') + key,
                        every_n=int(n))

    def training_step(self, params, batch, state):
        l, grads = jax.value_and_grad(self.loss)(params, batch[:-1],
                                                 batch[-1], state)
        self.plot("loss", l, train=True)
        return l, grads

    def validation_step(self, params, batch, state):
        l = self.loss(params, batch[:-1], batch[-1], state)
        self.plot('loss', l, train=False)

    def apply_init(self, dummy_input, key):
        """To be defined later in :numref:`sec_lazy_init`"""
        raise NotImplementedError

    def configure_optimizers(self):
        raise NotImplementedError

您可能注意到,Modulelinen.Module 的一个子类,后者是 Flax 中神经网络的基类。它为处理神经网络提供了便利的功能。例如,它处理模型参数,提供 nn.compact 装饰器来简化代码,调用 __call__ 方法等等。这里我们也将 __call__ 重定向到 forward 方法。我们这样做是为了使我们的代码与其他框架的实现更相似。

class Module(tf.keras.Model, d2l.HyperParameters):  #@save
    """The base class of models."""
    def __init__(self, plot_train_per_epoch=2, plot_valid_per_epoch=1):
        super().__init__()
        self.save_hyperparameters()
        self.board = ProgressBoard()
        self.training = None

    def loss(self, y_hat, y):
        raise NotImplementedError

    def forward(self, X):
        assert hasattr(self, 'net'), 'Neural network is defined'
        return self.net(X)

    def call(self, X, *args, **kwargs):
        if kwargs and "training" in kwargs:
            self.training = kwargs['training']
        return self.forward(X, *args)

    def plot(self, key, value, train):
        """Plot a point in animation."""
        assert hasattr(self, 'trainer'), 'Trainer is not inited'
        self.board.xlabel = 'epoch'
        if train:
            x = self.trainer.train_batch_idx / \
                self.trainer.num_train_batches
            n = self.trainer.num_train_batches / \
                self.plot_train_per_epoch
        else:
            x = self.trainer.epoch + 1
            n = self.trainer.num_val_batches / \
                self.plot_valid_per_epoch
        self.board.draw(x, value.numpy(), (
            'train_' if train else 'val_') + key, every_n=int(n))
    def training_step(self, batch):
        l = self.loss(self(*batch[:-1]), batch[-1])
        self.plot('loss', l, train=True)
        return l

    def validation_step(self, batch):
        l = self.loss(self(*batch[:-1]), batch[-1])
        self.plot('loss', l, train=False)

    def configure_optimizers(self):
        raise NotImplementedError

您可能注意到,Moduletf.keras.Model 的一个子类,后者是 TensorFlow 中神经网络的基类。它为处理神经网络提供了便利的功能。例如,它在内置的 __call__ 方法中调用 call 方法。这里我们将 call 重定向到 forward 方法,并将其参数保存为类属性。我们这样做是为了使我们的代码与其他框架的实现更相似。

3.2.3. 数据

DataModule 类是数据的基类。通常,__init__ 方法用于准备数据。这包括在需要时下载和预处理。 train_dataloader 返回训练数据集的数据加载器。数据加载器是一个(Python)生成器,每次使用时都会产生一个数据批次。然后将此批次送入 Moduletraining_step 方法以计算损失。有一个可选的 val_dataloader 用于返回验证数据集加载器。它的行为方式相同,只是它为 Module 中的 validation_step 方法产生数据批次。

class DataModule(d2l.HyperParameters):  #@save
    """The base class of data."""
    def __init__(self, root='../data', num_workers=4):
        self.save_hyperparameters()

    def get_dataloader(self, train):
        raise NotImplementedError

    def train_dataloader(self):
        return self.get_dataloader(train=True)

    def val_dataloader(self):
        return self.get_dataloader(train=False)
class DataModule(d2l.HyperParameters):  #@save
    """The base class of data."""
    def __init__(self, root='../data', num_workers=4):
        self.save_hyperparameters()

    def get_dataloader(self, train):
        raise NotImplementedError

    def train_dataloader(self):
        return self.get_dataloader(train=True)

    def val_dataloader(self):
        return self.get_dataloader(train=False)
class DataModule(d2l.HyperParameters):  #@save
    """The base class of data."""
    def __init__(self, root='../data'):
        self.save_hyperparameters()

    def get_dataloader(self, train):
        raise NotImplementedError

    def train_dataloader(self):
        return self.get_dataloader(train=True)

    def val_dataloader(self):
        return self.get_dataloader(train=False)
class DataModule(d2l.HyperParameters):  #@save
    """The base class of data."""
    def __init__(self, root='../data'):
        self.save_hyperparameters()

    def get_dataloader(self, train):
        raise NotImplementedError

    def train_dataloader(self):
        return self.get_dataloader(train=True)

    def val_dataloader(self):
        return self.get_dataloader(train=False)

3.2.4. 训练

Trainer 类使用在 DataModule 中指定的数据来训练 Module 类中的可学习参数。关键方法是 fit,它接受两个参数:modelModule 的实例)和 dataDataModule 的实例)。然后,它会遍历整个数据集 max_epochs 次来训练模型。和之前一样,我们将把这个方法的实现推迟到后面的章节。

class Trainer(d2l.HyperParameters):  #@save
    """The base class for training models with data."""
    def __init__(self, max_epochs, num_gpus=0, gradient_clip_val=0):
        self.save_hyperparameters()
        assert num_gpus == 0, 'No GPU support yet'

    def prepare_data(self, data):
        self.train_dataloader = data.train_dataloader()
        self.val_dataloader = data.val_dataloader()
        self.num_train_batches = len(self.train_dataloader)
        self.num_val_batches = (len(self.val_dataloader)
                                if self.val_dataloader is not None else 0)

    def prepare_model(self, model):
        model.trainer = self
        model.board.xlim = [0, self.max_epochs]
        self.model = model

    def fit(self, model, data):
        self.prepare_data(data)
        self.prepare_model(model)
        self.optim = model.configure_optimizers()
        self.epoch = 0
        self.train_batch_idx = 0
        self.val_batch_idx = 0
        for self.epoch in range(self.max_epochs):
            self.fit_epoch()

    def fit_epoch(self):
        raise NotImplementedError

Trainer 类使用在 DataModule 中指定的数据来训练 Module 类中的可学习参数。关键方法是 fit,它接受两个参数:modelModule 的实例)和 dataDataModule 的实例)。然后,它会遍历整个数据集 max_epochs 次来训练模型。和之前一样,我们将把这个方法的实现推迟到后面的章节。

class Trainer(d2l.HyperParameters):  #@save
    """The base class for training models with data."""
    def __init__(self, max_epochs, num_gpus=0, gradient_clip_val=0):
        self.save_hyperparameters()
        assert num_gpus == 0, 'No GPU support yet'

    def prepare_data(self, data):
        self.train_dataloader = data.train_dataloader()
        self.val_dataloader = data.val_dataloader()
        self.num_train_batches = len(self.train_dataloader)
        self.num_val_batches = (len(self.val_dataloader)
                                if self.val_dataloader is not None else 0)

    def prepare_model(self, model):
        model.trainer = self
        model.board.xlim = [0, self.max_epochs]
        self.model = model

    def fit(self, model, data):
        self.prepare_data(data)
        self.prepare_model(model)
        self.optim = model.configure_optimizers()
        self.epoch = 0
        self.train_batch_idx = 0
        self.val_batch_idx = 0
        for self.epoch in range(self.max_epochs):
            self.fit_epoch()

    def fit_epoch(self):
        raise NotImplementedError

Trainer 类使用在 DataModule 中指定的数据训练可学习参数 params。关键方法是 fit,它接受三个参数:modelModule 的实例),dataDataModule 的实例),以及 key(一个 JAX PRNGKeyArray)。我们在这里让 key 参数可选以简化接口,但在 JAX 和 Flax 中,始终建议传递并使用根密钥初始化模型参数。然后它会遍历整个数据集 max_epochs 次来训练模型。和之前一样,我们将把这个方法的实现推迟到后面的章节。

class Trainer(d2l.HyperParameters):  #@save
    """The base class for training models with data."""
    def __init__(self, max_epochs, num_gpus=0, gradient_clip_val=0):
        self.save_hyperparameters()
        assert num_gpus == 0, 'No GPU support yet'

    def prepare_data(self, data):
        self.train_dataloader = data.train_dataloader()
        self.val_dataloader = data.val_dataloader()
        self.num_train_batches = len(self.train_dataloader)
        self.num_val_batches = (len(self.val_dataloader)
                                if self.val_dataloader is not None else 0)

    def prepare_model(self, model):
        model.trainer = self
        model.board.xlim = [0, self.max_epochs]
        self.model = model

    def fit(self, model, data, key=None):
        self.prepare_data(data)
        self.prepare_model(model)
        self.optim = model.configure_optimizers()

        if key is None:
            root_key = d2l.get_key()
        else:
            root_key = key
        params_key, dropout_key = jax.random.split(root_key)
        key = {'params': params_key, 'dropout': dropout_key}

        dummy_input = next(iter(self.train_dataloader))[:-1]
        variables = model.apply_init(dummy_input, key=key)
        params = variables['params']

        if 'batch_stats' in variables.keys():
            # Here batch_stats will be used later (e.g., for batch norm)
            batch_stats = variables['batch_stats']
        else:
            batch_stats = {}

        # Flax uses optax under the hood for a single state obj TrainState.
        # More will be discussed later in the dropout and batch
        # normalization section
        class TrainState(train_state.TrainState):
            batch_stats: Any
            dropout_rng: jax.random.PRNGKeyArray

        self.state = TrainState.create(apply_fn=model.apply,
                                       params=params,
                                       batch_stats=batch_stats,
                                       dropout_rng=dropout_key,
                                       tx=model.configure_optimizers())
        self.epoch = 0
        self.train_batch_idx = 0
        self.val_batch_idx = 0
        for self.epoch in range(self.max_epochs):
            self.fit_epoch()

    def fit_epoch(self):
        raise NotImplementedError

Trainer 类使用在 DataModule 中指定的数据来训练 Module 类中的可学习参数。关键方法是 fit,它接受两个参数:modelModule 的实例)和 dataDataModule 的实例)。然后,它会遍历整个数据集 max_epochs 次来训练模型。和之前一样,我们将把这个方法的实现推迟到后面的章节。

class Trainer(d2l.HyperParameters):  #@save
    """The base class for training models with data."""
    def __init__(self, max_epochs, num_gpus=0, gradient_clip_val=0):
        self.save_hyperparameters()
        assert num_gpus == 0, 'No GPU support yet'

    def prepare_data(self, data):
        self.train_dataloader = data.train_dataloader()
        self.val_dataloader = data.val_dataloader()
        self.num_train_batches = len(self.train_dataloader)
        self.num_val_batches = (len(self.val_dataloader)
                                if self.val_dataloader is not None else 0)

    def prepare_model(self, model):
        model.trainer = self
        model.board.xlim = [0, self.max_epochs]
        self.model = model

    def fit(self, model, data):
        self.prepare_data(data)
        self.prepare_model(model)
        self.optim = model.configure_optimizers()
        self.epoch = 0
        self.train_batch_idx = 0
        self.val_batch_idx = 0
        for self.epoch in range(self.max_epochs):
            self.fit_epoch()

    def fit_epoch(self):
        raise NotImplementedError

3.2.5. 小结

为了突出我们未来深度学习实现的面向对象设计,上述类仅展示了它们的对象如何存储数据以及如何相互交互。我们将在本书的其余部分继续丰富这些类的实现,例如通过 @add_to_class。此外,这些完全实现的类保存在 D2L 库中,这是一个*轻量级工具包*,可以轻松地为深度学习进行结构化建模。特别是,它有助于在项目之间重用许多组件而无需太多改动。例如,我们可以只替换优化器、只替换模型、只替换数据集等;这种模块化程度在整本书中以简洁性和简单性方面带来了回报(这就是我们添加它的原因),它也可以为你的项目做同样的事情。

3.2.6. 练习

  1. 找到保存在 D2L 库中上述类的完整实现。一旦你对深度学习建模有了更多的了解,我们强烈建议你详细查看实现。

  2. B 类中移除 save_hyperparameters 语句。你还能打印 self.aself.b 吗?可选:如果你已经深入研究了 HyperParameters 类的完整实现,你能解释为什么吗?