14.11. 全卷积网络¶ 在 SageMaker Studio Lab 中打开 Notebook
正如我们在 14.9节 中所讨论的,语义分割在像素级别上对图像进行分类。 **全卷积网络**(fully convolutional network,FCN)采用卷积神经网络实现了从图像像素到像素类别的变换 (Long et al., 2015)。 与我们之前在图像分类或目标检测中介绍的卷积神经网络不同,全卷积网络将中间层特征图的高度和宽度变换回输入图像的尺寸:这是通过在 14.10节 中引入的转置卷积层(transposed convolutional layer)实现的。 因此,输出的类别预测与输入图像在像素级别上具有一一对应关系:通道维的输出即是对应位置的输入像素的类别预测。
%matplotlib inline
import torch
import torchvision
from torch import nn
from torch.nn import functional as F
from d2l import torch as d2l
%matplotlib inline
from mxnet import gluon, image, init, np, npx
from mxnet.gluon import nn
from d2l import mxnet as d2l
npx.set_np()
14.11.1. 模型¶
在这里,我们描述全卷积网络模型的基本设计。 如 图14.11.1 所示,该模型首先使用卷积神经网络来抽取图像特征,然后通过 \(1\times 1\) 卷积层将通道数转换为类别数,最后通过在 14.10节 中介绍的转置卷积层将特征图的高度和宽度变换为输入图像的尺寸。 因此,模型输出与输入图像的高和宽相同,其中输出通道包含对相同空间位置的输入像素的类别预测。
图 14.11.1 全卷积网络。¶
下面,我们使用在ImageNet数据集上预训练的ResNet-18模型来提取图像特征,并将该网络实例记为`pretrained_net`。 这个模型的最后几层包括一个全局平均汇聚层和一个全连接层:在全卷积网络中我们不需要它们。
pretrained_net = torchvision.models.resnet18(pretrained=True)
list(pretrained_net.children())[-3:]
Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /home/ci/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 56.3MB/s]
[Sequential(
(0): BasicBlock(
(conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
),
AdaptiveAvgPool2d(output_size=(1, 1)),
Linear(in_features=512, out_features=1000, bias=True)]
pretrained_net = gluon.model_zoo.vision.resnet18_v2(pretrained=True)
pretrained_net.features[-3:], pretrained_net.output
[22:23:49] ../src/storage/storage.cc:196: Using Pooled (Naive) StorageManager for CPU
(HybridSequential(
(0): Activation(relu)
(1): GlobalAvgPool2D(size=(1, 1), stride=(1, 1), padding=(0, 0), ceil_mode=True, global_pool=True, pool_type=avg, layout=NCHW)
(2): Flatten
),
Dense(512 -> 1000, linear))
接下来,我们创建全卷积网络实例`net`。 它复制了ResNet-18中大部分的预训练层,除了最后的全局平均汇聚层和最接近输出的全连接层。
net = nn.Sequential(*list(pretrained_net.children())[:-2])
net = nn.HybridSequential()
for layer in pretrained_net.features[:-2]:
net.add(layer)
给定高度为320和宽度为480的输入,`net`的前向传播将输入的高和宽减小至原来的1/32,即10和15。
X = torch.rand(size=(1, 3, 320, 480))
net(X).shape
torch.Size([1, 512, 10, 15])
X = np.random.uniform(size=(1, 3, 320, 480))
net(X).shape
(1, 512, 10, 15)
接下来,我们使用 \(1\times 1\) 卷积层将输出通道数转换为Pascal VOC2012数据集的类数(21类)。 最后,我们需要将特征图的高度和宽度增加32倍,从而将其变回输入图像的高和宽。 回想一下 7.3节 中计算卷积层输出形状的方法。 由于 \((320-64+16\times2+32)/32=10\) 和 \((480-64+16\times2+32)/32=15\),我们构造一个步幅为 \(32\) 的转置卷积层,并将卷积核的高和宽设为 \(64\),填充为 \(16\)。 我们可以看到,如果步幅为 \(s\),填充为 \(s/2\)(假设 \(s/2\) 是整数),卷积核的高和宽为 \(2s\),转置卷积核会将输入的高和宽分别放大 \(s\) 倍。
num_classes = 21
net.add_module('final_conv', nn.Conv2d(512, num_classes, kernel_size=1))
net.add_module('transpose_conv', nn.ConvTranspose2d(num_classes, num_classes,
kernel_size=64, padding=16, stride=32))
num_classes = 21
net.add(nn.Conv2D(num_classes, kernel_size=1),
nn.Conv2DTranspose(
num_classes, kernel_size=64, padding=16, strides=32))
14.11.2. 初始化转置卷积层¶
我们已经知道,转置卷积层可以增大特征图的高度和宽度。 在图像处理中,我们可能需要将图像放大,即*上采样*(upsampling)。 *双线性插值*(bilinear interpolation)是常用的上采样技术之一, 它也常用于初始化转置卷积层。
为了解释双线性插值,假设我们想要计算上采样输出图像中坐标 \((x, y)\) 处的像素。 首先,将 \((x, y)\) 映射到输入图像上的坐标 \((x', y')\), 例如,根据输入与输出的尺寸之比。 请注意,映射后的 \(x'\) 和 \(y'\) 是实数。 然后,在输入图像上找到与坐标 \((x', y')\) 最近的4个像素。 最后,基于输入图像上的这4个最近像素及其与 \((x', y')\) 的相对距离,来计算输出图像在坐标 \((x, y)\) 处的像素。
可以通过由以下 `bilinear_kernel` 函数构造的卷积核的转置卷积层来实现双线性插值的上采样。 由于篇幅限制,我们只给出下面 `bilinear_kernel` 函数的实现,不讨论算法的原理。
def bilinear_kernel(in_channels, out_channels, kernel_size):
factor = (kernel_size + 1) // 2
if kernel_size % 2 == 1:
center = factor - 1
else:
center = factor - 0.5
og = (torch.arange(kernel_size).reshape(-1, 1),
torch.arange(kernel_size).reshape(1, -1))
filt = (1 - torch.abs(og[0] - center) / factor) * \
(1 - torch.abs(og[1] - center) / factor)
weight = torch.zeros((in_channels, out_channels,
kernel_size, kernel_size))
weight[range(in_channels), range(out_channels), :, :] = filt
return weight
def bilinear_kernel(in_channels, out_channels, kernel_size):
factor = (kernel_size + 1) // 2
if kernel_size % 2 == 1:
center = factor - 1
else:
center = factor - 0.5
og = (np.arange(kernel_size).reshape(-1, 1),
np.arange(kernel_size).reshape(1, -1))
filt = (1 - np.abs(og[0] - center) / factor) * \
(1 - np.abs(og[1] - center) / factor)
weight = np.zeros((in_channels, out_channels, kernel_size, kernel_size))
weight[range(in_channels), range(out_channels), :, :] = filt
return np.array(weight)
让我们用转置卷积层来实现双线性插值的上采样实验。我们构造一个将输入高和宽都放大2倍的转置卷积层,并将其卷积核用`bilinear_kernel`函数初始化。
conv_trans = nn.ConvTranspose2d(3, 3, kernel_size=4, padding=1, stride=2,
bias=False)
conv_trans.weight.data.copy_(bilinear_kernel(3, 3, 4));
conv_trans = nn.Conv2DTranspose(3, kernel_size=4, padding=1, strides=2)
conv_trans.initialize(init.Constant(bilinear_kernel(3, 3, 4)))
读取图像`X`,并将上采样的结果赋值给`Y`。为了打印图像,我们需要调整通道维的位置。
img = torchvision.transforms.ToTensor()(d2l.Image.open('../img/catdog.jpg'))
X = img.unsqueeze(0)
Y = conv_trans(X)
out_img = Y[0].permute(1, 2, 0).detach()
img = image.imread('../img/catdog.jpg')
X = np.expand_dims(img.astype('float32').transpose(2, 0, 1), axis=0) / 255
Y = conv_trans(X)
out_img = Y[0].transpose(1, 2, 0)
正如我们所看到的,转置卷积层将图像的高度和宽度都放大了2倍。 除了坐标刻度不同,双线性插值放大的图像和在 14.3节 中打印出的原图看起来是一样的。
d2l.set_figsize()
print('input image shape:', img.permute(1, 2, 0).shape)
d2l.plt.imshow(img.permute(1, 2, 0));
print('output image shape:', out_img.shape)
d2l.plt.imshow(out_img);
input image shape: torch.Size([561, 728, 3])
output image shape: torch.Size([1122, 1456, 3])
d2l.set_figsize()
print('input image shape:', img.shape)
d2l.plt.imshow(img.asnumpy());
print('output image shape:', out_img.shape)
d2l.plt.imshow(out_img.asnumpy());
input image shape: (561, 728, 3)
output image shape: (1122, 1456, 3)
在全卷积网络中,我们用双线性插值的上采样来初始化转置卷积层。对于 \(1\times 1\) 卷积层,我们使用Xavier初始化。
W = bilinear_kernel(num_classes, num_classes, 64)
net.transpose_conv.weight.data.copy_(W);
W = bilinear_kernel(num_classes, num_classes, 64)
net[-1].initialize(init.Constant(W))
net[-2].initialize(init=init.Xavier())
14.11.3. 读取数据集¶
我们按照 14.9节 中介绍的步骤读取语义分割数据集。 随机裁剪的输出图像的形状指定为 \(320\times 480\):高和宽都可以被 \(32\) 整除。
batch_size, crop_size = 32, (320, 480)
train_iter, test_iter = d2l.load_data_voc(batch_size, crop_size)
read 1114 examples
read 1078 examples
batch_size, crop_size = 32, (320, 480)
train_iter, test_iter = d2l.load_data_voc(batch_size, crop_size)
read 1114 examples
read 1078 examples
14.11.4. 训练¶
现在我们可以训练我们构建的全卷积网络了。这里的损失函数和准确率计算与之前几章中的图像分类问题中的并没有本质上的不同。因为我们使用转置卷积层的通道来预测每个像素的类别,所以在损失计算中指定了通道维。此外,模型的准确率是根据所有像素的预测类别是否正确来计算的。
def loss(inputs, targets):
return F.cross_entropy(inputs, targets, reduction='none').mean(1).mean(1)
num_epochs, lr, wd, devices = 5, 0.001, 1e-3, d2l.try_all_gpus()
trainer = torch.optim.SGD(net.parameters(), lr=lr, weight_decay=wd)
d2l.train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs, devices)
loss 0.449, train acc 0.861, test acc 0.852
226.7 examples/sec on [device(type='cuda', index=0), device(type='cuda', index=1)]
num_epochs, lr, wd, devices = 5, 0.1, 1e-3, d2l.try_all_gpus()
loss = gluon.loss.SoftmaxCrossEntropyLoss(axis=1)
net.collect_params().reset_ctx(devices)
trainer = gluon.Trainer(net.collect_params(), 'sgd',
{'learning_rate': lr, 'wd': wd})
d2l.train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs, devices)
loss 0.322, train acc 0.894, test acc 0.851
132.1 examples/sec on [gpu(0), gpu(1)]
14.11.5. 预测¶
在预测时,我们需要将输入图像在各个通道上进行标准化,并将其转换为卷积神经网络所需要的四维输入格式。
def predict(img):
X = test_iter.dataset.normalize_image(img).unsqueeze(0)
pred = net(X.to(devices[0])).argmax(dim=1)
return pred.reshape(pred.shape[1], pred.shape[2])
def predict(img):
X = test_iter._dataset.normalize_image(img)
X = np.expand_dims(X.transpose(2, 0, 1), axis=0)
pred = net(X.as_in_ctx(devices[0])).argmax(axis=1)
return pred.reshape(pred.shape[1], pred.shape[2])
为了可视化预测的每个像素的类别,我们将预测的类别映射回它们在数据集中的标签颜色。
def label2image(pred):
colormap = torch.tensor(d2l.VOC_COLORMAP, device=devices[0])
X = pred.long()
return colormap[X, :]
def label2image(pred):
colormap = np.array(d2l.VOC_COLORMAP, ctx=devices[0], dtype='uint8')
X = pred.astype('int32')
return colormap[X, :]
测试数据集中的图像大小和形状各不相同。由于模型使用了步幅为32的转置卷积层,当输入图像的高度或宽度不能被32整除时,转置卷积层的输出高度或宽度将偏离输入图像的形状。为了解决这个问题,我们可以在图像中裁剪出多个高度和宽度为32的整数倍的矩形区域,并分别对这些区域中的像素进行前向传播。请注意,这些矩形区域的并集需要完全覆盖输入图像。当一个像素被多个矩形区域覆盖时,该像素在不同区域的转置卷积输出的平均值可以作为softmax运算的输入,用于预测类别。
为简单起见,我们只读取几张较大的测试图像,并从图像的左上角开始裁剪一个 \(320\times480\) 的区域用于预测。对于这些测试图像,我们逐行打印它们的裁剪区域、预测结果和真实值。
voc_dir = d2l.download_extract('voc2012', 'VOCdevkit/VOC2012')
test_images, test_labels = d2l.read_voc_images(voc_dir, False)
n, imgs = 4, []
for i in range(n):
crop_rect = (0, 0, 320, 480)
X = torchvision.transforms.functional.crop(test_images[i], *crop_rect)
pred = label2image(predict(X))
imgs += [X.permute(1,2,0), pred.cpu(),
torchvision.transforms.functional.crop(
test_labels[i], *crop_rect).permute(1,2,0)]
d2l.show_images(imgs[::3] + imgs[1::3] + imgs[2::3], 3, n, scale=2);
voc_dir = d2l.download_extract('voc2012', 'VOCdevkit/VOC2012')
test_images, test_labels = d2l.read_voc_images(voc_dir, False)
n, imgs = 4, []
for i in range(n):
crop_rect = (0, 0, 480, 320)
X = image.fixed_crop(test_images[i], *crop_rect)
pred = label2image(predict(X))
imgs += [X, pred, image.fixed_crop(test_labels[i], *crop_rect)]
d2l.show_images(imgs[::3] + imgs[1::3] + imgs[2::3], 3, n, scale=2);
14.11.6. 小结¶
全卷积网络首先使用卷积神经网络抽取图像特征,然后通过 \(1\times 1\) 卷积层将通道数变换为类别个数,最后通过转置卷积层将特征图的高和宽变换为输入图像的尺寸。
在全卷积网络中,我们可以使用双线性插值的上采样来初始化转置卷积层。
14.11.7. 练习¶
如果在实验中对转置卷积层使用Xavier初始化,结果会如何变化?
你能通过调整超参数来进一步提高模型的精度吗?
预测测试图像中所有像素的类别。
最初的全卷积网络论文还使用了某些中间卷积神经网络层的输出 (Long et al., 2015)。试着实现这个想法。