4.3. 分类模型基类¶ 在 SageMaker Studio Lab 中打开 Notebook
你可能已经注意到,在回归问题中,从零开始的实现和使用框架功能的简洁实现非常相似。分类问题也是如此。由于本书中的许多模型都处理分类问题,因此有必要添加专门支持此设置的功能。本节提供了一个分类模型的基类,以简化未来的代码。
import torch
from d2l import torch as d2l
from mxnet import autograd, gluon, np, npx
from d2l import mxnet as d2l
npx.set_np()
from functools import partial
import jax
import optax
from jax import numpy as jnp
from d2l import jax as d2l
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
import tensorflow as tf
from d2l import tensorflow as d2l
4.3.1. Classifier
类¶
我们在下面定义了 Classifier
类。在 validation_step
中,我们报告验证批次上的损失值和分类准确率。我们每隔 num_val_batches
个批次进行一次更新。这样做的好处是可以生成整个验证数据上的平均损失和准确率。如果最后一个批次包含的样本较少,这些平均数并不完全准确,但为了保持代码简洁,我们忽略了这个微小的差异。
class Classifier(d2l.Module): #@save
"""The base class of classification models."""
def validation_step(self, batch):
Y_hat = self(*batch[:-1])
self.plot('loss', self.loss(Y_hat, batch[-1]), train=False)
self.plot('acc', self.accuracy(Y_hat, batch[-1]), train=False)
我们在下面定义了 Classifier
类。在 validation_step
中,我们报告验证批次上的损失值和分类准确率。我们每隔 num_val_batches
个批次进行一次更新。这样做的好处是可以生成整个验证数据上的平均损失和准确率。如果最后一个批次包含的样本较少,这些平均数并不完全准确,但为了保持代码简洁,我们忽略了这个微小的差异。
class Classifier(d2l.Module): #@save
"""The base class of classification models."""
def validation_step(self, batch):
Y_hat = self(*batch[:-1])
self.plot('loss', self.loss(Y_hat, batch[-1]), train=False)
self.plot('acc', self.accuracy(Y_hat, batch[-1]), train=False)
我们在下面定义了 Classifier
类。在 validation_step
中,我们报告验证批次上的损失值和分类准确率。我们每隔 num_val_batches
个批次进行一次更新。这样做的好处是可以生成整个验证数据上的平均损失和准确率。如果最后一个批次包含的样本较少,这些平均数并不完全准确,但为了保持代码简洁,我们忽略了这个微小的差异。
我们还为JAX重新定义了 training_step
方法,因为之后所有将继承 Classifier
的模型都会有一个返回辅助数据的损失函数。这些辅助数据可以用于带批量归一化的模型(将在 8.5节 中解释),而在所有其他情况下,我们将使损失函数也返回一个占位符(空字典)来表示辅助数据。
class Classifier(d2l.Module): #@save
"""The base class of classification models."""
def training_step(self, params, batch, state):
# Here value is a tuple since models with BatchNorm layers require
# the loss to return auxiliary data
value, grads = jax.value_and_grad(
self.loss, has_aux=True)(params, batch[:-1], batch[-1], state)
l, _ = value
self.plot("loss", l, train=True)
return value, grads
def validation_step(self, params, batch, state):
# Discard the second returned value. It is used for training models
# with BatchNorm layers since loss also returns auxiliary data
l, _ = self.loss(params, batch[:-1], batch[-1], state)
self.plot('loss', l, train=False)
self.plot('acc', self.accuracy(params, batch[:-1], batch[-1], state),
train=False)
我们在下面定义了 Classifier
类。在 validation_step
中,我们报告验证批次上的损失值和分类准确率。我们每隔 num_val_batches
个批次进行一次更新。这样做的好处是可以生成整个验证数据上的平均损失和准确率。如果最后一个批次包含的样本较少,这些平均数并不完全准确,但为了保持代码简洁,我们忽略了这个微小的差异。
class Classifier(d2l.Module): #@save
"""The base class of classification models."""
def validation_step(self, batch):
Y_hat = self(*batch[:-1])
self.plot('loss', self.loss(Y_hat, batch[-1]), train=False)
self.plot('acc', self.accuracy(Y_hat, batch[-1]), train=False)
默认情况下,我们使用随机梯度下降优化器,对小批量数据进行操作,就像我们在线性回归的上下文中做的那样。
@d2l.add_to_class(d2l.Module) #@save
def configure_optimizers(self):
return torch.optim.SGD(self.parameters(), lr=self.lr)
@d2l.add_to_class(d2l.Module) #@save
def configure_optimizers(self):
params = self.parameters()
if isinstance(params, list):
return d2l.SGD(params, self.lr)
return gluon.Trainer(params, 'sgd', {'learning_rate': self.lr})
@d2l.add_to_class(d2l.Module) #@save
def configure_optimizers(self):
return optax.sgd(self.lr)
@d2l.add_to_class(d2l.Module) #@save
def configure_optimizers(self):
return tf.keras.optimizers.SGD(self.lr)
4.3.2. 准确率¶
给定预测的概率分布 y_hat
,当我们必须输出一个硬性预测时,我们通常选择预测概率最高的类别。事实上,许多应用都要求我们做出选择。例如,Gmail必须将一封电子邮件分类为“主要”、“社交”、“推广”、“论坛”或“垃圾邮件”。它可能在内部估计概率,但最终必须在这些类别中选择一个。
当预测与标签类别 y
一致时,它们就是正确的。分类准确率是所有预测中正确预测的比例。虽然直接优化准确率可能很困难(因为它不可微),但它通常是我们最关心的性能度量。它通常是基准测试中的相关指标。因此,在训练分类器时,我们几乎总是会报告它。
准确率的计算方式如下。首先,如果 y_hat
是一个矩阵,我们假设第二维度存储了每个类别的预测分数。我们使用 argmax
来通过每行中最大条目的索引获得预测的类别。然后,我们将预测的类别与真实标签 y
进行逐元素比较。由于等号运算符 ==
对数据类型敏感,我们将 y_hat
的数据类型转换为与 y
匹配。结果是一个包含0(假)和1(真)条目的张量。求和得到正确预测的数量。
@d2l.add_to_class(Classifier) #@save
def accuracy(self, Y_hat, Y, averaged=True):
"""Compute the number of correct predictions."""
Y_hat = Y_hat.reshape((-1, Y_hat.shape[-1]))
preds = Y_hat.argmax(axis=1).type(Y.dtype)
compare = (preds == Y.reshape(-1)).type(torch.float32)
return compare.mean() if averaged else compare
@d2l.add_to_class(Classifier) #@save
def accuracy(self, Y_hat, Y, averaged=True):
"""Compute the number of correct predictions."""
Y_hat = Y_hat.reshape((-1, Y_hat.shape[-1]))
preds = Y_hat.argmax(axis=1).astype(Y.dtype)
compare = (preds == Y.reshape(-1)).astype(np.float32)
return compare.mean() if averaged else compare
@d2l.add_to_class(d2l.Module) #@save
def get_scratch_params(self):
params = []
for attr in dir(self):
a = getattr(self, attr)
if isinstance(a, np.ndarray):
params.append(a)
if isinstance(a, d2l.Module):
params.extend(a.get_scratch_params())
return params
@d2l.add_to_class(d2l.Module) #@save
def parameters(self):
params = self.collect_params()
return params if isinstance(params, gluon.parameter.ParameterDict) and len(
params.keys()) else self.get_scratch_params()
@d2l.add_to_class(Classifier) #@save
@partial(jax.jit, static_argnums=(0, 5))
def accuracy(self, params, X, Y, state, averaged=True):
"""Compute the number of correct predictions."""
Y_hat = state.apply_fn({'params': params,
'batch_stats': state.batch_stats}, # BatchNorm Only
*X)
Y_hat = Y_hat.reshape((-1, Y_hat.shape[-1]))
preds = Y_hat.argmax(axis=1).astype(Y.dtype)
compare = (preds == Y.reshape(-1)).astype(jnp.float32)
return compare.mean() if averaged else compare
@d2l.add_to_class(Classifier) #@save
def accuracy(self, Y_hat, Y, averaged=True):
"""Compute the number of correct predictions."""
Y_hat = tf.reshape(Y_hat, (-1, Y_hat.shape[-1]))
preds = tf.cast(tf.argmax(Y_hat, axis=1), Y.dtype)
compare = tf.cast(preds == tf.reshape(Y, -1), tf.float32)
return tf.reduce_mean(compare) if averaged else compare
4.3.3. 小结¶
分类是一个足够常见的问题,值得拥有自己的便捷函数。在分类中至关重要的是分类器的*准确率*。请注意,虽然我们通常主要关心准确率,但出于统计和计算的原因,我们训练分类器以优化各种其他目标。然而,无论在训练期间最小化了哪个损失函数,拥有一个便捷的方法来凭经验评估我们分类器的准确率都是有用的。
4.3.4. 练习¶
用 \(L_\textrm{v}\) 表示验证损失,令 \(L_\textrm{v}^\textrm{q}\) 为本节中通过损失函数平均计算的粗略估计值。最后,用 \(l_\textrm{v}^\textrm{b}\) 表示最后一个小批量的损失。用 \(L_\textrm{v}^\textrm{q}\)、\(l_\textrm{v}^\textrm{b}\) 以及样本和批量大小来表示 \(L_\textrm{v}\)。
证明粗略估计 \(L_\textrm{v}^\textrm{q}\) 是无偏的。也就是说,证明 \(E[L_\textrm{v}] = E[L_\textrm{v}^\textrm{q}]\)。为什么你仍然想使用 \(L_\textrm{v}\) 呢?
给定一个多类分类损失,用 \(l(y,y')\) 表示当我们观察到 \(y\) 时估计为 \(y'\) 的惩罚,并给定一个概率 \(p(y \mid x)\),制定一个最优选择 \(y'\) 的规则。提示:使用 \(l\) 和 \(p(y \mid x)\) 表示期望损失。