4.3. 分类模型基类
在 Colab 中打开 Notebook
在 Colab 中打开 Notebook
在 Colab 中打开 Notebook
在 Colab 中打开 Notebook
在 SageMaker Studio Lab 中打开 Notebook

你可能已经注意到,在回归问题中,从零开始的实现和使用框架功能的简洁实现非常相似。分类问题也是如此。由于本书中的许多模型都处理分类问题,因此有必要添加专门支持此设置的功能。本节提供了一个分类模型的基类,以简化未来的代码。

import torch
from d2l import torch as d2l
from mxnet import autograd, gluon, np, npx
from d2l import mxnet as d2l

npx.set_np()
from functools import partial
import jax
import optax
from jax import numpy as jnp
from d2l import jax as d2l
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
import tensorflow as tf
from d2l import tensorflow as d2l

4.3.1. Classifier

我们在下面定义了 Classifier 类。在 validation_step 中,我们报告验证批次上的损失值和分类准确率。我们每隔 num_val_batches 个批次进行一次更新。这样做的好处是可以生成整个验证数据上的平均损失和准确率。如果最后一个批次包含的样本较少,这些平均数并不完全准确,但为了保持代码简洁,我们忽略了这个微小的差异。

class Classifier(d2l.Module):  #@save
    """The base class of classification models."""
    def validation_step(self, batch):
        Y_hat = self(*batch[:-1])
        self.plot('loss', self.loss(Y_hat, batch[-1]), train=False)
        self.plot('acc', self.accuracy(Y_hat, batch[-1]), train=False)

我们在下面定义了 Classifier 类。在 validation_step 中,我们报告验证批次上的损失值和分类准确率。我们每隔 num_val_batches 个批次进行一次更新。这样做的好处是可以生成整个验证数据上的平均损失和准确率。如果最后一个批次包含的样本较少,这些平均数并不完全准确,但为了保持代码简洁,我们忽略了这个微小的差异。

class Classifier(d2l.Module):  #@save
    """The base class of classification models."""
    def validation_step(self, batch):
        Y_hat = self(*batch[:-1])
        self.plot('loss', self.loss(Y_hat, batch[-1]), train=False)
        self.plot('acc', self.accuracy(Y_hat, batch[-1]), train=False)

我们在下面定义了 Classifier 类。在 validation_step 中,我们报告验证批次上的损失值和分类准确率。我们每隔 num_val_batches 个批次进行一次更新。这样做的好处是可以生成整个验证数据上的平均损失和准确率。如果最后一个批次包含的样本较少,这些平均数并不完全准确,但为了保持代码简洁,我们忽略了这个微小的差异。

我们还为JAX重新定义了 training_step 方法,因为之后所有将继承 Classifier 的模型都会有一个返回辅助数据的损失函数。这些辅助数据可以用于带批量归一化的模型(将在 8.5节 中解释),而在所有其他情况下,我们将使损失函数也返回一个占位符(空字典)来表示辅助数据。

class Classifier(d2l.Module):  #@save
    """The base class of classification models."""
    def training_step(self, params, batch, state):
        # Here value is a tuple since models with BatchNorm layers require
        # the loss to return auxiliary data
        value, grads = jax.value_and_grad(
            self.loss, has_aux=True)(params, batch[:-1], batch[-1], state)
        l, _ = value
        self.plot("loss", l, train=True)
        return value, grads

    def validation_step(self, params, batch, state):
        # Discard the second returned value. It is used for training models
        # with BatchNorm layers since loss also returns auxiliary data
        l, _ = self.loss(params, batch[:-1], batch[-1], state)
        self.plot('loss', l, train=False)
        self.plot('acc', self.accuracy(params, batch[:-1], batch[-1], state),
                  train=False)

我们在下面定义了 Classifier 类。在 validation_step 中,我们报告验证批次上的损失值和分类准确率。我们每隔 num_val_batches 个批次进行一次更新。这样做的好处是可以生成整个验证数据上的平均损失和准确率。如果最后一个批次包含的样本较少,这些平均数并不完全准确,但为了保持代码简洁,我们忽略了这个微小的差异。

class Classifier(d2l.Module):  #@save
    """The base class of classification models."""
    def validation_step(self, batch):
        Y_hat = self(*batch[:-1])
        self.plot('loss', self.loss(Y_hat, batch[-1]), train=False)
        self.plot('acc', self.accuracy(Y_hat, batch[-1]), train=False)

默认情况下,我们使用随机梯度下降优化器,对小批量数据进行操作,就像我们在线性回归的上下文中做的那样。

@d2l.add_to_class(d2l.Module)  #@save
def configure_optimizers(self):
    return torch.optim.SGD(self.parameters(), lr=self.lr)
@d2l.add_to_class(d2l.Module)  #@save
def configure_optimizers(self):
    params = self.parameters()
    if isinstance(params, list):
        return d2l.SGD(params, self.lr)
    return gluon.Trainer(params, 'sgd', {'learning_rate': self.lr})
@d2l.add_to_class(d2l.Module)  #@save
def configure_optimizers(self):
    return optax.sgd(self.lr)
@d2l.add_to_class(d2l.Module)  #@save
def configure_optimizers(self):
    return tf.keras.optimizers.SGD(self.lr)

