7.2. 图像卷积
在 Colab 中打开 Notebook
在 Colab 中打开 Notebook
在 Colab 中打开 Notebook
在 Colab 中打开 Notebook
在 SageMaker Studio Lab 中打开 Notebook

现在我们了解了卷积层在理论上的工作原理,我们准备好看看它们在实践中是如何工作的。基于我们将卷积神经网络作为有效探索图像数据中结构的架构的动机,我们仍以图像作为运行例子。

import torch
from torch import nn
from d2l import torch as d2l
from mxnet import autograd, np, npx
from mxnet.gluon import nn
from d2l import mxnet as d2l

npx.set_np()
import jax
from flax import linen as nn
from jax import numpy as jnp
from d2l import jax as d2l
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
import tensorflow as tf
from d2l import tensorflow as d2l

7.2.1. 互相关运算

回想一下,严格来说,卷积层是个有点用词不当的名字,因为它们所表达的运算更准确地应描述为*互相关*(cross-correlation)。根据我们在 7.1节 中对卷积层的描述,在这样的层中,输入张量和核张量通过互相关运算组合产生输出张量。

我们暂时忽略通道,看看这如何处理二维数据和隐藏表示。在 图 7.2.1中,输入是一个高度为3、宽度为3的二维张量。我们将张量的形状标记为 \(3 \times 3\) 或 (\(3\), \(3\))。核的高度和宽度都是2。*核窗口*(或*卷积窗口*)的形状由核的高度和宽度决定(这里是 \(2 \times 2\))。

../_images/correlation.svg

图 7.2.1 二维互相关运算。阴影部分是第一个输出元素以及用于输出计算的输入和核张量元素:\(0\times0+1\times1+3\times2+4\times3=19\)

在二维互相关运算中,我们从位于输入张量左上角的卷积窗口开始,并从左到右、从上到下滑动它。当卷积窗口滑动到某个位置时,该窗口中包含的输入子张量和核张量按元素相乘,得到的张量求和得到一个标量值。这个结果给出了相应位置的输出张量的值。在这里,输出张量的高度为2,宽度为2,这四个元素是由二维互相关运算得出的:

(7.2.1)\[\begin{split}0\times0+1\times1+3\times2+4\times3=19,\\ 1\times0+2\times1+4\times2+5\times3=25,\\ 3\times0+4\times1+6\times2+7\times3=37,\\ 4\times0+5\times1+7\times2+8\times3=43.\end{split}\]

请注意,沿每个轴,输出大小比输入大小略小。因为核的宽度和高度大于 \(1\),我们只能在核完全位于图像内部的位置正确计算互相关,输出大小由输入大小 \(n_\textrm{h} \times n_\textrm{w}\) 减去卷积核大小 \(k_\textrm{h} \times k_\textrm{w}\) 给出:

(7.2.2)\[(n_\textrm{h}-k_\textrm{h}+1) \times (n_\textrm{w}-k_\textrm{w}+1).\]

这是因为我们需要足够的空间在图像上“移动”卷积核。稍后我们将看到如何通过在图像边界周围填充零来保持大小不变,以便有足够的空间移动核。接下来,我们在 corr2d 函数中实现这个过程,该函数接受一个输入张量 X 和一个核张量 K,并返回一个输出张量 Y

def corr2d(X, K):  #@save
    """Compute 2D cross-correlation."""
    h, w = K.shape
    Y = torch.zeros((X.shape[0] - h + 1, X.shape[1] - w + 1))
    for i in range(Y.shape[0]):
        for j in range(Y.shape[1]):
            Y[i, j] = (X[i:i + h, j:j + w] * K).sum()
    return Y
def corr2d(X, K):  #@save
    """Compute 2D cross-correlation."""
    h, w = K.shape
    Y = np.zeros((X.shape[0] - h + 1, X.shape[1] - w + 1))
    for i in range(Y.shape[0]):
        for j in range(Y.shape[1]):
            Y[i, j] = (X[i:i + h, j:j + w] * K).sum()
    return Y
def corr2d(X, K):  #@save
    """Compute 2D cross-correlation."""
    h, w = K.shape
    Y = jnp.zeros((X.shape[0] - h + 1, X.shape[1] - w + 1))
    for i in range(Y.shape[0]):
        for j in range(Y.shape[1]):
            Y = Y.at[i, j].set((X[i:i + h, j:j + w] * K).sum())
    return Y
