3.4. 线性回归的从零开始实现
在 Colab 中打开 Notebook
在 Colab 中打开 Notebook
在 Colab 中打开 Notebook
在 Colab 中打开 Notebook
在 SageMaker Studio Lab 中打开 Notebook

我们现在准备好动手实现一个功能完整的线性回归。在本节中,我们将从零开始实现整个方法,包括 (i) 模型;(ii) 损失函数;(iii) 小批量随机梯度下降优化器;以及 (iv) 将所有这些部分拼接在一起的训练函数。最后,我们将运行 3.3节 中的合成数据生成器,并将我们的模型应用到生成的数据集上。虽然现代深度学习框架可以自动完成几乎所有这些工作,但从零开始实现是确保你真正理解自己在做什么的唯一方法。此外,当需要自定义模型、定义我们自己的层或损失函数时,了解底层的工作原理将非常有用。在本节中,我们将只依赖张量和自动微分。稍后,我们将介绍一种更简洁的实现方式,利用深度学习框架的各种便利功能,同时保留下面将要介绍的结构。

%matplotlib inline
import torch
from d2l import torch as d2l
%matplotlib inline
from mxnet import autograd, np, npx
from d2l import mxnet as d2l

npx.set_np()
%matplotlib inline
import jax
import optax
from flax import linen as nn
from jax import numpy as jnp
from d2l import jax as d2l
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
%matplotlib inline
import tensorflow as tf
from d2l import tensorflow as d2l

3.4.1. 定义模型

在我们开始通过小批量随机梯度下降优化模型参数之前,我们首先需要有一些参数。在下面的代码中,我们通过从均值为0、标准差为0.01的正态分布中抽取随机数来初始化权重。神奇数字0.01在实践中通常效果很好,但你可以通过参数 sigma 指定一个不同的值。此外,我们将偏置设置为0。请注意,为了进行面向对象设计,我们将代码添加到一个 d2l.Module 的子类的 __init__ 方法中(在 3.2.2节 中介绍)。

class LinearRegressionScratch(d2l.Module):  #@save
    """The linear regression model implemented from scratch."""
    def __init__(self, num_inputs, lr, sigma=0.01):
        super().__init__()
        self.save_hyperparameters()
        self.w = torch.normal(0, sigma, (num_inputs, 1), requires_grad=True)
        self.b = torch.zeros(1, requires_grad=True)
class LinearRegressionScratch(d2l.Module):  #@save
    """The linear regression model implemented from scratch."""
    def __init__(self, num_inputs, lr, sigma=0.01):
        super().__init__()
        self.save_hyperparameters()
        self.w = np.random.normal(0, sigma, (num_inputs, 1))
        self.b = np.zeros(1)
        self.w.attach_grad()
        self.b.attach_grad()
class LinearRegressionScratch(d2l.Module):  #@save
    """The linear regression model implemented from scratch."""
    num_inputs: int
    lr: float
    sigma: float = 0.01

    def setup(self):
        self.w = self.param('w', nn.initializers.normal(self.sigma),
                            (self.num_inputs, 1))
        self.b = self.param('b', nn.initializers.zeros, (1))
class LinearRegressionScratch(d2l.Module):  #@save
    """The linear regression model implemented from scratch."""
    def __init__(self, num_inputs, lr, sigma=0.01):
        super().__init__()
        self.save_hyperparameters()
        w = tf.random.normal((num_inputs, 1), mean=0, stddev=0.01)
        b = tf.zeros(1)
        self.w = tf.Variable(w, trainable=True)
        self.b = tf.Variable(b, trainable=True)

接下来,我们必须定义我们的模型,将其输入和参数与其输出关联起来。使用与 (3.1.4) 中相同的符号表示我们的线性模型,我们只需将输入特征 \(\mathbf{X}\) 和模型权重 \(\mathbf{w}\) 进行矩阵-向量乘积,并为每个样本加上偏移量 \(b\)。乘积 \(\mathbf{Xw}\) 是一个向量,而 \(b\) 是一个标量。由于广播机制(见 2.1.4节),当我们相加一个向量和一个标量时,标量会被加到向量的每个分量上。得到的 forward 方法通过 add_to_class(在 3.2.1节 中介绍)注册在 LinearRegressionScratch 类中。

@d2l.add_to_class(LinearRegressionScratch)  #@save
def forward(self, X):
    return torch.matmul(X, self.w) + self.b
