20.1. 生成对抗网络
在 Colab 中打开 Notebook
在 Colab 中打开 Notebook
在 Colab 中打开 Notebook
在 Colab 中打开 Notebook
在 SageMaker Studio Lab 中打开 Notebook

在本书的大部分内容中,我们都在讨论如何做出预测。我们以某种形式使用深度神经网络来学习从数据样本到标签的映射。这种学习被称为判别式学习(discriminative learning),就像我们希望能够区分猫和狗的照片一样。分类器和回归器都是判别式学习的例子。在大型复杂数据集上,通过反向传播训练的神经网络颠覆了我们对判别式学习的所有认知。高分辨率图像的分类准确率在短短5-6年内就从毫无用处提升到了人类水平(有一些注意事项)。我们将不再赘述深度神经网络在所有其他判别式任务中表现惊人的例子。

但机器学习不仅仅是解决判别式任务。例如,给定一个没有任何标签的大型数据集,我们可能希望学习一个能简洁地捕捉这些数据特征的模型。有了这样的模型,我们就可以采样出类似于训练数据分布的合成数据样本。例如,给定一个大型人脸照片语料库,我们可能希望能够生成一张新的、看起来像是从同一数据集中 plausibly 得到的逼真图像。这种学习方式被称为生成式建模(generative modeling)。

直到最近,我们还没有方法可以合成新颖的逼真图像。但是,深度神经网络在判别式学习方面的成功开辟了新的可能性。过去三年的一大趋势是,应用判别式深度网络来克服我们通常不认为是监督学习问题的难题。循环神经网络语言模型就是一个例子,它使用一个判别式网络(训练来预测下一个字符)在训练后可以作为一个生成模型。