def corr2d(X, K):  #@save
    """Compute 2D cross-correlation."""
    h, w = K.shape
    Y = tf.Variable(tf.zeros((X.shape[0] - h + 1, X.shape[1] - w + 1)))
    for i in range(Y.shape[0]):
        for j in range(Y.shape[1]):
            Y[i, j].assign(tf.reduce_sum(
                X[i: i + h, j: j + w] * K))
    return Y

我们可以从图 7.2.1中构造输入张量X和核张量K,以验证上述二维互相关运算实现的输出。

X = torch.tensor([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0]])
K = torch.tensor([[0.0, 1.0], [2.0, 3.0]])
corr2d(X, K)
tensor([[19., 25.],
        [37., 43.]])
X = np.array([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0]])
K = np.array([[0.0, 1.0], [2.0, 3.0]])
corr2d(X, K)
[22:11:22] ../src/storage/storage.cc:196: Using Pooled (Naive) StorageManager for CPU
array([[19., 25.],
       [37., 43.]])
X = jnp.array([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0]])
K = jnp.array([[0.0, 1.0], [2.0, 3.0]])
corr2d(X, K)
Array([[19., 25.],
       [37., 43.]], dtype=float32)
X = tf.constant([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0]])
K = tf.constant([[0.0, 1.0], [2.0, 3.0]])
corr2d(X, K)
<tf.Variable 'Variable:0' shape=(2, 2) dtype=float32, numpy=
array([[19., 25.],
       [37., 43.]], dtype=float32)>

7.2.2. 卷积层

卷积层对输入和核进行互相关运算,并加上一个标量偏置以产生输出。卷积层的两个参数是核和标量偏置。在训练基于卷积层的模型时,我们通常随机初始化核,就像我们对全连接层所做的那样。

我们现在准备好基于上面定义的 corr2d 函数来实现一个二维卷积层。在 __init__ 构造函数方法中,我们将 weightbias 声明为两个模型参数。前向传播方法调用 corr2d 函数并加上偏置。

class Conv2D(nn.Module):
    def __init__(self, kernel_size):
        super().__init__()
        self.weight = nn.Parameter(torch.rand(kernel_size))
        self.bias = nn.Parameter(torch.zeros(1))

    def forward(self, x):
        return corr2d(x, self.weight) + self.bias
class Conv2D(nn.Block):
    def __init__(self, kernel_size, **kwargs):
        super().__init__(**kwargs)
        self.weight = self.params.get('weight', shape=kernel_size)
        self.bias = self.params.get('bias', shape=(1,))

    def forward(self, x):
        return corr2d(x, self.weight.data()) + self.bias.data()
class Conv2D(nn.Module):
    kernel_size: int

    def setup(self):
        self.weight = nn.param('w', nn.initializers.uniform, self.kernel_size)
        self.bias = nn.param('b', nn.initializers.zeros, 1)

    def forward(self, x):
        return corr2d(x, self.weight) + self.bias
class Conv2D(tf.keras.layers.Layer):
    def __init__(self):
        super().__init__()

    def build(self, kernel_size):
        initializer = tf.random_normal_initializer()
        self.weight = self.add_weight(name='w', shape=kernel_size,
                                      initializer=initializer)
        self.bias = self.add_weight(name='b', shape=(1, ),
                                    initializer=initializer)

    def call(self, inputs):
        return corr2d(inputs, self.weight) + self.bias

\(h \times w\) 卷积或 \(h \times w\) 卷积核中,卷积核的高度和宽度分别为 \(h\)\(w\)。我们也简单地将具有 \(h \times w\) 卷积核的卷积层称为 \(h \times w\) 卷积层。

7.2.3. 图像中的物体边缘检测

让我们花点时间来解析一个卷积层的简单应用:通过找到像素变化的位置来检测图像中物体的边缘。首先,我们构造一个 \(6\times 8\) 像素的“图像”。中间四列是黑色(\(0\)),其余是白色(\(1\))。

X = torch.ones((6, 8))
X[:, 2:6] = 0
X
tensor([[1., 1., 0., 0., 0., 0., 1., 1.],
        [1., 1., 0., 0., 0., 0., 1., 1.],
        [1., 1., 0., 0., 0., 0., 1., 1.],
        [1., 1., 0., 0., 0., 0., 1., 1.],
        [1., 1., 0., 0., 0., 0., 1., 1.],
        [1., 1., 0., 0., 0., 0., 1., 1.]])