@d2l.add_to_class(LinearRegressionScratch)  #@save
def forward(self, X):
    return np.dot(X, self.w) + self.b
@d2l.add_to_class(LinearRegressionScratch)  #@save
def forward(self, X):
    return jnp.matmul(X, self.w) + self.b
@d2l.add_to_class(LinearRegressionScratch)  #@save
def forward(self, X):
    return tf.matmul(X, self.w) + self.b

3.4.2. 定义损失函数

因为更新我们的模型需要计算损失函数的梯度,所以我们应该首先定义损失函数。这里我们使用 (3.1.5) 中的平方损失函数。在实现中,我们需要将真实值 y 转换成预测值 y_hat 的形状。以下方法返回的结果也将具有与 y_hat 相同的形状。我们还返回了小批量中所有样本的平均损失值。

@d2l.add_to_class(LinearRegressionScratch)  #@save
def loss(self, y_hat, y):
    l = (y_hat - y) ** 2 / 2
    return l.mean()
@d2l.add_to_class(LinearRegressionScratch)  #@save
def loss(self, y_hat, y):
    l = (y_hat - y) ** 2 / 2
    return l.mean()
@d2l.add_to_class(LinearRegressionScratch)  #@save
def loss(self, params, X, y, state):
    y_hat = state.apply_fn({'params': params}, *X)  # X unpacked from a tuple
    l = (y_hat - y.reshape(y_hat.shape)) ** 2 / 2
    return l.mean()
@d2l.add_to_class(LinearRegressionScratch)  #@save
def loss(self, y_hat, y):
    l = (y_hat - y) ** 2 / 2
    return tf.reduce_mean(l)

3.4.3. 定义优化算法

正如在 3.1节 中讨论的,线性回归有一个闭式解。然而,我们这里的目标是展示如何训练更通用的神经网络,这需要我们教你如何使用小批量随机梯度下降。因此,我们将借此机会介绍你的第一个可工作的SGD示例。在每一步中,使用从我们的数据集中随机抽取的一个小批量,我们估计损失关于参数的梯度。接下来,我们沿着可能减少损失的方向更新参数。

以下代码在给定一组参数和学习率 lr 的情况下应用更新。由于我们的损失是作为小批量的平均值计算的,我们不需要根据批量大小调整学习率。在后面的章节中,我们将研究在分布式大规模学习中出现的非常大的小批量时,应如何调整学习率。目前,我们可以忽略这种依赖性。

我们定义我们的 SGD 类,它是 d2l.HyperParameters(在 3.2.1节 中介绍)的一个子类,以具有与内置SGD优化器类似的API。我们在 step 方法中更新参数。zero_grad 方法将所有梯度设置为0,这必须在反向传播步骤之前运行。

class SGD(d2l.HyperParameters):  #@save
    """Minibatch stochastic gradient descent."""
    def __init__(self, params, lr):
        self.save_hyperparameters()

    def step(self):
        for param in self.params:
            param -= self.lr * param.grad

    def zero_grad(self):
        for param in self.params:
            if param.grad is not None:
                param.grad.zero_()

我们定义我们的 SGD 类,它是 d2l.HyperParameters(在 3.2.1节 中介绍)的一个子类,以具有与内置SGD优化器类似的API。我们在 step 方法中更新参数。它接受一个可以忽略的 batch_size 参数。

class SGD(d2l.HyperParameters):  #@save
    """Minibatch stochastic gradient descent."""
    def __init__(self, params, lr):
        self.save_hyperparameters()

    def step(self, _):
        for param in self.params:
            param -= self.lr * param.grad
class SGD(d2l.HyperParameters):  #@save
    """Minibatch stochastic gradient descent."""
    # The key transformation of Optax is the GradientTransformation
    # defined by two methods, the init and the update.
    # The init initializes the state and the update transforms the gradients.
    # https://github.com/deepmind/optax/blob/master/optax/_src/transform.py
    def __init__(self, lr):
        self.save_hyperparameters()

    def init(self, params):
        # Delete unused params
        del params
        return optax.EmptyState

    def update(self, updates, state, params=None):
        del params
        # When state.apply_gradients method is called to update flax's
        # train_state object, it internally calls optax.apply_updates method
        # adding the params to the update equation defined below.
        updates = jax.tree_util.tree_map(lambda g: -self.lr * g, updates)
        return updates, state

    def __call__():
        return optax.GradientTransformation(self.init, self.update)

