4.2. 图像分类数据集¶ 在 SageMaker Studio Lab 中打开 Notebook
一个广泛用于图像分类的数据集是手写数字的MNIST数据集 (LeCun et al., 1998)。在20世纪90年代发布时,它由60000张分辨率为\(28 \times 28\)像素的图像(外加一个包含10000张图像的测试数据集)组成,对大多数机器学习算法构成了巨大的挑战。从当时的角度看,1995年,一台拥有高达64MB内存和惊人的5 MFLOPs性能的Sun SPARCStation 5工作站,在AT&T贝尔实验室被认为是机器学习的顶尖设备。在20世纪90年代,实现高手写数字识别准确率是美国邮政服务(USPS)信件分拣自动化的关键组成部分。像LeNet-5 (LeCun et al., 1995)这样的深度网络、具有不变性的支持向量机 (Schölkopf et al., 1996)以及切线距离分类器 (Simard et al., 1998)都能达到低于1%的错误率。
在十多年的时间里,MNIST一直是比较机器学习算法的参考标准。尽管它作为一个基准数据集有过辉煌的历史,但即便是按今天的标准来看很简单的模型,也能达到超过95%的分类准确率,这使得它不适合用来区分强模型和弱模型。更有甚者,该数据集允许达到*极*高的准确率,这在许多分类问题中并不常见。这使得算法的发展偏向于那些能利用干净数据集的特定算法家族,如活性集方法和寻求边界的活性集算法。如今,MNIST更多地被用作健全性检查,而非基准测试。ImageNet (Deng et al., 2009)提出了一个更具现实意义的挑战。不幸的是,ImageNet对于本书中的许多示例和说明来说太大了,因为训练时间太长,无法使示例具有互动性。作为替代,在接下来的章节中,我们将重点讨论性质相似但规模小得多的Fashion-MNIST数据集(Xiao et al., 2017),该数据集于2017年发布。它包含10个类别的服装图像,分辨率为\(28 \times 28\)像素。
%matplotlib inline
import time
import torch
import torchvision
from torchvision import transforms
from d2l import torch as d2l
d2l.use_svg_display()
%matplotlib inline
import time
from mxnet import gluon, npx
from mxnet.gluon.data.vision import transforms
from d2l import mxnet as d2l
npx.set_np()
d2l.use_svg_display()
%matplotlib inline
import time
import jax
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
from jax import numpy as jnp
from d2l import jax as d2l
d2l.use_svg_display()
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
%matplotlib inline
import time
import tensorflow as tf
from d2l import tensorflow as d2l
d2l.use_svg_display()
4.2.1. 加载数据集¶
由于Fashion-MNIST数据集非常有用,所有主流框架都提供了其预处理版本。我们可以使用内置的框架工具下载并将其读入内存。
class FashionMNIST(d2l.DataModule): #@save
"""The Fashion-MNIST dataset."""
def __init__(self, batch_size=64, resize=(28, 28)):
super().__init__()
self.save_hyperparameters()
trans = transforms.Compose([transforms.Resize(resize),
transforms.ToTensor()])
self.train = torchvision.datasets.FashionMNIST(
root=self.root, train=True, transform=trans, download=True)
self.val = torchvision.datasets.FashionMNIST(
root=self.root, train=False, transform=trans, download=True)
class FashionMNIST(d2l.DataModule): #@save
"""The Fashion-MNIST dataset."""
def __init__(self, batch_size=64, resize=(28, 28)):
super().__init__()
self.save_hyperparameters()
trans = transforms.Compose([transforms.Resize(resize),
transforms.ToTensor()])
self.train = gluon.data.vision.FashionMNIST(
train=True).transform_first(trans)
self.val = gluon.data.vision.FashionMNIST(
train=False).transform_first(trans)
class FashionMNIST(d2l.DataModule): #@save
"""The Fashion-MNIST dataset."""
def __init__(self, batch_size=64, resize=(28, 28)):
super().__init__()
self.save_hyperparameters()
self.train, self.val = tf.keras.datasets.fashion_mnist.load_data()
class FashionMNIST(d2l.DataModule): #@save
"""The Fashion-MNIST dataset."""
def __init__(self, batch_size=64, resize=(28, 28)):
super().__init__()
self.save_hyperparameters()
self.train, self.val = tf.keras.datasets.fashion_mnist.load_data()
Fashion-MNIST包含10个类别的图像,每个类别在训练数据集中有6000张图像,在测试数据集中有1000张图像。*测试数据集*用于评估模型性能(不能用于训练)。因此,训练集和测试集分别包含60,000张和10,000张图像。
data = FashionMNIST(resize=(32, 32))
len(data.train), len(data.val)
(60000, 10000)
data = FashionMNIST(resize=(32, 32))
len(data.train), len(data.val)
[22:09:46] ../src/storage/storage.cc:196: Using Pooled (Naive) StorageManager for CPU
(60000, 10000)
data = FashionMNIST(resize=(32, 32))
len(data.train[0]), len(data.val[0])
(60000, 10000)
data = FashionMNIST(resize=(32, 32))
len(data.train[0]), len(data.val[0])
(60000, 10000)
这些图像是灰度图,并且在上面被放大到\(32 \times 32\)像素分辨率。这与原始的MNIST数据集相似,后者由(二值的)黑白图像组成。但请注意,大多数现代图像数据有三个通道(红、绿、蓝),而高光谱图像可以有超过100个通道(HyMap传感器有126个通道)。按照惯例,我们将图像存储为一个\(c \times h \times w\)张量,其中\(c\)是颜色通道数,\(h\)是高度,\(w\)是宽度。
data.train[0][0].shape
torch.Size([1, 32, 32])
data.train[0][0].shape
(1, 32, 32)
data.train[0][0].shape
(28, 28)
data.train[0][0].shape
(28, 28)
Fashion-MNIST的类别有易于人类理解的名称。下面的便捷方法可以在数字标签和其名称之间进行转换。
@d2l.add_to_class(FashionMNIST) #@save
def text_labels(self, indices):
"""Return text labels."""
labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
return [labels[int(i)] for i in indices]
4.2.2. 读取小批量¶
为了方便我们从训练集和测试集读取数据,我们使用内置的数据迭代器,而不是从头创建一个。回想一下,在每次迭代中,数据迭代器会读取一个大小为batch_size
的小批量数据。我们还为训练数据迭代器随机打乱了样本。
@d2l.add_to_class(FashionMNIST) #@save
def get_dataloader(self, train):
data = self.train if train else self.val
return torch.utils.data.DataLoader(data, self.batch_size, shuffle=train,
num_workers=self.num_workers)
@d2l.add_to_class(FashionMNIST) #@save
def get_dataloader(self, train):
data = self.train if train else self.val
return gluon.data.DataLoader(data, self.batch_size, shuffle=train,
num_workers=self.num_workers)
@d2l.add_to_class(FashionMNIST) #@save
def get_dataloader(self, train):
data = self.train if train else self.val
process = lambda X, y: (tf.expand_dims(X, axis=3) / 255,
tf.cast(y, dtype='int32'))
resize_fn = lambda X, y: (tf.image.resize_with_pad(X, *self.resize), y)
shuffle_buf = len(data[0]) if train else 1
return tfds.as_numpy(
tf.data.Dataset.from_tensor_slices(process(*data)).batch(
self.batch_size).map(resize_fn).shuffle(shuffle_buf))
@d2l.add_to_class(FashionMNIST) #@save
def get_dataloader(self, train):
data = self.train if train else self.val
process = lambda X, y: (tf.expand_dims(X, axis=3) / 255,
tf.cast(y, dtype='int32'))
resize_fn = lambda X, y: (tf.image.resize_with_pad(X, *self.resize), y)
shuffle_buf = len(data[0]) if train else 1
return tf.data.Dataset.from_tensor_slices(process(*data)).batch(
self.batch_size).map(resize_fn).shuffle(shuffle_buf)
为了看看它是如何工作的,让我们通过调用train_dataloader
方法来加载一个小批量的图像。它包含64张图像。
X, y = next(iter(data.train_dataloader()))
print(X.shape, X.dtype, y.shape, y.dtype)
torch.Size([64, 1, 32, 32]) torch.float32 torch.Size([64]) torch.int64
X, y = next(iter(data.train_dataloader()))
print(X.shape, X.dtype, y.shape, y.dtype)
(64, 1, 32, 32) float32 (64,) int32
X, y = next(iter(data.train_dataloader()))
print(X.shape, X.dtype, y.shape, y.dtype)
(64, 32, 32, 1) float32 (64,) int32
X, y = next(iter(data.train_dataloader()))
print(X.shape, X.dtype, y.shape, y.dtype)
(64, 32, 32, 1) <dtype: 'float32'> (64,) <dtype: 'int32'>
我们来看看读取图像所需的时间。尽管它是一个内置的加载器,但速度并非极快。尽管如此,这已经足够了,因为用深度网络处理图像需要更长的时间。因此,只要训练网络不会受到I/O的限制,这就足够好了。
tic = time.time()
for X, y in data.train_dataloader():
continue
f'{time.time() - tic:.2f} sec'
'4.69 sec'
tic = time.time()
for X, y in data.train_dataloader():
continue
f'{time.time() - tic:.2f} sec'
'4.60 sec'
tic = time.time()
for X, y in data.train_dataloader():
continue
f'{time.time() - tic:.2f} sec'
'1.02 sec'
tic = time.time()
for X, y in data.train_dataloader():
continue
f'{time.time() - tic:.2f} sec'
'0.87 sec'
4.2.3. 可视化¶
我们将经常使用Fashion-MNIST数据集。一个便捷函数show_images
可以用来可视化图像及其相关标签。我们跳过实现细节,只在下面展示其接口:我们只需要知道如何调用d2l.show_images
,而不需要知道这类实用函数是如何工作的。
def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5): #@save
"""Plot a list of images."""
raise NotImplementedError
让我们好好利用它。总的来说,可视化和检查你正在训练的数据是一个好主意。人类非常擅长发现异常,因此,可视化可以作为一种额外的保障,防止实验设计中的错误和失误。以下是训练数据集中前几个样本的图像及其对应的(文本)标签。
@d2l.add_to_class(FashionMNIST) #@save
def visualize(self, batch, nrows=1, ncols=8, labels=[]):
X, y = batch
if not labels:
labels = self.text_labels(y)
d2l.show_images(X.squeeze(1), nrows, ncols, titles=labels)
batch = next(iter(data.val_dataloader()))
data.visualize(batch)
@d2l.add_to_class(FashionMNIST) #@save
def visualize(self, batch, nrows=1, ncols=8, labels=[]):
X, y = batch
if not labels:
labels = self.text_labels(y)
d2l.show_images(X.squeeze(1), nrows, ncols, titles=labels)
batch = next(iter(data.val_dataloader()))
data.visualize(batch)
@d2l.add_to_class(FashionMNIST) #@save
def visualize(self, batch, nrows=1, ncols=8, labels=[]):
X, y = batch
if not labels:
labels = self.text_labels(y)
d2l.show_images(jnp.squeeze(X), nrows, ncols, titles=labels)
batch = next(iter(data.val_dataloader()))
data.visualize(batch)
@d2l.add_to_class(FashionMNIST) #@save
def visualize(self, batch, nrows=1, ncols=8, labels=[]):
X, y = batch
if not labels:
labels = self.text_labels(y)
d2l.show_images(tf.squeeze(X), nrows, ncols, titles=labels)
batch = next(iter(data.val_dataloader()))
data.visualize(batch)
我们现在已经准备好在接下来的章节中使用Fashion-MNIST数据集了。
4.2.4. 小结¶
我们现在有一个稍微更现实的数据集用于分类。Fashion-MNIST是一个服装分类数据集,由代表10个类别的图像组成。我们将在后续的章节中使用这个数据集来评估各种网络设计,从简单的线性模型到高级的残差网络。正如我们通常处理图像一样,我们将其读取为形状为(批量大小,通道数,高度,宽度)的张量。目前,我们只有一个通道,因为图像是灰度图(上面的可视化为了提高可见性使用了伪彩色调色板)。
最后,数据迭代器是实现高效性能的关键组成部分。例如,我们可能会使用GPU进行高效的图像解压、视频转码或其他预处理。只要可能,您都应该依赖于实现良好的数据迭代器,它们利用高性能计算来避免减慢您的训练循环。