X = np.ones((6, 8))
X[:, 2:6] = 0
X
array([[1., 1., 0., 0., 0., 0., 1., 1.],
       [1., 1., 0., 0., 0., 0., 1., 1.],
       [1., 1., 0., 0., 0., 0., 1., 1.],
       [1., 1., 0., 0., 0., 0., 1., 1.],
       [1., 1., 0., 0., 0., 0., 1., 1.],
       [1., 1., 0., 0., 0., 0., 1., 1.]])
X = jnp.ones((6, 8))
X = X.at[:, 2:6].set(0)
X
Array([[1., 1., 0., 0., 0., 0., 1., 1.],
       [1., 1., 0., 0., 0., 0., 1., 1.],
       [1., 1., 0., 0., 0., 0., 1., 1.],
       [1., 1., 0., 0., 0., 0., 1., 1.],
       [1., 1., 0., 0., 0., 0., 1., 1.],
       [1., 1., 0., 0., 0., 0., 1., 1.]], dtype=float32)
X = tf.Variable(tf.ones((6, 8)))
X[:, 2:6].assign(tf.zeros(X[:, 2:6].shape))
X
<tf.Variable 'Variable:0' shape=(6, 8) dtype=float32, numpy=
array([[1., 1., 0., 0., 0., 0., 1., 1.],
       [1., 1., 0., 0., 0., 0., 1., 1.],
       [1., 1., 0., 0., 0., 0., 1., 1.],
       [1., 1., 0., 0., 0., 0., 1., 1.],
       [1., 1., 0., 0., 0., 0., 1., 1.],
       [1., 1., 0., 0., 0., 0., 1., 1.]], dtype=float32)>

接下来,我们构造一个高度为1,宽度为2的核K。当我们用输入进行互相关运算时,如果水平相邻的元素相同,则输出为0。否则,输出为非零。注意,这个核是有限差分算子的一个特例。在位置\((i,j)\),它计算\(x_{i,j} - x_{(i+1),j}\),即它计算水平相邻像素值之间的差异。这是水平方向一阶导数的离散近似。毕竟,对于函数\(f(i,j)\),其导数\(-\partial_i f(i,j) = \lim_{\epsilon \to 0} \frac{f(i,j) - f(i+\epsilon,j)}{\epsilon}\)。让我们看看这在实践中是如何工作的。

K = torch.tensor([[1.0, -1.0]])
K = np.array([[1.0, -1.0]])
K = jnp.array([[1.0, -1.0]])
K = tf.constant([[1.0, -1.0]])

我们准备好用参数 X(我们的输入)和 K(我们的核)进行互相关运算。如你所见,我们检测到从白到黑的边缘为 \(1\),从黑到白的边缘为 \(-1\)。所有其他输出的值都为 \(0\)

Y = corr2d(X, K)
Y
tensor([[ 0.,  1.,  0.,  0.,  0., -1.,  0.],
        [ 0.,  1.,  0.,  0.,  0., -1.,  0.],
        [ 0.,  1.,  0.,  0.,  0., -1.,  0.],
        [ 0.,  1.,  0.,  0.,  0., -1.,  0.],
        [ 0.,  1.,  0.,  0.,  0., -1.,  0.],
        [ 0.,  1.,  0.,  0.,  0., -1.,  0.]])
Y = corr2d(X, K)
Y
array([[ 0.,  1.,  0.,  0.,  0., -1.,  0.],
       [ 0.,  1.,  0.,  0.,  0., -1.,  0.],
       [ 0.,  1.,  0.,  0.,  0., -1.,  0.],
       [ 0.,  1.,  0.,  0.,  0., -1.,  0.],
       [ 0.,  1.,  0.,  0.,  0., -1.,  0.],
       [ 0.,  1.,  0.,  0.,  0., -1.,  0.]])
Y = corr2d(X, K)
Y
Array([[ 0.,  1.,  0.,  0.,  0., -1.,  0.],
       [ 0.,  1.,  0.,  0.,  0., -1.,  0.],
       [ 0.,  1.,  0.,  0.,  0., -1.,  0.],
       [ 0.,  1.,  0.,  0.,  0., -1.,  0.],
       [ 0.,  1.,  0.,  0.,  0., -1.,  0.],
       [ 0.,  1.,  0.,  0.,  0., -1.,  0.]], dtype=float32)
