13.6. 多GPU的简洁实现
在 Colab 中打开 Notebook
在 Colab 中打开 Notebook
在 Colab 中打开 Notebook
在 Colab 中打开 Notebook
在 SageMaker Studio Lab 中打开 Notebook

为每个新模型从零开始实现并行化是无趣的。此外,优化同步工具以获得高性能有明显的好处。下面,我们将展示如何使用深度学习框架的高级API来实现这一点。数学和算法与 第 13.5 节中的相同。不出所料,您将需要至少两个GPU来运行本节的代码。

import torch
from torch import nn
from d2l import torch as d2l
from mxnet import autograd, gluon, init, np, npx
from mxnet.gluon import nn
from d2l import mxnet as d2l

npx.set_np()

13.6.1. 一个玩具网络

让我们使用一个比 第 13.5 节 中的LeNet更有意义的网络,它仍然足够简单和快速来训练。我们选择ResNet-18的一个变体 (He et al., 2016)。由于输入图像很小,我们对其稍作修改。具体来说,与 第 8.6 节 的不同之处在于,我们在开始时使用了更小的卷积核、步幅和填充。此外,我们移除了最大汇聚层。

#@save
def resnet18(num_classes, in_channels=1):
    """A slightly modified ResNet-18 model."""
    def resnet_block(in_channels, out_channels, num_residuals,
                     first_block=False):
        blk = []
        for i in range(num_residuals):
            if i == 0 and not first_block:
                blk.append(d2l.Residual(out_channels, use_1x1conv=True,
                                        strides=2))
            else:
                blk.append(d2l.Residual(out_channels))
        return nn.Sequential(*blk)

    # This model uses a smaller convolution kernel, stride, and padding and
    # removes the max-pooling layer
    net = nn.Sequential(
        nn.Conv2d(in_channels, 64, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(64),
        nn.ReLU())
    net.add_module("resnet_block1", resnet_block(64, 64, 2, first_block=True))
    net.add_module("resnet_block2", resnet_block(64, 128, 2))
    net.add_module("resnet_block3", resnet_block(128, 256, 2))
    net.add_module("resnet_block4", resnet_block(256, 512, 2))
    net.add_module("global_avg_pool", nn.AdaptiveAvgPool2d((1,1)))
    net.add_module("fc", nn.Sequential(nn.Flatten(),
                                       nn.Linear(512, num_classes)))
    return net
#@save
def resnet18(num_classes):
    """A slightly modified ResNet-18 model."""
    def resnet_block(num_channels, num_residuals, first_block=False):
        blk = nn.Sequential()
        for i in range(num_residuals):
            if i == 0 and not first_block:
                blk.add(d2l.Residual(
                    num_channels, use_1x1conv=True, strides=2))
            else:
                blk.add(d2l.Residual(num_channels))
        return blk

    net = nn.Sequential()
    # This model uses a smaller convolution kernel, stride, and padding and
    # removes the max-pooling layer
    net.add(nn.Conv2D(64, kernel_size=3, strides=1, padding=1),
            nn.BatchNorm(), nn.Activation('relu'))
    net.add(resnet_block(64, 2, first_block=True),
            resnet_block(128, 2),
            resnet_block(256, 2),
            resnet_block(512, 2))
    net.add(nn.GlobalAvgPool2D(), nn.Dense(num_classes))
    return net

13.6.2. 网络初始化

我们将在训练循环中初始化网络。要回顾初始化方法,请参阅 第 5.4 节

net = resnet18(10)
# Get a list of GPUs
devices = d2l.try_all_gpus()
# We will initialize the network inside the training loop

initialize 函数允许我们在选择的设备上初始化参数。要回顾初始化方法,请参阅 第 5.4 节。特别方便的是,它还允许我们同时在*多个*设备上初始化网络。让我们看看这在实践中是如何工作的。

net = resnet18(10)
# Get a list of GPUs
devices = d2l.try_all_gpus()
# Initialize all the parameters of the network
net.initialize(init=init.Normal(sigma=0.01), ctx=devices)
[22:06:27] ../src/storage/storage.cc:196: Using Pooled (Naive) StorageManager for CPU
[22:06:27] ../src/storage/storage.cc:196: Using Pooled (Naive) StorageManager for GPU
[22:06:27] ../src/storage/storage.cc:196: Using Pooled (Naive) StorageManager for GPU

使用在 第 13.5 节 中介绍的 split_and_load 函数,我们可以划分一个小批量数据并将部分复制到由 devices 变量提供的设备列表中。网络实例*自动*使用适当的GPU来计算前向传播的值。这里我们生成4个观测值并将它们分配到各个GPU上。

x = np.random.uniform(size=(4, 1, 28, 28))
x_shards = gluon.utils.split_and_load(x, devices)
net(x_shards[0]), net(x_shards[1])
[22:06:28] ../src/operator/cudnn_ops.cc:318: Auto-tuning cuDNN op, set MXNET_CUDNN_AUTOTUNE_DEFAULT to 0 to disable
[22:06:28] ../src/operator/cudnn_ops.cc:318: Auto-tuning cuDNN op, set MXNET_CUDNN_AUTOTUNE_DEFAULT to 0 to disable
[22:06:28] ../src/operator/cudnn_ops.cc:318: Auto-tuning cuDNN op, set MXNET_CUDNN_AUTOTUNE_DEFAULT to 0 to disable
[22:06:28] ../src/operator/cudnn_ops.cc:318: Auto-tuning cuDNN op, set MXNET_CUDNN_AUTOTUNE_DEFAULT to 0 to disable
(array([[ 2.2610207e-06,  2.2045981e-06, -5.4046786e-06,  1.2869955e-06,
          5.1373163e-06, -3.8297967e-06,  1.4339059e-07,  5.4683451e-06,
         -2.8279192e-06, -3.9651104e-06],
        [ 2.0698672e-06,  2.0084667e-06, -5.6382510e-06,  1.0498458e-06,
          5.5506434e-06, -4.1065491e-06,  6.0830087e-07,  5.4521784e-06,
         -3.7365021e-06, -4.1891640e-06]], ctx=gpu(0)),
 array([[ 2.4629783e-06,  2.6015525e-06, -5.4362617e-06,  1.2938218e-06,
          5.6387889e-06, -4.1360108e-06,  3.5758853e-07,  5.5125256e-06,
         -3.1957325e-06, -4.2976326e-06],
        [ 1.9431673e-06,  2.2600434e-06, -5.2698201e-06,  1.4807417e-06,
          5.4830934e-06, -3.9678889e-06,  7.5751018e-08,  5.6764356e-06,
         -3.2530229e-06, -4.0943951e-06]], ctx=gpu(1)))

一旦数据通过网络,相应的参数就会*在数据通过的设备上*被初始化。这意味着初始化是基于每个设备进行的。由于我们选择了GPU 0和GPU 1进行初始化,网络只在这两个设备上被初始化,而不在CPU上。事实上,这些参数在CPU上甚至不存在。我们可以通过打印出参数并观察可能出现的任何错误来验证这一点。

weight = net[0].params.get('weight')

try:
    weight.data()
except RuntimeError:
    print('not initialized on cpu')
weight.data(devices[0])[0], weight.data(devices[1])[0]
not initialized on cpu
(array([[[ 0.01382882, -0.01183044,  0.01417865],
         [-0.00319718,  0.00439528,  0.02562625],
         [-0.00835081,  0.01387452, -0.01035946]]], ctx=gpu(0)),
 array([[[ 0.01382882, -0.01183044,  0.01417865],
         [-0.00319718,  0.00439528,  0.02562625],
         [-0.00835081,  0.01387452, -0.01035946]]], ctx=gpu(1)))

接下来,让我们用一个可以在多个设备上并行工作的代码来替换评估准确率的代码。这可以替代 第 7.6 节 中的 evaluate_accuracy_gpu 函数。主要区别在于我们在调用网络之前先分割一个小批量。其他部分基本相同。

#@save
def evaluate_accuracy_gpus(net, data_iter, split_f=d2l.split_batch):
    """Compute the accuracy for a model on a dataset using multiple GPUs."""
    # Query the list of devices
    devices = list(net.collect_params().values())[0].list_ctx()
    # No. of correct predictions, no. of predictions
    metric = d2l.Accumulator(2)
    for features, labels in data_iter:
        X_shards, y_shards = split_f(features, labels, devices)
        # Run in parallel
        pred_shards = [net(X_shard) for X_shard in X_shards]
        metric.add(sum(float(d2l.accuracy(pred_shard, y_shard)) for
                       pred_shard, y_shard in zip(
                           pred_shards, y_shards)), labels.size)
    return metric[0] / metric[1]

13.6.3. 训练

和以前一样,训练代码需要执行几个基本功能以实现高效并行化:

  • 网络参数需要在所有设备上进行初始化。

  • 在迭代数据集时,小批量数据需要被分配到所有设备上。

  • 我们在所有设备上并行计算损失及其梯度。

  • 梯度被聚合起来,并相应地更新参数。

最后,我们计算准确率(同样是并行的)来报告网络的最终性能。训练例程与前面章节的实现非常相似,只是我们需要分割和聚合数据。

def train(net, num_gpus, batch_size, lr):
    train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
    devices = [d2l.try_gpu(i) for i in range(num_gpus)]
    def init_weights(module):
        if type(module) in [nn.Linear, nn.Conv2d]:
            nn.init.normal_(module.weight, std=0.01)
    net.apply(init_weights)
    # Set the model on multiple GPUs
    net = nn.DataParallel(net, device_ids=devices)
    trainer = torch.optim.SGD(net.parameters(), lr)
    loss = nn.CrossEntropyLoss()
    timer, num_epochs = d2l.Timer(), 10
    animator = d2l.Animator('epoch', 'test acc', xlim=[1, num_epochs])
    for epoch in range(num_epochs):
        net.train()
        timer.start()
        for X, y in train_iter:
            trainer.zero_grad()
            X, y = X.to(devices[0]), y.to(devices[0])
            l = loss(net(X), y)
            l.backward()
            trainer.step()
        timer.stop()
        animator.add(epoch + 1, (d2l.evaluate_accuracy_gpu(net, test_iter),))
    print(f'test acc: {animator.Y[0][-1]:.2f}, {timer.avg():.1f} sec/epoch '
          f'on {str(devices)}')
def train(num_gpus, batch_size, lr):
    train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
    ctx = [d2l.try_gpu(i) for i in range(num_gpus)]
    net.initialize(init=init.Normal(sigma=0.01), ctx=ctx, force_reinit=True)
    trainer = gluon.Trainer(net.collect_params(), 'sgd',
                            {'learning_rate': lr})
    loss = gluon.loss.SoftmaxCrossEntropyLoss()
    timer, num_epochs = d2l.Timer(), 10
    animator = d2l.Animator('epoch', 'test acc', xlim=[1, num_epochs])
    for epoch in range(num_epochs):
        timer.start()
        for features, labels in train_iter:
            X_shards, y_shards = d2l.split_batch(features, labels, ctx)
            with autograd.record():
                ls = [loss(net(X_shard), y_shard) for X_shard, y_shard
                      in zip(X_shards, y_shards)]
            for l in ls:
                l.backward()
            trainer.step(batch_size)
        npx.waitall()
        timer.stop()
        animator.add(epoch + 1, (evaluate_accuracy_gpus(net, test_iter),))
    print(f'test acc: {animator.Y[0][-1]:.2f}, {timer.avg():.1f} sec/epoch '
          f'on {str(ctx)}')

让我们看看这在实践中是如何工作的。作为热身,我们在单个GPU上训练网络。

train(net, num_gpus=1, batch_size=256, lr=0.1)
test acc: 0.91, 12.2 sec/epoch on [device(type='cuda', index=0)]
../_images/output_multiple-gpus-concise_2e111f_47_1.svg
train(num_gpus=1, batch_size=256, lr=0.1)
test acc: 0.93, 14.0 sec/epoch on [gpu(0)]
../_images/output_multiple-gpus-concise_2e111f_50_1.svg

接下来我们使用2个GPU进行训练。与 第 13.5 节 中评估的LeNet相比,ResNet-18的模型要复杂得多。这正是并行化显示其优势的地方。计算时间明显大于同步参数的时间。这提高了可伸缩性,因为并行化的开销变得不那么重要了。

train(net, num_gpus=2, batch_size=512, lr=0.2)
test acc: 0.73, 7.5 sec/epoch on [device(type='cuda', index=0), device(type='cuda', index=1)]
../_images/output_multiple-gpus-concise_2e111f_56_1.svg
train(num_gpus=2, batch_size=512, lr=0.2)
test acc: 0.92, 11.1 sec/epoch on [gpu(0), gpu(1)]
../_images/output_multiple-gpus-concise_2e111f_59_1.svg

13.6.4. 小结

  • Gluon通过提供上下文列表,为跨多个设备进行模型初始化提供了基本功能。

  • 数据会自动在可以找到数据的设备上进行评估。

  • 在尝试访问某个设备上的参数之前,请务必在该设备上初始化网络。否则会遇到错误。

  • 优化算法会自动在多个GPU上进行聚合。

13.6.5. 练习

  1. 本节使用了ResNet-18。尝试不同的训练周期、批量大小和学习率。使用更多的GPU进行计算。如果用16个GPU(例如,在一个AWS p2.16xlarge实例上)来尝试会发生什么?

  2. 有时,不同的设备提供不同的计算能力。我们可以同时使用GPU和CPU。我们应该如何分配工作?这样做值得吗?为什么值得?为什么不值得?

讨论

  1. 本节使用了ResNet-18。尝试不同的训练周期、批量大小和学习率。使用更多的GPU进行计算。如果用16个GPU(例如,在一个AWS p2.16xlarge实例上)来尝试会发生什么?

  2. 有时,不同的设备提供不同的计算能力。我们可以同时使用GPU和CPU。我们应该如何分配工作?这样做值得吗?为什么值得?为什么不值得?

  3. 如果我们去掉 npx.waitall() 会发生什么?您将如何修改训练,使其并行化时有多达两步的重叠?

讨论