8.4. 多分支网络 (GoogLeNet)¶ 在 SageMaker Studio Lab 中打开 Notebook
2014年,*GoogLeNet* 在ImageNet挑战赛中一举夺魁 (Szegedy et al., 2015)。它使用了一种结合了NiN (Lin et al., 2013)、重复块 (Simonyan and Zisserman, 2014) 和各种卷积核优点的结构。它也可以说是第一个在CNN中对主干(数据摄取)、主体(数据处理)和头部(预测)有明确区分的网络。这种设计模式从此在深度网络的设计中一直沿用至今:*主干*由处理图像的前两到三个卷积层组成。它们从底层图像中提取低级特征。接下来是一个由卷积块组成的*主体*。最后,*头部*将目前为止获得的特征映射到手头的分类、分割、检测或跟踪问题所需的输出。
GoogLeNet的关键贡献是网络主体的设计。它以一种巧妙的方式解决了选择卷积核的问题。在其他工作试图确定哪种从 \(1 \times 1\) 到 \(11 \times 11\) 的卷积是最好的时候,它简单地*拼接*了多分支卷积。接下来,我们将介绍一个稍微简化的GoogLeNet版本:原始设计包含了一些通过中间损失函数来稳定训练的技巧,这些技巧应用于网络的多个层。由于有了改进的训练算法,这些技巧不再是必需的。
import torch
from torch import nn
from torch.nn import functional as F
from d2l import torch as d2l
from mxnet import init, np, npx
from mxnet.gluon import nn
from d2l import mxnet as d2l
npx.set_np()
import jax
from flax import linen as nn
from jax import numpy as jnp
from d2l import jax as d2l
import tensorflow as tf
from d2l import tensorflow as d2l
8.4.1. Inception块¶
GoogLeNet中的基本卷积块被称为 *Inception块*,这个名字来源于电影《盗梦空间》中的梗“we need to go deeper”。
图 8.4.1 Inception块的结构。¶
如 图 8.4.1 所示,Inception块由四个平行的分支组成。前三个分支使用窗口大小为 \(1\times 1\)、\(3\times 3\) 和 \(5\times 5\) 的卷积层,以从不同的空间尺寸中提取信息。中间两个分支还在输入上添加了一个 \(1\times 1\) 卷积来减少通道数,从而降低模型的复杂性。第四个分支使用一个 \(3\times 3\) 的最大汇聚层,然后是一个 \(1\times 1\) 的卷积层来改变通道数。所有四个分支都使用适当的填充来使输入和输出具有相同的高度和宽度。最后,每个分支的输出在通道维度上被拼接起来,构成该块的输出。Inception块中通常调整的超参数是每层的输出通道数,即如何在不同大小的卷积之间分配容量。
class Inception(nn.Module):
# c1--c4 are the number of output channels for each branch
def __init__(self, c1, c2, c3, c4, **kwargs):
super(Inception, self).__init__(**kwargs)
# Branch 1
self.b1_1 = nn.LazyConv2d(c1, kernel_size=1)
# Branch 2
self.b2_1 = nn.LazyConv2d(c2[0], kernel_size=1)
self.b2_2 = nn.LazyConv2d(c2[1], kernel_size=3, padding=1)
# Branch 3
self.b3_1 = nn.LazyConv2d(c3[0], kernel_size=1)
self.b3_2 = nn.LazyConv2d(c3[1], kernel_size=5, padding=2)
# Branch 4
self.b4_1 = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)
self.b4_2 = nn.LazyConv2d(c4, kernel_size=1)
def forward(self, x):
b1 = F.relu(self.b1_1(x))
b2 = F.relu(self.b2_2(F.relu(self.b2_1(x))))
b3 = F.relu(self.b3_2(F.relu(self.b3_1(x))))
b4 = F.relu(self.b4_2(self.b4_1(x)))
return torch.cat((b1, b2, b3, b4), dim=1)
class Inception(nn.Block):
# c1--c4 are the number of output channels for each branch
def __init__(self, c1, c2, c3, c4, **kwargs):
super(Inception, self).__init__(**kwargs)
# Branch 1
self.b1_1 = nn.Conv2D(c1, kernel_size=1, activation='relu')
# Branch 2
self.b2_1 = nn.Conv2D(c2[0], kernel_size=1, activation='relu')
self.b2_2 = nn.Conv2D(c2[1], kernel_size=3, padding=1,
activation='relu')
# Branch 3
self.b3_1 = nn.Conv2D(c3[0], kernel_size=1, activation='relu')
self.b3_2 = nn.Conv2D(c3[1], kernel_size=5, padding=2,
activation='relu')
# Branch 4
self.b4_1 = nn.MaxPool2D(pool_size=3, strides=1, padding=1)
self.b4_2 = nn.Conv2D(c4, kernel_size=1, activation='relu')
def forward(self, x):
b1 = self.b1_1(x)
b2 = self.b2_2(self.b2_1(x))
b3 = self.b3_2(self.b3_1(x))
b4 = self.b4_2(self.b4_1(x))
return np.concatenate((b1, b2, b3, b4), axis=1)
class Inception(nn.Module):
# `c1`--`c4` are the number of output channels for each branch
c1: int
c2: tuple
c3: tuple
c4: int
def setup(self):
# Branch 1
self.b1_1 = nn.Conv(self.c1, kernel_size=(1, 1))
# Branch 2
self.b2_1 = nn.Conv(self.c2[0], kernel_size=(1, 1))
self.b2_2 = nn.Conv(self.c2[1], kernel_size=(3, 3), padding='same')
# Branch 3
self.b3_1 = nn.Conv(self.c3[0], kernel_size=(1, 1))
self.b3_2 = nn.Conv(self.c3[1], kernel_size=(5, 5), padding='same')
# Branch 4
self.b4_1 = lambda x: nn.max_pool(x, window_shape=(3, 3),
strides=(1, 1), padding='same')
self.b4_2 = nn.Conv(self.c4, kernel_size=(1, 1))
def __call__(self, x):
b1 = nn.relu(self.b1_1(x))
b2 = nn.relu(self.b2_2(nn.relu(self.b2_1(x))))
b3 = nn.relu(self.b3_2(nn.relu(self.b3_1(x))))
b4 = nn.relu(self.b4_2(self.b4_1(x)))
return jnp.concatenate((b1, b2, b3, b4), axis=-1)
class Inception(tf.keras.Model):
# c1--c4 are the number of output channels for each branch
def __init__(self, c1, c2, c3, c4):
super().__init__()
self.b1_1 = tf.keras.layers.Conv2D(c1, 1, activation='relu')
self.b2_1 = tf.keras.layers.Conv2D(c2[0], 1, activation='relu')
self.b2_2 = tf.keras.layers.Conv2D(c2[1], 3, padding='same',
activation='relu')
self.b3_1 = tf.keras.layers.Conv2D(c3[0], 1, activation='relu')
self.b3_2 = tf.keras.layers.Conv2D(c3[1], 5, padding='same',
activation='relu')
self.b4_1 = tf.keras.layers.MaxPool2D(3, 1, padding='same')
self.b4_2 = tf.keras.layers.Conv2D(c4, 1, activation='relu')
def call(self, x):
b1 = self.b1_1(x)
b2 = self.b2_2(self.b2_1(x))
b3 = self.b3_2(self.b3_1(x))
b4 = self.b4_2(self.b4_1(x))
return tf.keras.layers.Concatenate()([b1, b2, b3, b4])
为了直观地理解为什么这个网络效果这么好,我们来考虑一下这些滤波器的组合。它们以各种滤波器尺寸探索图像。这意味着不同大小的滤波器可以有效地识别不同范围的细节。同时,我们可以为不同的滤波器分配不同数量的参数。
8.4.2. GoogLeNet模型¶
如 图 8.4.2 所示,GoogLeNet使用了总共9个Inception块的堆栈,这些块被分成了三组,中间有最大汇聚层,其头部使用全局平均汇聚层来生成其估计。Inception块之间的最大汇聚层降低了维度。在其主干部分,第一个模块类似于AlexNet和LeNet。
图 8.4.2 GoogLeNet架构。¶
现在我们可以一步步地实现GoogLeNet。让我们从主干开始。第一个模块使用一个64通道的 \(7\times 7\) 卷积层。
class GoogleNet(d2l.Classifier):
def b1(self):
return nn.Sequential(
nn.LazyConv2d(64, kernel_size=7, stride=2, padding=3),
nn.ReLU(), nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
class GoogleNet(d2l.Classifier):
def b1(self):
net = nn.Sequential()
net.add(nn.Conv2D(64, kernel_size=7, strides=2, padding=3,
activation='relu'),
nn.MaxPool2D(pool_size=3, strides=2, padding=1))
return net
class GoogleNet(d2l.Classifier):
lr: float = 0.1
num_classes: int = 10
def setup(self):
self.net = nn.Sequential([self.b1(), self.b2(), self.b3(), self.b4(),
self.b5(), nn.Dense(self.num_classes)])
def b1(self):
return nn.Sequential([
nn.Conv(64, kernel_size=(7, 7), strides=(2, 2), padding='same'),
nn.relu,
lambda x: nn.max_pool(x, window_shape=(3, 3), strides=(2, 2),
padding='same')])
class GoogleNet(d2l.Classifier):
def b1(self):
return tf.keras.models.Sequential([
tf.keras.layers.Conv2D(64, 7, strides=2, padding='same',
activation='relu'),
tf.keras.layers.MaxPool2D(pool_size=3, strides=2,
padding='same')])
第二个模块使用两个卷积层:首先是一个64通道的 \(1\times 1\) 卷积层,然后是一个将通道数增加三倍的 \(3\times 3\) 卷积层。这对应于Inception块中的第二个分支,并完成了主体的设计。此时我们有192个通道。
@d2l.add_to_class(GoogleNet)
def b2(self):
return nn.Sequential(
nn.LazyConv2d(64, kernel_size=1), nn.ReLU(),
nn.LazyConv2d(192, kernel_size=3, padding=1), nn.ReLU(),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
@d2l.add_to_class(GoogleNet)
def b2(self):
net = nn.Sequential()
net.add(nn.Conv2D(64, kernel_size=1, activation='relu'),
nn.Conv2D(192, kernel_size=3, padding=1, activation='relu'),
nn.MaxPool2D(pool_size=3, strides=2, padding=1))
return net
@d2l.add_to_class(GoogleNet)
def b2(self):
return nn.Sequential([nn.Conv(64, kernel_size=(1, 1)),
nn.relu,
nn.Conv(192, kernel_size=(3, 3), padding='same'),
nn.relu,
lambda x: nn.max_pool(x, window_shape=(3, 3),
strides=(2, 2),
padding='same')])
@d2l.add_to_class(GoogleNet)
def b2(self):
return tf.keras.Sequential([
tf.keras.layers.Conv2D(64, 1, activation='relu'),
tf.keras.layers.Conv2D(192, 3, padding='same', activation='relu'),
tf.keras.layers.MaxPool2D(pool_size=3, strides=2, padding='same')])
第三个模块串联了两个完整的Inception块。第一个Inception块的输出通道数是 \(64+128+32+32=256\)。这相当于四个分支之间输出通道数的比例为 \(2:4:1:1\)。为了实现这一点,我们首先在第二和第三个分支中分别将输入维度减少 \(\frac{1}{2}\) 和 \(\frac{1}{12}\),以分别得到 \(96 = 192/2\) 和 \(16 = 192/12\) 个通道。
第二个Inception块的输出通道数增加到 \(128+192+96+64=480\),得到的比例为 \(128:192:96:64 = 4:6:3:2\)。和之前一样,我们需要减少第二和第三个通道的中间维度。分别缩放 \(\frac{1}{2}\) 和 \(\frac{1}{8}\) 就足够了,分别得到 \(128\) 和 \(32\) 个通道。这由以下 Inception
块构造函数的参数所体现。
@d2l.add_to_class(GoogleNet)
def b3(self):
return nn.Sequential(Inception(64, (96, 128), (16, 32), 32),
Inception(128, (128, 192), (32, 96), 64),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
@d2l.add_to_class(GoogleNet)
def b3(self):
net = nn.Sequential()
net.add(Inception(64, (96, 128), (16, 32), 32),
Inception(128, (128, 192), (32, 96), 64),
nn.MaxPool2D(pool_size=3, strides=2, padding=1))
return net
@d2l.add_to_class(GoogleNet)
def b3(self):
return nn.Sequential([Inception(64, (96, 128), (16, 32), 32),
Inception(128, (128, 192), (32, 96), 64),
lambda x: nn.max_pool(x, window_shape=(3, 3),
strides=(2, 2),
padding='same')])
@d2l.add_to_class(GoogleNet)
def b3(self):
return tf.keras.models.Sequential([
Inception(64, (96, 128), (16, 32), 32),
Inception(128, (128, 192), (32, 96), 64),
tf.keras.layers.MaxPool2D(pool_size=3, strides=2, padding='same')])
第四个模块更复杂。它串联了五个Inception块,它们的输出通道数分别为 \(192+208+48+64=512\)、\(160+224+64+64=512\)、\(128+256+64+64=512\)、\(112+288+64+64=528\) 和 \(256+320+128+128=832\)。分配给这些分支的通道数与第三个模块中的相似:带有 \(3\times 3\) 卷积层的第二个分支输出最多的通道,其次是仅有 \(1\times 1\) 卷积层的第一个分支,带有 \(5\times 5\) 卷积层的第三个分支,以及带有 \(3\times 3\) 最大汇聚层的第四个分支。第二和第三个分支将首先根据比例减少通道数。这些比例在不同的Inception块中略有不同。
@d2l.add_to_class(GoogleNet)
def b4(self):
return nn.Sequential(Inception(192, (96, 208), (16, 48), 64),
Inception(160, (112, 224), (24, 64), 64),
Inception(128, (128, 256), (24, 64), 64),
Inception(112, (144, 288), (32, 64), 64),
Inception(256, (160, 320), (32, 128), 128),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
@d2l.add_to_class(GoogleNet)
def b4(self):
net = nn.Sequential()
net.add(Inception(192, (96, 208), (16, 48), 64),
Inception(160, (112, 224), (24, 64), 64),
Inception(128, (128, 256), (24, 64), 64),
Inception(112, (144, 288), (32, 64), 64),
Inception(256, (160, 320), (32, 128), 128),
nn.MaxPool2D(pool_size=3, strides=2, padding=1))
return net
@d2l.add_to_class(GoogleNet)
def b4(self):
return nn.Sequential([Inception(192, (96, 208), (16, 48), 64),
Inception(160, (112, 224), (24, 64), 64),
Inception(128, (128, 256), (24, 64), 64),
Inception(112, (144, 288), (32, 64), 64),
Inception(256, (160, 320), (32, 128), 128),
lambda x: nn.max_pool(x, window_shape=(3, 3),
strides=(2, 2),
padding='same')])
@d2l.add_to_class(GoogleNet)
def b4(self):
return tf.keras.Sequential([
Inception(192, (96, 208), (16, 48), 64),
Inception(160, (112, 224), (24, 64), 64),
Inception(128, (128, 256), (24, 64), 64),
Inception(112, (144, 288), (32, 64), 64),
Inception(256, (160, 320), (32, 128), 128),
tf.keras.layers.MaxPool2D(pool_size=3, strides=2, padding='same')])
第五个模块有两个Inception块,输出通道数分别为 \(256+320+128+128=832\) 和 \(384+384+128+128=1024\)。分配给每个分支的通道数与第三和第四个模块中的相同,但具体数值不同。值得注意的是,第五个块后面是输出层。这个块使用全局平均汇聚层将每个通道的高度和宽度变为1,就像在NiN中一样。最后,我们将输出转换成一个二维数组,然后是一个全连接层,其输出数量是标签类的数量。
@d2l.add_to_class(GoogleNet)
def b5(self):
return nn.Sequential(Inception(256, (160, 320), (32, 128), 128),
Inception(384, (192, 384), (48, 128), 128),
nn.AdaptiveAvgPool2d((1,1)), nn.Flatten())
@d2l.add_to_class(GoogleNet)
def b5(self):
net = nn.Sequential()
net.add(Inception(256, (160, 320), (32, 128), 128),
Inception(384, (192, 384), (48, 128), 128),
nn.GlobalAvgPool2D())
return net
@d2l.add_to_class(GoogleNet)
def b5(self):
return nn.Sequential([Inception(256, (160, 320), (32, 128), 128),
Inception(384, (192, 384), (48, 128), 128),
# Flax does not provide a GlobalAvgPool2D layer
lambda x: nn.avg_pool(x,
window_shape=x.shape[1:3],
strides=x.shape[1:3],
padding='valid'),
lambda x: x.reshape((x.shape[0], -1))])
@d2l.add_to_class(GoogleNet)
def b5(self):
return tf.keras.Sequential([
Inception(256, (160, 320), (32, 128), 128),
Inception(384, (192, 384), (48, 128), 128),
tf.keras.layers.GlobalAvgPool2D(),
tf.keras.layers.Flatten()])
现在我们已经定义了从 b1
到 b5
的所有块,只需将它们全部组装成一个完整的网络即可。
@d2l.add_to_class(GoogleNet)
def __init__(self, lr=0.1, num_classes=10):
super(GoogleNet, self).__init__()
self.save_hyperparameters()
self.net = nn.Sequential(self.b1(), self.b2(), self.b3(), self.b4(),
self.b5(), nn.LazyLinear(num_classes))
self.net.apply(d2l.init_cnn)
@d2l.add_to_class(GoogleNet)
def __init__(self, lr=0.1, num_classes=10):
super(GoogleNet, self).__init__()
self.save_hyperparameters()
self.net = nn.Sequential()
self.net.add(self.b1(), self.b2(), self.b3(), self.b4(), self.b5(),
nn.Dense(num_classes))
self.net.initialize(init.Xavier())
@d2l.add_to_class(GoogleNet)
def __init__(self, lr=0.1, num_classes=10):
super(GoogleNet, self).__init__()
self.save_hyperparameters()
self.net = tf.keras.Sequential([
self.b1(), self.b2(), self.b3(), self.b4(), self.b5(),
tf.keras.layers.Dense(num_classes)])
GoogLeNet模型在计算上很复杂。请注意,在选择的通道数、降维前的块数、跨通道的容量相对划分等方面存在大量相对任意的超参数。这很大程度上是由于在GoogLeNet被提出时,用于网络定义或设计探索的自动化工具尚未出现。例如,现在我们理所当然地认为一个称职的深度学习框架能够自动推断输入张量的维度。而在当时,许多这样的配置都必须由实验者明确指定,从而常常减慢了积极的实验过程。此外,用于自动探索的工具仍在不断变化中,最初的实验基本上是昂贵的暴力探索、遗传算法和类似策略。
目前我们将进行的唯一修改是将输入的高度和宽度从224减少到96,以便在Fashion-MNIST上有一个合理的训练时间。这简化了计算。让我们看看各个模块之间输出形状的变化。
model = GoogleNet().layer_summary((1, 1, 96, 96))
Sequential output shape: torch.Size([1, 64, 24, 24])
Sequential output shape: torch.Size([1, 192, 12, 12])
Sequential output shape: torch.Size([1, 480, 6, 6])
Sequential output shape: torch.Size([1, 832, 3, 3])
Sequential output shape: torch.Size([1, 1024])
Linear output shape: torch.Size([1, 10])
model = GoogleNet().layer_summary((1, 1, 96, 96))
[22:26:25] ../src/storage/storage.cc:196: Using Pooled (Naive) StorageManager for CPU
Sequential output shape: (1, 64, 24, 24)
Sequential output shape: (1, 192, 12, 12)
Sequential output shape: (1, 480, 6, 6)
Sequential output shape: (1, 832, 3, 3)
Sequential output shape: (1, 1024, 1, 1)
Dense output shape: (1, 10)
model = GoogleNet().layer_summary((1, 96, 96, 1))
Sequential output shape: (1, 24, 24, 64)
Sequential output shape: (1, 12, 12, 192)
Sequential output shape: (1, 6, 6, 480)
Sequential output shape: (1, 3, 3, 832)
Sequential output shape: (1, 1024)
Dense output shape: (1, 10)
model = GoogleNet().layer_summary((1, 96, 96, 1))
Sequential output shape: (1, 24, 24, 64)
Sequential output shape: (1, 12, 12, 192)
Sequential output shape: (1, 6, 6, 480)
Sequential output shape: (1, 3, 3, 832)
Sequential output shape: (1, 1024)
Dense output shape: (1, 10)
8.4.3. 训练¶
和之前一样,我们使用Fashion-MNIST数据集来训练我们的模型。在调用训练过程之前,我们将其转换为 \(96 \times 96\) 像素分辨率。
model = GoogleNet(lr=0.01)
trainer = d2l.Trainer(max_epochs=10, num_gpus=1)
data = d2l.FashionMNIST(batch_size=128, resize=(96, 96))
model.apply_init([next(iter(data.get_dataloader(True)))[0]], d2l.init_cnn)
trainer.fit(model, data)
model = GoogleNet(lr=0.01)
trainer = d2l.Trainer(max_epochs=10, num_gpus=1)
data = d2l.FashionMNIST(batch_size=128, resize=(96, 96))
trainer.fit(model, data)
model = GoogleNet(lr=0.01)
trainer = d2l.Trainer(max_epochs=10, num_gpus=1)
data = d2l.FashionMNIST(batch_size=128, resize=(96, 96))
trainer.fit(model, data)
trainer = d2l.Trainer(max_epochs=10)
data = d2l.FashionMNIST(batch_size=128, resize=(96, 96))
with d2l.try_gpu():
model = GoogleNet(lr=0.01)
trainer.fit(model, data)
8.4.4. 讨论¶
GoogLeNet的一个关键特征是,它的计算成本实际上比其前辈*更低*,同时提供了更高的准确性。这标志着一个更加深思熟虑的网络设计的开始,该设计权衡了评估网络的成本与误差的减少。它也标志着在块级别上对网络设计超参数进行实验的开始,尽管当时完全是手动的。我们将在 8.8节 中讨论网络结构探索策略时重新审视这个主题。
在接下来的章节中,我们将遇到一些设计选择(例如,批量归一化、残差连接和通道分组),这些选择可以让我们显著地改进网络。目前,你可以为实现了可以说是第一个真正现代的CNN而感到自豪。
8.4.5. 练习¶
GoogLeNet非常成功,以至于它经历了几次迭代,逐步提高了速度和准确性。尝试实现并运行其中一些。它们包括以下内容:
添加一个批量归一化层 (Ioffe and Szegedy, 2015),如稍后在 8.5节 中所述。
对Inception块进行调整(宽度、卷积的选择和顺序),如 Szegedy et al. (2016) 中所述。
使用标签平滑进行模型正则化,如 Szegedy et al. (2016) 中所述。
通过添加残差连接对Inception块进行进一步调整 (Szegedy et al., 2017),如稍后在 8.6节 中所述。
GoogLeNet工作所需的最小图像尺寸是多少?
你能设计一个适用于Fashion-MNIST原生分辨率 \(28 \times 28\) 像素的GoogLeNet变体吗?如果需要,你将如何改变网络的主干、主体和头部?
比较AlexNet、VGG、NiN和GoogLeNet的模型参数大小。后两种网络架构如何显著减少模型参数大小?
比较GoogLeNet和AlexNet中所需的计算量。这如何影响加速器芯片的设计,例如在内存大小、内存带宽、缓存大小、计算量以及专用操作的好处方面?