Y = corr2d(X, K)
Y
<tf.Variable 'Variable:0' shape=(6, 7) dtype=float32, numpy=
array([[ 0.,  1.,  0.,  0.,  0., -1.,  0.],
       [ 0.,  1.,  0.,  0.,  0., -1.,  0.],
       [ 0.,  1.,  0.,  0.,  0., -1.,  0.],
       [ 0.,  1.,  0.,  0.,  0., -1.,  0.],
       [ 0.,  1.,  0.,  0.,  0., -1.,  0.],
       [ 0.,  1.,  0.,  0.,  0., -1.,  0.]], dtype=float32)>

我们现在可以将核应用于转置后的图像。正如预期的那样,它消失了。核 K 只检测垂直边缘。

corr2d(X.t(), K)
tensor([[0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.]])
corr2d(d2l.transpose(X), K)
array([[0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.]])
corr2d(d2l.transpose(X), K)
Array([[0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.]], dtype=float32)
corr2d(tf.transpose(X), K)
<tf.Variable 'Variable:0' shape=(8, 5) dtype=float32, numpy=
array([[0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.]], dtype=float32)>

7.2.4. 学习一个核

如果我们知道这正是我们想要的,那么用有限差分[1, -1]设计一个边缘检测器是很巧妙的。然而,当我们看更大的核,并考虑连续的卷积层时,可能无法精确地手动指定每个滤波器应该做什么。

现在让我们看看是否能仅通过观察输入-输出对,从X学习生成Y的核。我们首先构造一个卷积层,并将其核初始化为一个随机张量。接下来,在每次迭代中,我们将使用平方误差来比较Y和卷积层的输出。然后我们可以计算梯度来更新核。为了简单起见,在下文中我们使用内置的二维卷积层类,并忽略偏置。

# Construct a two-dimensional convolutional layer with 1 output channel and a
# kernel of shape (1, 2). For the sake of simplicity, we ignore the bias here
conv2d = nn.LazyConv2d(1, kernel_size=(1, 2), bias=False)

# The two-dimensional convolutional layer uses four-dimensional input and
# output in the format of (example, channel, height, width), where the batch
# size (number of examples in the batch) and the number of channels are both 1
X = X.reshape((1, 1, 6, 8))
Y = Y.reshape((1, 1, 6, 7))
lr = 3e-2  # Learning rate

for i in range(10):
    Y_hat = conv2d(X)
    l = (Y_hat - Y) ** 2
    conv2d.zero_grad()
    l.sum().backward()
    # Update the kernel
    conv2d.weight.data[:] -= lr * conv2d.weight.grad
    if (i + 1) % 2 == 0:
        print(f'epoch {i + 1}, loss {l.sum():.3f}')
epoch 2, loss 16.481
epoch 4, loss 5.069
epoch 6, loss 1.794
epoch 8, loss 0.688
epoch 10, loss 0.274
# Construct a two-dimensional convolutional layer with 1 output channel and a
# kernel of shape (1, 2). For the sake of simplicity, we ignore the bias here
conv2d = nn.Conv2D(1, kernel_size=(1, 2), use_bias=False)
conv2d.initialize()

# The two-dimensional convolutional layer uses four-dimensional input and
# output in the format of (example, channel, height, width), where the batch
# size (number of examples in the batch) and the number of channels are both 1
X = X.reshape(1, 1, 6, 8)
Y = Y.reshape(1, 1, 6, 7)
lr = 3e-2  # Learning rate

for i in range(10):
    with autograd.record():
        Y_hat = conv2d(X)
        l = (Y_hat - Y) ** 2
    l.backward()
    # Update the kernel
    conv2d.weight.data()[:] -= lr * conv2d.weight.grad()
    if (i + 1) % 2 == 0:
        print(f'epoch {i + 1}, loss {float(l.sum()):.3f}')
epoch 2, loss 4.949
epoch 4, loss 0.831
epoch 6, loss 0.140
epoch 8, loss 0.024
epoch 10, loss 0.004
[22:11:22] ../src/base.cc:48: GPU context requested, but no GPUs found.
# Construct a two-dimensional convolutional layer with 1 output channel and a
# kernel of shape (1, 2). For the sake of simplicity, we ignore the bias here
conv2d = nn.Conv(1, kernel_size=(1, 2), use_bias=False, padding='VALID')

# The two-dimensional convolutional layer uses four-dimensional input and
# output in the format of (example, height, width, channel), where the batch
# size (number of examples in the batch) and the number of channels are both 1
X = X.reshape((1, 6, 8, 1))
Y = Y.reshape((1, 6, 7, 1))
lr = 3e-2  # Learning rate