我们定义我们的 SGD 类,它是 d2l.HyperParameters(在 3.2.1节 中介绍)的一个子类,以具有与内置SGD优化器类似的API。我们在 apply_gradients 方法中更新参数。它接受一个参数和梯度对的列表。

class SGD(d2l.HyperParameters):  #@save
    """Minibatch stochastic gradient descent."""
    def __init__(self, lr):
        self.save_hyperparameters()

    def apply_gradients(self, grads_and_vars):
        for grad, param in grads_and_vars:
            param.assign_sub(self.lr * grad)

接下来我们定义 configure_optimizers 方法,它返回一个 SGD 类的实例。

@d2l.add_to_class(LinearRegressionScratch)  #@save
def configure_optimizers(self):
    return SGD([self.w, self.b], self.lr)
@d2l.add_to_class(LinearRegressionScratch)  #@save
def configure_optimizers(self):
    return SGD([self.w, self.b], self.lr)
@d2l.add_to_class(LinearRegressionScratch)  #@save
def configure_optimizers(self):
    return SGD(self.lr)
@d2l.add_to_class(LinearRegressionScratch)  #@save
def configure_optimizers(self):
    return SGD(self.lr)

3.4.4. 训练

现在我们已经准备好了所有的部分(参数、损失函数、模型和优化器),我们可以实现主要的训练循环了。完全理解这段代码至关重要,因为你将为本书中涵盖的所有其他深度学习模型使用类似的训练循环。在每个*轮次*(epoch)中,我们遍历整个训练数据集,对每个样本过一遍(假设样本数量可以被批量大小整除)。在每次*迭代*(iteration)中,我们抓取一个小批量的训练样本,并通过模型的 training_step 方法计算其损失。然后我们计算关于每个参数的梯度。最后,我们将调用优化算法来更新模型参数。总而言之,我们将执行以下循环:

  • 初始化参数 \((\mathbf{w}, b)\)

  • 重复直到完成

    • 计算梯度 \(\mathbf{g} \leftarrow \partial_{(\mathbf{w},b)} \frac{1}{|\mathcal{B}|} \sum_{i \in \mathcal{B}} l(\mathbf{x}^{(i)}, y^{(i)}, \mathbf{w}, b)\)

    • 更新参数 \((\mathbf{w}, b) \leftarrow (\mathbf{w}, b) - \eta \mathbf{g}\)

回想一下,我们在 3.3节 中生成的合成回归数据集没有提供验证数据集。然而,在大多数情况下,我们希望有一个验证数据集来衡量我们的模型质量。这里,我们在每个轮次中对验证数据加载器过一遍,以衡量模型性能。遵循我们的面向对象设计,prepare_batchfit_epoch 方法注册在 d2l.Trainer 类中(在 3.2.4节 中介绍)。

@d2l.add_to_class(d2l.Trainer)  #@save
def prepare_batch(self, batch):
    return batch

@d2l.add_to_class(d2l.Trainer)  #@save
def fit_epoch(self):
    self.model.train()
    for batch in self.train_dataloader:
        loss = self.model.training_step(self.prepare_batch(batch))
        self.optim.zero_grad()
        with torch.no_grad():
            loss.backward()
            if self.gradient_clip_val > 0:  # To be discussed later
                self.clip_gradients(self.gradient_clip_val, self.model)
            self.optim.step()
        self.train_batch_idx += 1
    if self.val_dataloader is None:
        return
    self.model.eval()
    for batch in self.val_dataloader:
        with torch.no_grad():
            self.model.validation_step(self.prepare_batch(batch))
        self.val_batch_idx += 1
@d2l.add_to_class(d2l.Trainer)  #@save
def prepare_batch(self, batch):
    return batch

@d2l.add_to_class(d2l.Trainer)  #@save
def fit_epoch(self):
    for batch in self.train_dataloader:
        with autograd.record():
            loss = self.model.training_step(self.prepare_batch(batch))
        loss.backward()
        if self.gradient_clip_val > 0:
            self.clip_gradients(self.gradient_clip_val, self.model)
        self.optim.step(1)
        self.train_batch_idx += 1
    if self.val_dataloader is None:
        return
    for batch in self.val_dataloader:
        self.model.validation_step(self.prepare_batch(batch))
        self.val_batch_idx += 1
