3.2. 面向实现的面向对象设计¶ 在 SageMaker Studio Lab 中打开 Notebook
在线性回归的介绍中,我们演练了各种组件,包括数据、模型、损失函数和优化算法。的确,线性回归是机器学习中最简单的模型之一。然而,训练它需要使用本书中其他模型所需的许多相同组件。因此,在深入研究实现细节之前,有必要设计一些我们在本书中使用的API。将深度学习中的组件视为对象,我们可以从为这些对象及其交互定义类开始。这种面向实现的面向对象设计将极大地简化演示,你甚至可能想在你的项目中使用它。
受 PyTorch Lightning 等开源库的启发,在高层次上,我们希望有三个类:(i)Module
包含模型、损失和优化方法;(ii)DataModule
提供用于训练和验证的数据加载器;(iii)这两个类都使用 Trainer
类进行组合,该类允许我们在各种硬件平台上训练模型。本书中的大多数代码都改编自 Module
和 DataModule
。我们只在讨论GPU、CPU、并行训练和优化算法时才会涉及到 Trainer
类。
import time
import numpy as np
import torch
from torch import nn
from d2l import torch as d2l
import time
import numpy as np
from mxnet.gluon import nn
from d2l import mxnet as d2l
import time
from dataclasses import field
from typing import Any
import jax
import numpy as np
from flax import linen as nn
from flax.training import train_state
from jax import numpy as jnp
from d2l import jax as d2l
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
import time
import numpy as np
import tensorflow as tf
from d2l import tensorflow as d2l
3.2.1. 实用工具¶
我们需要一些实用工具来简化Jupyter笔记本中的面向对象编程。其中一个挑战是,类定义往往是相当长的代码块。笔记本的可读性要求代码片段要短,并穿插着解释,这一要求与Python库常见的编程风格不兼容。第一个实用工具函数允许我们在类创建*之后*将函数注册为类中的方法。事实上,我们甚至可以在创建了类的实例*之后*再这样做!它允许我们将一个类的实现拆分成多个代码块。
def add_to_class(Class): #@save
"""Register functions as methods in created class."""
def wrapper(obj):
setattr(Class, obj.__name__, obj)
return wrapper
让我们快速看一下如何使用它。我们计划实现一个类 A
,它有一个方法 do
。我们不必将 A
和 do
的代码放在同一个代码块中,而是可以先声明类 A
并创建一个实例 a
。
class A:
def __init__(self):
self.b = 1
a = A()
接下来,我们像通常那样定义方法 do
,但不在类 A
的作用域内。相反,我们用 add_to_class
装饰这个方法,并将类 A
作为其参数。这样做,该方法就能够访问 A
的成员变量,就像我们期望它作为 A
定义的一部分一样。让我们看看当我们为实例 a
调用它时会发生什么。
@add_to_class(A)
def do(self):
print('Class attribute "b" is', self.b)
a.do()
Class attribute "b" is 1
@add_to_class(A)
def do(self):
print('Class attribute "b" is', self.b)
a.do()
Class attribute "b" is 1
@add_to_class(A)
def do(self):
print('Class attribute "b" is', self.b)
a.do()
Class attribute "b" is 1
@add_to_class(A)
def do(self):
print('Class attribute "b" is', self.b)
a.do()
Class attribute "b" is 1
第二个是一个实用工具类,它将一个类的 __init__
方法中的所有参数保存为类属性。这使我们能够隐式地扩展构造函数调用签名而无需额外的代码。
class HyperParameters: #@save
"""The base class of hyperparameters."""
def save_hyperparameters(self, ignore=[]):
raise NotImplemented
我们将其实现推迟到 第 23.7 节。要使用它,我们定义我们的类,它继承自 HyperParameters
并在 __init__
方法中调用 save_hyperparameters
。
# Call the fully implemented HyperParameters class saved in d2l
class B(d2l.HyperParameters):
def __init__(self, a, b, c):
self.save_hyperparameters(ignore=['c'])
print('self.a =', self.a, 'self.b =', self.b)
print('There is no self.c =', not hasattr(self, 'c'))
b = B(a=1, b=2, c=3)
self.a = 1 self.b = 2
There is no self.c = True
# Call the fully implemented HyperParameters class saved in d2l
class B(d2l.HyperParameters):
def __init__(self, a, b, c):
self.save_hyperparameters(ignore=['c'])
print('self.a =', self.a, 'self.b =', self.b)
print('There is no self.c =', not hasattr(self, 'c'))
b = B(a=1, b=2, c=3)
self.a = 1 self.b = 2
There is no self.c = True
# Call the fully implemented HyperParameters class saved in d2l
class B(d2l.HyperParameters):
def __init__(self, a, b, c):
self.save_hyperparameters(ignore=['c'])
print('self.a =', self.a, 'self.b =', self.b)
print('There is no self.c =', not hasattr(self, 'c'))
b = B(a=1, b=2, c=3)
self.a = 1 self.b = 2
There is no self.c = True
# Call the fully implemented HyperParameters class saved in d2l
class B(d2l.HyperParameters):
def __init__(self, a, b, c):
self.save_hyperparameters(ignore=['c'])
print('self.a =', self.a, 'self.b =', self.b)
print('There is no self.c =', not hasattr(self, 'c'))
b = B(a=1, b=2, c=3)
self.a = 1 self.b = 2
There is no self.c = True
最后一个实用工具允许我们在实验进行时以交互方式绘制实验进度。为了向功能更强大(也更复杂)的 TensorBoard 致敬,我们将其命名为 ProgressBoard
。其实现推迟到 第 23.7 节。现在,让我们简单地看看它的实际效果。
draw
方法在图中绘制一个点 (x, y)
,并在图例中指定 label
。可选的 every_n
通过仅在图中显示 \(1/n\) 个点来平滑线条。它们的值是从原始图中 \(n\) 个相邻点的平均值计算出来的。
class ProgressBoard(d2l.HyperParameters): #@save
"""The board that plots data points in animation."""
def __init__(self, xlabel=None, ylabel=None, xlim=None,
ylim=None, xscale='linear', yscale='linear',
ls=['-', '--', '-.', ':'], colors=['C0', 'C1', 'C2', 'C3'],
fig=None, axes=None, figsize=(3.5, 2.5), display=True):
self.save_hyperparameters()
def draw(self, x, y, label, every_n=1):
raise NotImplemented
在下面的例子中,我们用不同的平滑度绘制 sin
和 cos
。如果你运行这个代码块,你会看到线条以动画形式增长。
board = d2l.ProgressBoard('x')
for x in np.arange(0, 10, 0.1):
board.draw(x, np.sin(x), 'sin', every_n=2)
board.draw(x, np.cos(x), 'cos', every_n=10)
board = d2l.ProgressBoard('x')
for x in np.arange(0, 10, 0.1):
board.draw(x, np.sin(x), 'sin', every_n=2)
board.draw(x, np.cos(x), 'cos', every_n=10)
board = d2l.ProgressBoard('x')
for x in np.arange(0, 10, 0.1):
board.draw(x, np.sin(x), 'sin', every_n=2)
board.draw(x, np.cos(x), 'cos', every_n=10)
board = d2l.ProgressBoard('x')
for x in np.arange(0, 10, 0.1):
board.draw(x, np.sin(x), 'sin', every_n=2)
board.draw(x, np.cos(x), 'cos', every_n=10)
3.2.2. 模型¶
Module
类是我们所有将要实现的模型的基类。我们至少需要三个方法。第一个是 __init__
,它存储可学习的参数;training_step
方法接受一个数据批次以返回损失值;最后,configure_optimizers
返回用于更新可学习参数的优化方法,或一个优化方法列表。可选地,我们可以定义 validation_step
来报告评估指标。有时我们会将计算输出的代码放入一个单独的 forward
方法中,以使其更具可重用性。
class Module(nn.Module, d2l.HyperParameters): #@save
"""The base class of models."""
def __init__(self, plot_train_per_epoch=2, plot_valid_per_epoch=1):
super().__init__()
self.save_hyperparameters()
self.board = ProgressBoard()
def loss(self, y_hat, y):
raise NotImplementedError
def forward(self, X):
assert hasattr(self, 'net'), 'Neural network is defined'
return self.net(X)
def plot(self, key, value, train):
"""Plot a point in animation."""
assert hasattr(self, 'trainer'), 'Trainer is not inited'
self.board.xlabel = 'epoch'
if train:
x = self.trainer.train_batch_idx / \
self.trainer.num_train_batches
n = self.trainer.num_train_batches / \
self.plot_train_per_epoch
else:
x = self.trainer.epoch + 1
n = self.trainer.num_val_batches / \
self.plot_valid_per_epoch
self.board.draw(x, value.to(d2l.cpu()).detach().numpy(),
('train_' if train else 'val_') + key,
every_n=int(n))
def training_step(self, batch):
l = self.loss(self(*batch[:-1]), batch[-1])
self.plot('loss', l, train=True)
return l
def validation_step(self, batch):
l = self.loss(self(*batch[:-1]), batch[-1])
self.plot('loss', l, train=False)
def configure_optimizers(self):
raise NotImplementedError
你可能会注意到 Module
是 nn.Module
的一个子类,后者是 PyTorch 中神经网络的基类。它为处理神经网络提供了便利的功能。例如,如果我们定义一个 forward
方法,比如 forward(self, X)
,那么对于一个实例 a
,我们可以通过 a(X)
调用这个方法。这是因为它在内置的 __call__
方法中调用了 forward
方法。你可以在 第 6.1 节 中找到更多关于 nn.Module
的细节和例子。
class Module(nn.Block, d2l.HyperParameters): #@save
"""The base class of models."""
def __init__(self, plot_train_per_epoch=2, plot_valid_per_epoch=1):
super().__init__()
self.save_hyperparameters()
self.board = ProgressBoard()
def loss(self, y_hat, y):
raise NotImplementedError
def forward(self, X):
assert hasattr(self, 'net'), 'Neural network is defined'
return self.net(X)
def plot(self, key, value, train):
"""Plot a point in animation."""
assert hasattr(self, 'trainer'), 'Trainer is not inited'
self.board.xlabel = 'epoch'
if train:
x = self.trainer.train_batch_idx / \
self.trainer.num_train_batches
n = self.trainer.num_train_batches / \
self.plot_train_per_epoch
else:
x = self.trainer.epoch + 1
n = self.trainer.num_val_batches / \
self.plot_valid_per_epoch
self.board.draw(x, value.asnumpy(), (
'train_' if train else 'val_') + key, every_n=int(n))
def training_step(self, batch):
l = self.loss(self(*batch[:-1]), batch[-1])
self.plot('loss', l, train=True)
return l
def validation_step(self, batch):
l = self.loss(self(*batch[:-1]), batch[-1])
self.plot('loss', l, train=False)
def configure_optimizers(self):
raise NotImplementedError
您可能注意到,Module
是 nn.Block
的一个子类,后者是 Gluon 中神经网络的基类。它为处理神经网络提供了便利的功能。例如,如果我们定义一个 forward
方法,比如 forward(self, X)
,那么对于一个实例 a
,我们可以通过 a(X)
调用这个方法。这是因为它在内置的 __call__
方法中调用了 forward
方法。您可以在 第 6.1 节 中找到更多关于 nn.Block
的细节和示例。
随着 Python 3.7 中 dataclasses 的引入,用 @dataclass
装饰的类会自动添加诸如 __init__
和 __repr__
之类的魔法方法。成员变量使用类型注解来定义。所有的 Flax 模块都是 Python 3.7 的数据类。
class Module(nn.Module, d2l.HyperParameters): #@save
"""The base class of models."""
# No need for save_hyperparam when using Python dataclass
plot_train_per_epoch: int = field(default=2, init=False)
plot_valid_per_epoch: int = field(default=1, init=False)
# Use default_factory to make sure new plots are generated on each run
board: ProgressBoard = field(default_factory=lambda: ProgressBoard(),
init=False)
def loss(self, y_hat, y):
raise NotImplementedError
# JAX & Flax do not have a forward-method-like syntax. Flax uses setup
# and built-in __call__ magic methods for forward pass. Adding here
# for consistency
def forward(self, X, *args, **kwargs):
assert hasattr(self, 'net'), 'Neural network is defined'
return self.net(X, *args, **kwargs)
def __call__(self, X, *args, **kwargs):
return self.forward(X, *args, **kwargs)
def plot(self, key, value, train):
"""Plot a point in animation."""
assert hasattr(self, 'trainer'), 'Trainer is not inited'
self.board.xlabel = 'epoch'
if train:
x = self.trainer.train_batch_idx / \
self.trainer.num_train_batches
n = self.trainer.num_train_batches / \
self.plot_train_per_epoch
else:
x = self.trainer.epoch + 1
n = self.trainer.num_val_batches / \
self.plot_valid_per_epoch
self.board.draw(x, jax.device_put(value, d2l.cpu()),
('train_' if train else 'val_') + key,
every_n=int(n))
def training_step(self, params, batch, state):
l, grads = jax.value_and_grad(self.loss)(params, batch[:-1],
batch[-1], state)
self.plot("loss", l, train=True)
return l, grads
def validation_step(self, params, batch, state):
l = self.loss(params, batch[:-1], batch[-1], state)
self.plot('loss', l, train=False)
def apply_init(self, dummy_input, key):
"""To be defined later in :numref:`sec_lazy_init`"""
raise NotImplementedError
def configure_optimizers(self):
raise NotImplementedError
您可能注意到,Module
是 linen.Module
的一个子类,后者是 Flax 中神经网络的基类。它为处理神经网络提供了便利的功能。例如,它处理模型参数,提供 nn.compact
装饰器来简化代码,调用 __call__
方法等等。这里我们也将 __call__
重定向到 forward
方法。我们这样做是为了使我们的代码与其他框架的实现更相似。
class Module(tf.keras.Model, d2l.HyperParameters): #@save
"""The base class of models."""
def __init__(self, plot_train_per_epoch=2, plot_valid_per_epoch=1):
super().__init__()
self.save_hyperparameters()
self.board = ProgressBoard()
self.training = None
def loss(self, y_hat, y):
raise NotImplementedError
def forward(self, X):
assert hasattr(self, 'net'), 'Neural network is defined'
return self.net(X)
def call(self, X, *args, **kwargs):
if kwargs and "training" in kwargs:
self.training = kwargs['training']
return self.forward(X, *args)
def plot(self, key, value, train):
"""Plot a point in animation."""
assert hasattr(self, 'trainer'), 'Trainer is not inited'
self.board.xlabel = 'epoch'
if train:
x = self.trainer.train_batch_idx / \
self.trainer.num_train_batches
n = self.trainer.num_train_batches / \
self.plot_train_per_epoch
else:
x = self.trainer.epoch + 1
n = self.trainer.num_val_batches / \
self.plot_valid_per_epoch
self.board.draw(x, value.numpy(), (
'train_' if train else 'val_') + key, every_n=int(n))
def training_step(self, batch):
l = self.loss(self(*batch[:-1]), batch[-1])
self.plot('loss', l, train=True)
return l
def validation_step(self, batch):
l = self.loss(self(*batch[:-1]), batch[-1])
self.plot('loss', l, train=False)
def configure_optimizers(self):
raise NotImplementedError
您可能注意到,Module
是 tf.keras.Model
的一个子类,后者是 TensorFlow 中神经网络的基类。它为处理神经网络提供了便利的功能。例如,它在内置的 __call__
方法中调用 call
方法。这里我们将 call
重定向到 forward
方法,并将其参数保存为类属性。我们这样做是为了使我们的代码与其他框架的实现更相似。
3.2.3. 数据¶
DataModule
类是数据的基类。通常,__init__
方法用于准备数据。这包括在需要时下载和预处理。 train_dataloader
返回训练数据集的数据加载器。数据加载器是一个(Python)生成器,每次使用时都会产生一个数据批次。然后将此批次送入 Module
的 training_step
方法以计算损失。有一个可选的 val_dataloader
用于返回验证数据集加载器。它的行为方式相同,只是它为 Module
中的 validation_step
方法产生数据批次。
class DataModule(d2l.HyperParameters): #@save
"""The base class of data."""
def __init__(self, root='../data', num_workers=4):
self.save_hyperparameters()
def get_dataloader(self, train):
raise NotImplementedError
def train_dataloader(self):
return self.get_dataloader(train=True)
def val_dataloader(self):
return self.get_dataloader(train=False)
class DataModule(d2l.HyperParameters): #@save
"""The base class of data."""
def __init__(self, root='../data', num_workers=4):
self.save_hyperparameters()
def get_dataloader(self, train):
raise NotImplementedError
def train_dataloader(self):
return self.get_dataloader(train=True)
def val_dataloader(self):
return self.get_dataloader(train=False)
class DataModule(d2l.HyperParameters): #@save
"""The base class of data."""
def __init__(self, root='../data'):
self.save_hyperparameters()
def get_dataloader(self, train):
raise NotImplementedError
def train_dataloader(self):
return self.get_dataloader(train=True)
def val_dataloader(self):
return self.get_dataloader(train=False)
class DataModule(d2l.HyperParameters): #@save
"""The base class of data."""
def __init__(self, root='../data'):
self.save_hyperparameters()
def get_dataloader(self, train):
raise NotImplementedError
def train_dataloader(self):
return self.get_dataloader(train=True)
def val_dataloader(self):
return self.get_dataloader(train=False)
3.2.4. 训练¶
Trainer
类使用在 DataModule
中指定的数据来训练 Module
类中的可学习参数。关键方法是 fit
,它接受两个参数:model
(Module
的实例)和 data
(DataModule
的实例)。然后,它会遍历整个数据集 max_epochs
次来训练模型。和之前一样,我们将把这个方法的实现推迟到后面的章节。
class Trainer(d2l.HyperParameters): #@save
"""The base class for training models with data."""
def __init__(self, max_epochs, num_gpus=0, gradient_clip_val=0):
self.save_hyperparameters()
assert num_gpus == 0, 'No GPU support yet'
def prepare_data(self, data):
self.train_dataloader = data.train_dataloader()
self.val_dataloader = data.val_dataloader()
self.num_train_batches = len(self.train_dataloader)
self.num_val_batches = (len(self.val_dataloader)
if self.val_dataloader is not None else 0)
def prepare_model(self, model):
model.trainer = self
model.board.xlim = [0, self.max_epochs]
self.model = model
def fit(self, model, data):
self.prepare_data(data)
self.prepare_model(model)
self.optim = model.configure_optimizers()
self.epoch = 0
self.train_batch_idx = 0
self.val_batch_idx = 0
for self.epoch in range(self.max_epochs):
self.fit_epoch()
def fit_epoch(self):
raise NotImplementedError
Trainer
类使用在 DataModule
中指定的数据来训练 Module
类中的可学习参数。关键方法是 fit
,它接受两个参数:model
(Module
的实例)和 data
(DataModule
的实例)。然后,它会遍历整个数据集 max_epochs
次来训练模型。和之前一样,我们将把这个方法的实现推迟到后面的章节。
class Trainer(d2l.HyperParameters): #@save
"""The base class for training models with data."""
def __init__(self, max_epochs, num_gpus=0, gradient_clip_val=0):
self.save_hyperparameters()
assert num_gpus == 0, 'No GPU support yet'
def prepare_data(self, data):
self.train_dataloader = data.train_dataloader()
self.val_dataloader = data.val_dataloader()
self.num_train_batches = len(self.train_dataloader)
self.num_val_batches = (len(self.val_dataloader)
if self.val_dataloader is not None else 0)
def prepare_model(self, model):
model.trainer = self
model.board.xlim = [0, self.max_epochs]
self.model = model
def fit(self, model, data):
self.prepare_data(data)
self.prepare_model(model)
self.optim = model.configure_optimizers()
self.epoch = 0
self.train_batch_idx = 0
self.val_batch_idx = 0
for self.epoch in range(self.max_epochs):
self.fit_epoch()
def fit_epoch(self):
raise NotImplementedError
Trainer
类使用在 DataModule
中指定的数据训练可学习参数 params
。关键方法是 fit
,它接受三个参数:model
(Module
的实例),data
(DataModule
的实例),以及 key
(一个 JAX PRNGKeyArray
)。我们在这里让 key
参数可选以简化接口,但在 JAX 和 Flax 中,始终建议传递并使用根密钥初始化模型参数。然后它会遍历整个数据集 max_epochs
次来训练模型。和之前一样,我们将把这个方法的实现推迟到后面的章节。
class Trainer(d2l.HyperParameters): #@save
"""The base class for training models with data."""
def __init__(self, max_epochs, num_gpus=0, gradient_clip_val=0):
self.save_hyperparameters()
assert num_gpus == 0, 'No GPU support yet'
def prepare_data(self, data):
self.train_dataloader = data.train_dataloader()
self.val_dataloader = data.val_dataloader()
self.num_train_batches = len(self.train_dataloader)
self.num_val_batches = (len(self.val_dataloader)
if self.val_dataloader is not None else 0)
def prepare_model(self, model):
model.trainer = self
model.board.xlim = [0, self.max_epochs]
self.model = model
def fit(self, model, data, key=None):
self.prepare_data(data)
self.prepare_model(model)
self.optim = model.configure_optimizers()
if key is None:
root_key = d2l.get_key()
else:
root_key = key
params_key, dropout_key = jax.random.split(root_key)
key = {'params': params_key, 'dropout': dropout_key}
dummy_input = next(iter(self.train_dataloader))[:-1]
variables = model.apply_init(dummy_input, key=key)
params = variables['params']
if 'batch_stats' in variables.keys():
# Here batch_stats will be used later (e.g., for batch norm)
batch_stats = variables['batch_stats']
else:
batch_stats = {}
# Flax uses optax under the hood for a single state obj TrainState.
# More will be discussed later in the dropout and batch
# normalization section
class TrainState(train_state.TrainState):
batch_stats: Any
dropout_rng: jax.random.PRNGKeyArray
self.state = TrainState.create(apply_fn=model.apply,
params=params,
batch_stats=batch_stats,
dropout_rng=dropout_key,
tx=model.configure_optimizers())
self.epoch = 0
self.train_batch_idx = 0
self.val_batch_idx = 0
for self.epoch in range(self.max_epochs):
self.fit_epoch()
def fit_epoch(self):
raise NotImplementedError
Trainer
类使用在 DataModule
中指定的数据来训练 Module
类中的可学习参数。关键方法是 fit
,它接受两个参数:model
(Module
的实例)和 data
(DataModule
的实例)。然后,它会遍历整个数据集 max_epochs
次来训练模型。和之前一样,我们将把这个方法的实现推迟到后面的章节。
class Trainer(d2l.HyperParameters): #@save
"""The base class for training models with data."""
def __init__(self, max_epochs, num_gpus=0, gradient_clip_val=0):
self.save_hyperparameters()
assert num_gpus == 0, 'No GPU support yet'
def prepare_data(self, data):
self.train_dataloader = data.train_dataloader()
self.val_dataloader = data.val_dataloader()
self.num_train_batches = len(self.train_dataloader)
self.num_val_batches = (len(self.val_dataloader)
if self.val_dataloader is not None else 0)
def prepare_model(self, model):
model.trainer = self
model.board.xlim = [0, self.max_epochs]
self.model = model
def fit(self, model, data):
self.prepare_data(data)
self.prepare_model(model)
self.optim = model.configure_optimizers()
self.epoch = 0
self.train_batch_idx = 0
self.val_batch_idx = 0
for self.epoch in range(self.max_epochs):
self.fit_epoch()
def fit_epoch(self):
raise NotImplementedError