params = conv2d.init(jax.random.PRNGKey(d2l.get_seed()), X)

def loss(params, X, Y):
    Y_hat = conv2d.apply(params, X)
    return ((Y_hat - Y) ** 2).sum()

for i in range(10):
    l, grads = jax.value_and_grad(loss)(params, X, Y)
    # Update the kernel
    params = jax.tree_map(lambda p, g: p - lr * g, params, grads)
    if (i + 1) % 2 == 0:
        print(f'epoch {i + 1}, loss {l:.3f}')
epoch 2, loss 23.292
epoch 4, loss 8.424
epoch 6, loss 3.263
epoch 8, loss 1.305
epoch 10, loss 0.529
# Construct a two-dimensional convolutional layer with 1 output channel and a
# kernel of shape (1, 2). For the sake of simplicity, we ignore the bias here
conv2d = tf.keras.layers.Conv2D(1, (1, 2), use_bias=False)

# The two-dimensional convolutional layer uses four-dimensional input and
# output in the format of (example, height, width, channel), where the batch
# size (number of examples in the batch) and the number of channels are both 1
X = tf.reshape(X, (1, 6, 8, 1))
Y = tf.reshape(Y, (1, 6, 7, 1))
lr = 3e-2  # Learning rate

Y_hat = conv2d(X)
for i in range(10):
    with tf.GradientTape(watch_accessed_variables=False) as g:
        g.watch(conv2d.weights[0])
        Y_hat = conv2d(X)
        l = (abs(Y_hat - Y)) ** 2
        # Update the kernel
        update = tf.multiply(lr, g.gradient(l, conv2d.weights[0]))
        weights = conv2d.get_weights()
        weights[0] = conv2d.weights[0] - update
        conv2d.set_weights(weights)
        if (i + 1) % 2 == 0:
            print(f'epoch {i + 1}, loss {tf.reduce_sum(l):.3f}')
epoch 2, loss 16.452
epoch 4, loss 3.301
epoch 6, loss 0.775
epoch 8, loss 0.221
epoch 10, loss 0.074

请注意,经过10次迭代后,误差已降至一个很小的值。现在我们来看看我们学到的核张量。

conv2d.weight.data.reshape((1, 2))
tensor([[ 1.0398, -0.9328]])
conv2d.weight.data().reshape((1, 2))
array([[ 0.9895   , -0.9873705]])
params['params']['kernel'].reshape((1, 2))
Array([[ 1.0637044, -0.9139465]], dtype=float32)
tf.reshape(conv2d.get_weights()[0], (1, 2))
<tf.Tensor: shape=(1, 2), dtype=float32, numpy=array([[ 1.0062983 , -0.95448595]], dtype=float32)>

确实,学习到的核张量与我们之前定义的核张量 K 非常接近。

7.2.5. 互相关和卷积

回想我们在 7.1节 中观察到的互相关和卷积运算之间的对应关系。这里我们继续考虑二维卷积层。如果这样的层执行的是 (7.1.6) 中定义的严格卷积运算而不是互相关运算,会怎么样?为了获得严格*卷积*运算的输出,我们只需要将二维核张量在水平和垂直方向上都翻转,然后与输入张量执行*互相关*运算。

值得注意的是,由于在深度学习中核是从数据中学习的,无论卷积层执行的是严格的卷积运算还是互相关运算,其输出都保持不变。

为了说明这一点,假设一个卷积层执行*互相关*运算,并学习了图 7.2.1中的核,这里表示为矩阵\(\mathbf{K}\)。假设其他条件保持不变,当这个层改为执行严格的*卷积*时,学习到的核\(\mathbf{K}'\)在水平和垂直翻转后将与\(\mathbf{K}\)相同。也就是说,当卷积层对图 7.2.1中的输入和\(\mathbf{K}'\)执行严格*卷积*时,将获得与图 7.2.1中相同的输出(输入和\(\mathbf{K}\)的互相关)。

为了与深度学习文献中的标准术语保持一致,我们将继续将互相关运算称为卷积,尽管严格来说,它们略有不同。此外,我们使用术语*元素*来指代表示层表示或卷积核的任何张量的条目(或分量)。

7.2.6. 特征图和感受野