@d2l.add_to_class(d2l.Trainer)  #@save
def prepare_batch(self, batch):
    return batch

@d2l.add_to_class(d2l.Trainer)  #@save
def fit_epoch(self):
    self.model.training = True
    if self.state.batch_stats:
        # Mutable states will be used later (e.g., for batch norm)
        for batch in self.train_dataloader:
            (_, mutated_vars), grads = self.model.training_step(self.state.params,
                                                           self.prepare_batch(batch),
                                                           self.state)
            self.state = self.state.apply_gradients(grads=grads)
            # Can be ignored for models without Dropout Layers
            self.state = self.state.replace(
                dropout_rng=jax.random.split(self.state.dropout_rng)[0])
            self.state = self.state.replace(batch_stats=mutated_vars['batch_stats'])
            self.train_batch_idx += 1
    else:
        for batch in self.train_dataloader:
            _, grads = self.model.training_step(self.state.params,
                                                self.prepare_batch(batch),
                                                self.state)
            self.state = self.state.apply_gradients(grads=grads)
            # Can be ignored for models without Dropout Layers
            self.state = self.state.replace(
                dropout_rng=jax.random.split(self.state.dropout_rng)[0])
            self.train_batch_idx += 1

    if self.val_dataloader is None:
        return
    self.model.training = False
    for batch in self.val_dataloader:
        self.model.validation_step(self.state.params,
                                   self.prepare_batch(batch),
                                   self.state)
        self.val_batch_idx += 1
@d2l.add_to_class(d2l.Trainer)  #@save
def prepare_batch(self, batch):
    return batch

@d2l.add_to_class(d2l.Trainer)  #@save
def fit_epoch(self):
    self.model.training = True
    for batch in self.train_dataloader:
        with tf.GradientTape() as tape:
            loss = self.model.training_step(self.prepare_batch(batch))
        grads = tape.gradient(loss, self.model.trainable_variables)
        if self.gradient_clip_val > 0:
            grads = self.clip_gradients(self.gradient_clip_val, grads)
        self.optim.apply_gradients(zip(grads, self.model.trainable_variables))
        self.train_batch_idx += 1
    if self.val_dataloader is None:
        return
    self.model.training = False
    for batch in self.val_dataloader:
        self.model.validation_step(self.prepare_batch(batch))
        self.val_batch_idx += 1

我们几乎准备好训练模型了,但首先我们需要一些训练数据。这里我们使用 SyntheticRegressionData 类并传入一些真实参数。然后我们用学习率 lr=0.03max_epochs=3 来训练我们的模型。请注意,通常情况下,轮次数和学习率都是超参数。一般而言,设置超参数很棘手,我们通常会希望使用三路划分,一组用于训练,第二组用于超参数选择,第三组保留用于最终评估。我们现在省略这些细节,但稍后会重新讨论它们。

model = LinearRegressionScratch(2, lr=0.03)
data = d2l.SyntheticRegressionData(w=torch.tensor([2, -3.4]), b=4.2)
trainer = d2l.Trainer(max_epochs=3)
trainer.fit(model, data)
../_images/output_linear-regression-scratch_b94b4e_111_0.svg
model = LinearRegressionScratch(2, lr=0.03)
data = d2l.SyntheticRegressionData(w=np.array([2, -3.4]), b=4.2)
trainer = d2l.Trainer(max_epochs=3)
trainer.fit(model, data)
../_images/output_linear-regression-scratch_b94b4e_114_0.svg
model = LinearRegressionScratch(2, lr=0.03)
data = d2l.SyntheticRegressionData(w=jnp.array([2, -3.4]), b=4.2)
trainer = d2l.Trainer(max_epochs=3)
trainer.fit(model, data)
../_images/output_linear-regression-scratch_b94b4e_117_0.svg
model = LinearRegressionScratch(2, lr=0.03)
data = d2l.SyntheticRegressionData(w=tf.constant([2, -3.4]), b=4.2)
trainer = d2l.Trainer(max_epochs=3)
trainer.fit(model, data)
../_images/output_linear-regression-scratch_b94b4e_120_0.svg

因为我们是自己合成的数据集,所以我们精确地知道真实的参数是什么。因此,我们可以通过比较真实参数和我们通过训练循环学到的参数来评估我们训练的成功程度。事实上,它们最终非常接近。