2014年,一篇突破性的论文介绍了生成对抗网络(GANs)Goodfellow et al., 2014,这是一种巧妙的新方法,利用判别式模型的力量来获得好的生成式模型。GANs的核心思想是,如果我们无法区分假数据和真数据,那么数据生成器就是好的。在统计学中,这被称为双样本检验——一个用来回答数据集 \(X=\{x_1,\ldots, x_n\}\)\(X'=\{x'_1,\ldots, x'_n\}\) 是否来自同一分布的问题的检验。大多数统计学论文和GANs之间的主要区别在于,后者以一种建设性的方式使用这个想法。换句话说,他们不仅仅是训练一个模型来说“嘿,这两个数据集看起来不像是来自同一个分布”,而是使用双样本检验为生成模型提供训练信号。这使我们能够改进数据生成器,直到它生成的东西类似于真实数据。至少,它需要能骗过分类器,即使我们的分类器是目前最先进的深度神经网络。

../_images/gan.svg

图 20.1.1 生成对抗网络

GAN的架构如 图 20.1.1 所示。如你所见,GAN架构中有两个部分——首先,我们需要一个设备(比如一个深度网络,但它实际上可以是任何东西,比如一个游戏渲染引擎),它可能能够生成看起来和真实数据一样的数据。如果我们处理的是图像,它需要生成图像。如果我们处理的是语音,它需要生成音频序列,等等。我们称之为生成器网络。第二个组件是判别器网络。它试图区分假数据和真实数据。这两个网络相互竞争。生成器网络试图欺骗判别器网络。此时,判别器网络适应新的假数据。这个信息反过来又被用来改进生成器网络,依此类推。

判别器是一个二元分类器,用于区分输入\(x\)是真实的(来自真实数据)还是伪造的(来自生成器)。通常,判别器对输入\(\mathbf x\)输出一个标量预测\(o\in\mathbb R\),例如使用一个隐藏大小为1的全连接层,然后应用sigmoid函数得到预测概率\(D(\mathbf x) = 1/(1+e^{-o})\)。假设真实数据的标签\(y\)\(1\),伪造数据的标签为\(0\)。我们训练判别器以最小化交叉熵损失,*即*,

(20.1.1)\[\min_D \{ - y \log D(\mathbf x) - (1-y)\log(1-D(\mathbf x)) \},\]

对于生成器,它首先从一个随机源(*例如*,正态分布 \(\mathbf z \sim \mathcal{N} (0, 1)\))中抽取一些参数 \(\mathbf z\in\mathbb R^d\)。我们通常称 \(\mathbf z\) 为潜变量。然后它应用一个函数来生成 \(\mathbf x'=G(\mathbf z)\)。生成器的目标是欺骗判别器将 \(\mathbf x'=G(\mathbf z)\) 分类为真实数据,*即*,我们希望 \(D( G(\mathbf z)) \approx 1\)。换句话说,对于一个给定的判别器 \(D\),我们更新生成器 \(G\) 的参数来最大化当 \(y=0\) 时的交叉熵损失,*即*,

(20.1.2)\[\max_G \{ - (1-y) \log(1-D(G(\mathbf z))) \} = \max_G \{ - \log(1-D(G(\mathbf z))) \}.\]

如果生成器表现完美,那么 \(D(\mathbf x')\approx 1\),因此上述损失接近0,这会导致梯度过小,无法为判别器带来好的进展。所以通常情况下,我们最小化以下损失:

(20.1.3)\[\min_G \{ - y \log(D(G(\mathbf z))) \} = \min_G \{ - \log(D(G(\mathbf z))) \},\]

这只是将 \(\mathbf x'=G(\mathbf z)\) 输入到判别器中,但给定的标签是 \(y=1\)

总而言之,\(D\)\(G\) 正在玩一个“minimax”博弈,其综合目标函数为:

(20.1.4)\[\min_D \max_G \{ -E_{x \sim \textrm{Data}} \log D(\mathbf x) - E_{z \sim \textrm{Noise}} \log(1 - D(G(\mathbf z))) \}.\]

许多GAN的应用都在图像领域。作为演示目的,我们将首先满足于拟合一个简单得多的分布。我们将展示如果使用GAN来构建世界上最低效的高斯分布参数估计器会发生什么。让我们开始吧。

%matplotlib inline
import torch
from torch import nn
from d2l import torch as d2l
%matplotlib inline
from mxnet import autograd, gluon, init, np, npx
from mxnet.gluon import nn
from d2l import mxnet as d2l

npx.set_np()
import tensorflow as tf
from d2l import tensorflow as d2l

20.1.1. 生成一些“真实”数据

因为这将是世界上最无聊的例子,我们简单地生成从高斯分布中抽取的数据。

X = torch.normal(0.0, 1, (1000, 2))
A = torch.tensor([[1, 2], [-0.1, 0.5]])
b = torch.tensor([1, 2])
data = torch.matmul(X, A) + b
X = np.random.normal(0.0, 1, (1000, 2))
A = np.array([[1, 2], [-0.1, 0.5]])
b = np.array([1, 2])
data = np.dot(X, A) + b
[21:54:27] ../src/storage/storage.cc:196: Using Pooled (Naive) StorageManager for CPU
X = tf.random.normal((1000, 2), 0.0, 1)
A = tf.constant([[1, 2], [-0.1, 0.5]])
b = tf.constant([1, 2], tf.float32)
data = tf.matmul(X, A) + b

让我们看看我们得到了什么。这应该是一个以某种相当随意的方式平移的高斯分布,其均值为\(b\),协方差矩阵为\(A^TA\)

d2l.set_figsize()
d2l.plt.scatter(data[:100, (0)].detach().numpy(), data[:100, (1)].detach().numpy());
print(f'The covariance matrix is\n{torch.matmul(A.T, A)}')
The covariance matrix is
tensor([[1.0100, 1.9500],
        [1.9500, 4.2500]])
../_images/output_gan_bc32f8_27_1.svg
batch_size = 8
data_iter = d2l.load_array((data,), batch_size)
d2l.set_figsize()
d2l.plt.scatter(data[:100, (0)].asnumpy(), data[:100, (1)].asnumpy());
print(f'The covariance matrix is\n{np.dot(A.T, A)}')
The covariance matrix is
[[1.01 1.95]
 [1.95 4.25]]
../_images/output_gan_bc32f8_31_1.svg
batch_size = 8
data_iter = d2l.load_array((data,), batch_size)
d2l.set_figsize()
d2l.plt.scatter(data[:100, 0].numpy(), data[:100, 1].numpy());
print(f'The covariance matrix is\n{tf.matmul(A, A, transpose_a=True)}')
The covariance matrix is
[[1.01 1.95]
 [1.95 4.25]]
../_images/output_gan_bc32f8_35_1.svg
batch_size = 8
data_iter = d2l.load_array((data,), batch_size)

20.1.2. 生成器

我们的生成器网络将是可能的最简单的网络——一个单层线性模型。这是因为我们将用一个高斯数据生成器来驱动那个线性网络。因此,它实际上只需要学习参数就能完美地伪造数据。

net_G = nn.Sequential(nn.Linear(2, 2))
net_G = nn.Sequential()
net_G.add(nn.Dense(2))
net_G = tf.keras.layers.Dense(2)

20.1.3. 判别器

对于判别器,我们将更具辨别力:我们将使用一个带3层的MLP,让事情变得更有趣一些。

net_D = nn.Sequential(
    nn.Linear(2, 5), nn.Tanh(),
    nn.Linear(5, 3), nn.Tanh(),
    nn.Linear(3, 1))
net_D = nn.Sequential()
net_D.add(nn.Dense(5, activation='tanh'),
          nn.Dense(3, activation='tanh'),
          nn.Dense(1))
net_D = tf.keras.models.Sequential([
    tf.keras.layers.Dense(5, activation="tanh", input_shape=(2,)),
    tf.keras.layers.Dense(3, activation="tanh"),
    tf.keras.layers.Dense(1)
])

20.1.4. 训练

首先我们定义一个更新判别器的函数。

#@save
def update_D(X, Z, net_D, net_G, loss, trainer_D):
    """Update discriminator."""
    batch_size = X.shape[0]
    ones = torch.ones((batch_size,), device=X.device)
    zeros = torch.zeros((batch_size,), device=X.device)
    trainer_D.zero_grad()
    real_Y = net_D(X)
    fake_X = net_G(Z)
    # Do not need to compute gradient for `net_G`, detach it from
    # computing gradients.
    fake_Y = net_D(fake_X.detach())
    loss_D = (loss(real_Y, ones.reshape(real_Y.shape)) +
              loss(fake_Y, zeros.reshape(fake_Y.shape))) / 2
    loss_D.backward()
    trainer_D.step()
    return loss_D
#@save
def update_D(X, Z, net_D, net_G, loss, trainer_D):
    """Update discriminator."""
    batch_size = X.shape[0]
    ones = np.ones((batch_size,), ctx=X.ctx)
    zeros = np.zeros((batch_size,), ctx=X.ctx)
    with autograd.record():
        real_Y = net_D(X)
        fake_X = net_G(Z)
        # Do not need to compute gradient for `net_G`, detach it from
        # computing gradients.
        fake_Y = net_D(fake_X.detach())
        loss_D = (loss(real_Y, ones) + loss(fake_Y, zeros)) / 2
    loss_D.backward()
    trainer_D.step(batch_size)
    return float(loss_D.sum())
#@save
def update_D(X, Z, net_D, net_G, loss, optimizer_D):
    """Update discriminator."""
    batch_size = X.shape[0]
    ones = tf.ones((batch_size,)) # Labels corresponding to real data
    zeros = tf.zeros((batch_size,)) # Labels corresponding to fake data
    # Do not need to compute gradient for `net_G`, so it is outside GradientTape
    fake_X = net_G(Z)
    with tf.GradientTape() as tape:
        real_Y = net_D(X)
        fake_Y = net_D(fake_X)
        # We multiply the loss by batch_size to match PyTorch's BCEWithLogitsLoss
        loss_D = (loss(ones, tf.squeeze(real_Y)) + loss(
            zeros, tf.squeeze(fake_Y))) * batch_size / 2
    grads_D = tape.gradient(loss_D, net_D.trainable_variables)
    optimizer_D.apply_gradients(zip(grads_D, net_D.trainable_variables))
    return loss_D

生成器的更新方式类似。这里我们重用交叉熵损失,但将假数据的标签从\(0\)改为\(1\)

#@save
def update_G(Z, net_D, net_G, loss, trainer_G):
    """Update generator."""
    batch_size = Z.shape[0]
    ones = torch.ones((batch_size,), device=Z.device)
    trainer_G.zero_grad()
    # We could reuse `fake_X` from `update_D` to save computation
    fake_X = net_G(Z)
    # Recomputing `fake_Y` is needed since `net_D` is changed
    fake_Y = net_D(fake_X)
    loss_G = loss(fake_Y, ones.reshape(fake_Y.shape))
    loss_G.backward()
    trainer_G.step()
    return loss_G
#@save
def update_G(Z, net_D, net_G, loss, trainer_G):
    """Update generator."""
    batch_size = Z.shape[0]
    ones = np.ones((batch_size,), ctx=Z.ctx)
    with autograd.record():
        # We could reuse `fake_X` from `update_D` to save computation
        fake_X = net_G(Z)
        # Recomputing `fake_Y` is needed since `net_D` is changed
        fake_Y = net_D(fake_X)
        loss_G = loss(fake_Y, ones)
    loss_G.backward()
    trainer_G.step(batch_size)
    return float(loss_G.sum())
#@save
def update_G(Z, net_D, net_G, loss, optimizer_G):
    """Update generator."""
    batch_size = Z.shape[0]
    ones = tf.ones((batch_size,))
    with tf.GradientTape() as tape:
        # We could reuse `fake_X` from `update_D` to save computation
        fake_X = net_G(Z)
        # Recomputing `fake_Y` is needed since `net_D` is changed
        fake_Y = net_D(fake_X)
        # We multiply the loss by batch_size to match PyTorch's BCEWithLogits loss
        loss_G = loss(ones, tf.squeeze(fake_Y)) * batch_size
    grads_G = tape.gradient(loss_G, net_G.trainable_variables)
    optimizer_G.apply_gradients(zip(grads_G, net_G.trainable_variables))
    return loss_G

判别器和生成器都使用交叉熵损失进行二元逻辑回归。我们使用Adam来平滑训练过程。在每次迭代中,我们首先更新判别器,然后更新生成器。我们可视化损失和生成的样本。

def train(net_D, net_G, data_iter, num_epochs, lr_D, lr_G, latent_dim, data):
    loss = nn.BCEWithLogitsLoss(reduction='sum')
    for w in net_D.parameters():
        nn.init.normal_(w, 0, 0.02)
    for w in net_G.parameters():
        nn.init.normal_(w, 0, 0.02)
    trainer_D = torch.optim.Adam(net_D.parameters(), lr=lr_D)
    trainer_G = torch.optim.Adam(net_G.parameters(), lr=lr_G)
    animator = d2l.Animator(xlabel='epoch', ylabel='loss',
                            xlim=[1, num_epochs], nrows=2, figsize=(5, 5),
                            legend=['discriminator', 'generator'])
    animator.fig.subplots_adjust(hspace=0.3)
    for epoch in range(num_epochs):
        # Train one epoch
        timer = d2l.Timer()
        metric = d2l.Accumulator(3)  # loss_D, loss_G, num_examples
        for (X,) in data_iter:
            batch_size = X.shape[0]
            Z = torch.normal(0, 1, size=(batch_size, latent_dim))
            metric.add(update_D(X, Z, net_D, net_G, loss, trainer_D),
                       update_G(Z, net_D, net_G, loss, trainer_G),
                       batch_size)
        # Visualize generated examples
        Z = torch.normal(0, 1, size=(100, latent_dim))
        fake_X = net_G(Z).detach().numpy()
        animator.axes[1].cla()
        animator.axes[1].scatter(data[:, 0], data[:, 1])
        animator.axes[1].scatter(fake_X[:, 0], fake_X[:, 1])
        animator.axes[1].legend(['real', 'generated'])
        # Show the losses
        loss_D, loss_G = metric[0]/metric[2], metric[1]/metric[2]
        animator.add(epoch + 1, (loss_D, loss_G))
    print(f'loss_D {loss_D:.3f}, loss_G {loss_G:.3f}, '
          f'{metric[2] / timer.stop():.1f} examples/sec')
def train(net_D, net_G, data_iter, num_epochs, lr_D, lr_G, latent_dim, data):
    loss = gluon.loss.SigmoidBCELoss()
    net_D.initialize(init=init.Normal(0.02), force_reinit=True)
    net_G.initialize(init=init.Normal(0.02), force_reinit=True)
    trainer_D = gluon.Trainer(net_D.collect_params(),
                              'adam', {'learning_rate': lr_D})
    trainer_G = gluon.Trainer(net_G.collect_params(),
                              'adam', {'learning_rate': lr_G})
    animator = d2l.Animator(xlabel='epoch', ylabel='loss',
                            xlim=[1, num_epochs], nrows=2, figsize=(5, 5),
                            legend=['discriminator', 'generator'])
    animator.fig.subplots_adjust(hspace=0.3)
    for epoch in range(num_epochs):
        # Train one epoch
        timer = d2l.Timer()
        metric = d2l.Accumulator(3)  # loss_D, loss_G, num_examples
        for X in data_iter:
            batch_size = X.shape[0]
            Z = np.random.normal(0, 1, size=(batch_size, latent_dim))
            metric.add(update_D(X, Z, net_D, net_G, loss, trainer_D),
                       update_G(Z, net_D, net_G, loss, trainer_G),
                       batch_size)
        # Visualize generated examples
        Z = np.random.normal(0, 1, size=(100, latent_dim))
        fake_X = net_G(Z).asnumpy()
        animator.axes[1].cla()
        animator.axes[1].scatter(data[:, 0], data[:, 1])
        animator.axes[1].scatter(fake_X[:, 0], fake_X[:, 1])
        animator.axes[1].legend(['real', 'generated'])
        # Show the losses
        loss_D, loss_G = metric[0]/metric[2], metric[1]/metric[2]
        animator.add(epoch + 1, (loss_D, loss_G))
    print(f'loss_D {loss_D:.3f}, loss_G {loss_G:.3f}, '
          f'{metric[2] / timer.stop():.1f} examples/sec')
def train(net_D, net_G, data_iter, num_epochs, lr_D, lr_G, latent_dim, data):
    loss = tf.keras.losses.BinaryCrossentropy(
        from_logits=True, reduction=tf.keras.losses.Reduction.SUM)
    for w in net_D.trainable_variables:
        w.assign(tf.random.normal(mean=0, stddev=0.02, shape=w.shape))
    for w in net_G.trainable_variables:
        w.assign(tf.random.normal(mean=0, stddev=0.02, shape=w.shape))
    optimizer_D = tf.keras.optimizers.Adam(learning_rate=lr_D)
    optimizer_G = tf.keras.optimizers.Adam(learning_rate=lr_G)
    animator = d2l.Animator(
        xlabel="epoch", ylabel="loss", xlim=[1, num_epochs], nrows=2,
        figsize=(5, 5), legend=["discriminator", "generator"])
    animator.fig.subplots_adjust(hspace=0.3)
    for epoch in range(num_epochs):
        # Train one epoch
        timer = d2l.Timer()
        metric = d2l.Accumulator(3)  # loss_D, loss_G, num_examples
        for (X,) in data_iter:
            batch_size = X.shape[0]
            Z = tf.random.normal(
                mean=0, stddev=1, shape=(batch_size, latent_dim))
            metric.add(update_D(X, Z, net_D, net_G, loss, optimizer_D),
                       update_G(Z, net_D, net_G, loss, optimizer_G),
                       batch_size)
        # Visualize generated examples
        Z = tf.random.normal(mean=0, stddev=1, shape=(100, latent_dim))
        fake_X = net_G(Z)
        animator.axes[1].cla()
        animator.axes[1].scatter(data[:, 0], data[:, 1])
        animator.axes[1].scatter(fake_X[:, 0], fake_X[:, 1])
        animator.axes[1].legend(["real", "generated"])

        # Show the losses
        loss_D, loss_G = metric[0] / metric[2], metric[1] / metric[2]
        animator.add(epoch + 1, (loss_D, loss_G))

    print(f'loss_D {loss_D:.3f}, loss_G {loss_G:.3f}, '
          f'{metric[2] / timer.stop():.1f} examples/sec')

现在我们指定超参数来拟合高斯分布。

lr_D, lr_G, latent_dim, num_epochs = 0.05, 0.005, 2, 20
train(net_D, net_G, data_iter, num_epochs, lr_D, lr_G,
      latent_dim, data[:100].detach().numpy())
loss_D 0.693, loss_G 0.693, 1020.0 examples/sec
../_images/output_gan_bc32f8_102_1.svg
lr_D, lr_G, latent_dim, num_epochs = 0.05, 0.005, 2, 20
train(net_D, net_G, data_iter, num_epochs, lr_D, lr_G,
      latent_dim, data[:100].asnumpy())
loss_D 0.693, loss_G 0.693, 265.2 examples/sec
../_images/output_gan_bc32f8_105_1.svg
lr_D, lr_G, latent_dim, num_epochs = 0.05, 0.005, 2, 20
train(net_D, net_G, data_iter, num_epochs, lr_D, lr_G,
      latent_dim, data[:100].numpy())
loss_D 0.693, loss_G 0.693, 167.5 examples/sec
../_images/output_gan_bc32f8_108_1.svg

20.1.5. 小结

  • 生成对抗网络(GAN)由两个深度网络组成,即生成器和判别器。

  • 生成器生成尽可能接近真实图像的图像以欺骗判别器,通过最大化交叉熵损失,*即* \(\max \log(D(\mathbf{x'}))\)

  • 判别器试图区分生成的图像和真实的图像,通过最小化交叉熵损失,*即* \(\min - y \log D(\mathbf{x}) - (1-y)\log(1-D(\mathbf{x}))\)

20.1.6. 练习

  • 是否存在一个生成器获胜的平衡点,*即* 判别器最终无法在有限样本上区分这两种分布?