7.1.4节中所述,图 7.2.1中的卷积层输出有时被称为*特征图*,因为它可以被看作是空间维度(例如,宽度和高度)上到后续层的学习表示(特征)。在卷积神经网络中,对于某一层中的任何元素\(x\),其*感受野*指的是在前向传播过程中可能影响\(x\)计算的所有(来自所有先前层的)元素。注意,感受野可能大于输入的实际大小。

让我们继续使用图 7.2.1来解释感受野。给定\(2 \times 2\)的卷积核,阴影输出元素(值为\(19\))的感受野是输入阴影部分的四个元素。现在让我们将\(2 \times 2\)的输出表示为\(\mathbf{Y}\),并考虑一个更深的CNN,它有一个额外的\(2 \times 2\)卷积层,以\(\mathbf{Y}\)作为输入,输出一个单一元素\(z\)。在这种情况下,\(z\)\(\mathbf{Y}\)上的感受野包括\(\mathbf{Y}\)的所有四个元素,而在输入上的感受野包括所有九个输入元素。因此,当特征图中的任何元素需要更大的感受野来检测更广区域的输入特征时,我们可以构建一个更深的网络。

感受野的名称来源于神经生理学。一系列在不同动物身上使用不同刺激的实验(Hubel and Wiesel, 1959, Hubel and Wiesel, 1962, Hubel and Wiesel, 1968)探索了所谓的视觉皮层对这些刺激的反应。总的来说,他们发现较低层次对边缘和相关形状有反应。后来,Field (1987)用只能称之为卷积核的东西在自然图像上展示了这种效应。我们在图 7.2.2中重印了一个关键的图来说明这些惊人的相似之处。

../_images/field-visual.png

图 7.2.2 图和标题摘自 Field (1987):一个使用六个不同通道进行编码的例子。(左) 与每个通道相关的六种传感器的示例。(右) (中)中的图像与(左)中所示的六个传感器的卷积。单个传感器的响应是通过以与传感器大小成比例的距离对这些过滤后的图像进行采样来确定的(用点表示)。此图仅显示了偶对称传感器的响应。

事实证明,这种关系甚至适用于在图像分类任务上训练的更深层网络计算出的特征,例如,Kuzovkin *et al.* (2018)所展示的。可以说,卷积已被证明是计算机视觉中一个极其强大的工具,无论是在生物学还是在代码中。因此,它们预示了深度学习近期的成功,这(事后看来)并不奇怪。

7.2.7. 总结

卷积层所需的核心计算是互相关运算。我们看到,一个简单的嵌套for循环就足以计算其值。如果我们有多个输入和多个输出通道,我们正在执行通道间的矩阵-矩阵运算。可以看出,计算是直接的,最重要的是,高度*局部化*。这为显著的硬件优化提供了可能,许多最近在计算机视觉领域的成果之所以可能,正是因为这一点。毕竟,这意味着芯片设计者在优化卷积时可以投资于快速计算而不是内存。虽然这可能不会为其他应用带来最优设计,但它确实为普及且经济实惠的计算机视觉打开了大门。

就卷积本身而言,它们可以用于多种目的,例如检测边缘和线条、模糊图像或锐化图像。最重要的是,统计学家(或工程师)不必发明合适的滤波器。相反,我们可以简单地从数据中*学习*它们。这用基于证据的统计学取代了特征工程的启发式方法。最后,而且相当令人愉快的是,这些滤波器不仅对于构建深度网络有利,而且它们还对应于大脑中的感受野和特征图。这让我们相信我们走在正确的轨道上。

7.2.8. 练习

  1. 构造一个具有对角线边缘的图像X

    1. 如果将本节中的核K应用于它,会发生什么?

    2. 如果对X进行转置会发生什么?

    3. 如果对K进行转置会发生什么?

  2. 手动设计一些核。

    1. 给定一个方向向量\(\mathbf{v} = (v_1, v_2)\),推导一个边缘检测核,用于检测与\(\mathbf{v}\)正交的边缘,即方向为\((v_2, -v_1)\)的边缘。

    2. 推导二阶导数的有限差分算子。与其相关的卷积核的最小尺寸是多少?图像中的哪些结构对其响应最强烈?

    3. 你会如何设计一个模糊核?你为什么可能想要使用这样的核?

    4. 要获得\(d\)阶导数,核的最小尺寸是多少?

  3. 当你尝试为我们创建的Conv2D类自动查找梯度时,你会看到什么样的错误消息?

  4. 你如何通过改变输入和核张量,将互相关运算表示为矩阵乘法?