with torch.no_grad():
    print(f'error in estimating w: {data.w - model.w.reshape(data.w.shape)}')
    print(f'error in estimating b: {data.b - model.b}')
error in estimating w: tensor([ 0.1408, -0.1493])
error in estimating b: tensor([0.2130])
print(f'error in estimating w: {data.w - model.w.reshape(data.w.shape)}')
print(f'error in estimating b: {data.b - model.b}')
error in estimating w: [ 0.11080897 -0.12691855]
error in estimating b: [0.19214153]
params = trainer.state.params
print(f"error in estimating w: {data.w - params['w'].reshape(data.w.shape)}")
print(f"error in estimating b: {data.b - params['b']}")
error in estimating w: [ 0.07147813 -0.19255161]
error in estimating b: [0.24309802]
print(f'error in estimating w: {data.w - tf.reshape(model.w, data.w.shape)}')
print(f'error in estimating b: {data.b - model.b}')
error in estimating w: [ 0.11899102 -0.21494818]
error in estimating b: [0.2501297]

我们不应该理所当然地认为能够精确地恢复真实参数。总的来说,对于深度模型,参数不存在唯一解,即使对于线性模型,也只有当没有特征与其他特征线性相关时,才可能精确恢复参数。然而,在机器学习中,我们通常更关心的是找到能够带来高精度预测的参数,而不是恢复真实的底层参数 (Vapnik, 1992)。幸运的是,即使在困难的优化问题上,随机梯度下降也常常能找到非常好的解,这部分是由于对于深度网络,存在许多导致高精度预测的参数配置。

3.4.5. 总结

在本节中,我们通过实现一个功能齐全的神经网络模型和训练循环,朝着设计深度学习系统迈出了重要的一步。在这个过程中,我们构建了一个数据加载器、一个模型、一个损失函数、一个优化过程以及一个可视化和监控工具。我们通过组合一个包含所有训练模型相关组件的Python对象来完成此操作。虽然这还不是一个专业级的实现,但它功能完善,像这样的代码已经可以帮助您快速解决小问题。在接下来的章节中,我们将看到如何*更简洁地*(避免样板代码)和*更高效地*(充分利用我们的GPU)来完成这项工作。

3.4.6. 练习

  1. 如果我们将权重初始化为零,会发生什么?算法还能工作吗?如果我们将参数初始化为方差为 \(1000\) 而不是 \(0.01\) 会怎样?

  2. 假设你是 格奥尔格·西蒙·欧姆,试图为电阻建立一个关联电压和电流的模型。你能使用自动微分来学习你的模型的参数吗?

  3. 你能否使用 普朗克定律 来确定一个物体的温度?作为参考,黑体辐射的光谱密度 \(B\)\(B(\lambda, T) = \frac{2 hc^2}{\lambda^5} \cdot \left(\exp \frac{h c}{\lambda k T} - 1\right)^{-1}\)。这里 \(\lambda\) 是波长,\(T\) 是温度,\(c\) 是光速,\(h\) 是普朗克常数,\(k\) 是玻尔兹曼常数。你测量了不同波长 \(\lambda\) 下的能量,现在你需要将光谱密度曲线拟合到普朗克定律。

  4. 如果你想计算损失的二阶导数,可能会遇到什么问题?你会如何解决它们?

  5. 为什么在 loss 函数中需要 reshape 方法?

  6. 尝试使用不同的学习率进行实验,看看损失函数值下降的速度有多快。你能通过增加训练的轮次数来减少误差吗?

  7. 如果样本数量不能被批量大小整除,在一个轮次结束时,data_iter 会发生什么?

  8. 尝试实现一个不同的损失函数,比如绝对值损失 (y_hat - d2l.reshape(y, y_hat.shape)).abs().sum()

    1. 检查对于常规数据会发生什么。

    2. 检查如果你主动扰动 \(\mathbf{y}\) 的一些条目,例如 \(y_5 = 10000\),行为是否会有所不同。

    3. 你能想出一个廉价的解决方案来结合平方损失和绝对值损失的最佳方面吗?提示:你如何避免非常大的梯度值?

  9. 我们为什么需要重新打乱数据集?你能设计一个案例,其中一个恶意构造的数据集会破坏优化算法吗?