4.3.2. 准确率

给定预测的概率分布 y_hat,当我们必须输出一个硬性预测时,我们通常选择预测概率最高的类别。事实上,许多应用都要求我们做出选择。例如,Gmail必须将一封电子邮件分类为“主要”、“社交”、“推广”、“论坛”或“垃圾邮件”。它可能在内部估计概率,但最终必须在这些类别中选择一个。

当预测与标签类别 y 一致时,它们就是正确的。分类准确率是所有预测中正确预测的比例。虽然直接优化准确率可能很困难(因为它不可微),但它通常是我们最关心的性能度量。它通常是基准测试中的相关指标。因此,在训练分类器时,我们几乎总是会报告它。

准确率的计算方式如下。首先,如果 y_hat 是一个矩阵,我们假设第二维度存储了每个类别的预测分数。我们使用 argmax 来通过每行中最大条目的索引获得预测的类别。然后,我们将预测的类别与真实标签 y 进行逐元素比较。由于等号运算符 == 对数据类型敏感,我们将 y_hat 的数据类型转换为与 y 匹配。结果是一个包含0(假)和1(真)条目的张量。求和得到正确预测的数量。

@d2l.add_to_class(Classifier)  #@save
def accuracy(self, Y_hat, Y, averaged=True):
    """Compute the number of correct predictions."""
    Y_hat = Y_hat.reshape((-1, Y_hat.shape[-1]))
    preds = Y_hat.argmax(axis=1).type(Y.dtype)
    compare = (preds == Y.reshape(-1)).type(torch.float32)
    return compare.mean() if averaged else compare
@d2l.add_to_class(Classifier)  #@save
def accuracy(self, Y_hat, Y, averaged=True):
    """Compute the number of correct predictions."""
    Y_hat = Y_hat.reshape((-1, Y_hat.shape[-1]))
    preds = Y_hat.argmax(axis=1).astype(Y.dtype)
    compare = (preds == Y.reshape(-1)).astype(np.float32)
    return compare.mean() if averaged else compare

@d2l.add_to_class(d2l.Module)  #@save
def get_scratch_params(self):
    params = []
    for attr in dir(self):
        a = getattr(self, attr)
        if isinstance(a, np.ndarray):
            params.append(a)
        if isinstance(a, d2l.Module):
            params.extend(a.get_scratch_params())
    return params

@d2l.add_to_class(d2l.Module)  #@save
def parameters(self):
    params = self.collect_params()
    return params if isinstance(params, gluon.parameter.ParameterDict) and len(
        params.keys()) else self.get_scratch_params()
@d2l.add_to_class(Classifier)  #@save
@partial(jax.jit, static_argnums=(0, 5))
def accuracy(self, params, X, Y, state, averaged=True):
    """Compute the number of correct predictions."""
    Y_hat = state.apply_fn({'params': params,
                            'batch_stats': state.batch_stats},  # BatchNorm Only
                           *X)
    Y_hat = Y_hat.reshape((-1, Y_hat.shape[-1]))
    preds = Y_hat.argmax(axis=1).astype(Y.dtype)
    compare = (preds == Y.reshape(-1)).astype(jnp.float32)
    return compare.mean() if averaged else compare
@d2l.add_to_class(Classifier)  #@save
def accuracy(self, Y_hat, Y, averaged=True):
    """Compute the number of correct predictions."""
    Y_hat = tf.reshape(Y_hat, (-1, Y_hat.shape[-1]))
    preds = tf.cast(tf.argmax(Y_hat, axis=1), Y.dtype)
    compare = tf.cast(preds == tf.reshape(Y, -1), tf.float32)
    return tf.reduce_mean(compare) if averaged else compare

4.3.3. 小结

分类是一个足够常见的问题,值得拥有自己的便捷函数。在分类中至关重要的是分类器的*准确率*。请注意,虽然我们通常主要关心准确率,但出于统计和计算的原因,我们训练分类器以优化各种其他目标。然而,无论在训练期间最小化了哪个损失函数,拥有一个便捷的方法来凭经验评估我们分类器的准确率都是有用的。

4.3.4. 练习

  1. \(L_\textrm{v}\) 表示验证损失,令 \(L_\textrm{v}^\textrm{q}\) 为本节中通过损失函数平均计算的粗略估计值。最后,用 \(l_\textrm{v}^\textrm{b}\) 表示最后一个小批量的损失。用 \(L_\textrm{v}^\textrm{q}\)\(l_\textrm{v}^\textrm{b}\) 以及样本和批量大小来表示 \(L_\textrm{v}\)

  2. 证明粗略估计 \(L_\textrm{v}^\textrm{q}\) 是无偏的。也就是说,证明 \(E[L_\textrm{v}] = E[L_\textrm{v}^\textrm{q}]\)。为什么你仍然想使用 \(L_\textrm{v}\) 呢?

  3. 给定一个多类分类损失,用 \(l(y,y')\) 表示当我们观察到 \(y\) 时估计为 \(y'\) 的惩罚,并给定一个概率 \(p(y \mid x)\),制定一个最优选择 \(y'\) 的规则。提示:使用 \(l\)\(p(y \mid x)\) 表示期望损失。