6.7. GPU¶ 在 SageMaker Studio Lab 中打开 Notebook
在 tab_intro_decade
中,我们⽰范了过去⼆⼗年来计算能力的快速增⻓。简⽽⾔之,自2000年以来,GPU性能每⼗年增⻓1000倍。这提供了巨⼤的机会,但也意味着对这样的性能有很⼤的需求。
在本节中,我们将讨论如何利⽤这种计算性能进⾏研究。⾸先是使⽤单个GPU,稍后将介绍如何使⽤多个GPU和多个服务器(具有多个GPU)。
具体来说,我们将讨论如何使用单个NVIDIA GPU进行计算。首先,请确保你至少安装了一块NVIDIA GPU。然后,下载NVIDIA驱动和CUDA并按照提示设置适当的路径。当这些准备工作完成后,就可以使用nvidia-smi
命令来查看显卡信息。
在PyTorch中,每个数组都有一个设备;我们通常将其称为*上下文*(context)。到目前为止,默认情况下,所有变量和相关的计算都已分配给CPU。通常,其他上下文可能是各种GPU。当我们在多个服务器上部署作业时,事情会变得更加棘手。通过智能地将数组分配给上下文,我们可以最大限度地减少设备之间传输数据所花费的时间。例如,当在具有GPU的服务器上训练神经网络时,我们通常希望模型的参数位于GPU上。
你可能已经注意到,MXNet张量看起来与NumPy ndarray
几乎完全相同。但有几个关键区别。MXNet与NumPy区别开来的一个关键特性是它对不同硬件设备的支持。
在MXNet中,每个数组都有一个上下文。到目前为止,默认情况下,所有变量和相关的计算都已分配给CPU。通常,其他上下文可能是各种GPU。当我们在多个服务器上部署作业时,事情会变得更加棘手。通过智能地将数组分配给上下文,我们可以最大限度地减少设备之间传输数据所花费的时间。例如,当在具有GPU的服务器上训练神经网络时,我们通常希望模型的参数位于GPU上。
接下来,我们需要确认已安装了MXNet的GPU版本。如果已经安装了MXNet的CPU版本,我们需要先卸载它。例如,使用pip uninstall mxnet
命令,然后根据您的CUDA版本安装相应的MXNet版本。假设您已安装CUDA 10.0,您可以通过pip install mxnet-cu100
安装支持CUDA 10.0的MXNet版本。
要运行本节中的程序,您至少需要两个GPU。请注意,这对于大多数台式计算机来说可能过于奢侈,但在云端很容易获得,例如,通过使用AWS EC2多GPU实例。几乎所有其他部分都*不*需要多个GPU,但在这里我们只是想说明不同设备之间的数据流。
import torch
from torch import nn
from d2l import torch as d2l
from mxnet import np, npx
from mxnet.gluon import nn
from d2l import mxnet as d2l
npx.set_np()
import jax
from flax import linen as nn
from jax import numpy as jnp
from d2l import jax as d2l
import tensorflow as tf
from d2l import tensorflow as d2l
6.7.1. 计算设备¶
我们可以为存储和计算指定设备,例如CPU和GPU。默认情况下,张量是在主内存中创建的,然后使用CPU进行计算。
在PyTorch中,CPU和GPU可以通过torch.device('cpu')
和torch.device('cuda')
来表示。需要注意的是,cpu
设备意味着所有的物理CPU和内存。这意味着PyTorch的计算将尝试使用所有CPU核心。但是,一个gpu
设备只代表一张卡和相应的内存。如果有多个GPU,我们使用torch.device(f'cuda:{i}')
来表示第\(i\)块GPU(\(i\)从0开始)。另外,gpu:0
和gpu
是等价的。
def cpu(): #@save
"""Get the CPU device."""
return torch.device('cpu')
def gpu(i=0): #@save
"""Get a GPU device."""
return torch.device(f'cuda:{i}')
cpu(), gpu(), gpu(1)
(device(type='cpu'),
device(type='cuda', index=0),
device(type='cuda', index=1))
在MXNet中,CPU和GPU可以由cpu()
和gpu()
来表示。需要注意的是,cpu()
(或括号中的任何整数)意味着所有物理CPU和内存。这意味着MXNet的计算将尝试使用所有CPU核心。然而,gpu()
只代表一张卡和相应的内存。如果有多个GPU,我们使用gpu(i)
来表示第\(i\)个GPU(\(i\)从0开始)。此外,gpu(0)
和gpu()
是等价的。
def cpu(): #@save
"""Get the CPU device."""
return npx.cpu()
def gpu(i=0): #@save
"""Get a GPU device."""
return npx.gpu(i)
cpu(), gpu(), gpu(1)
(cpu(0), gpu(0), gpu(1))
def cpu(): #@save
"""Get the CPU device."""
return jax.devices('cpu')[0]
def gpu(i=0): #@save
"""Get a GPU device."""
return jax.devices('gpu')[i]
cpu(), gpu(), gpu(1)
(CpuDevice(id=0), gpu(id=0), gpu(id=1))
def cpu(): #@save
"""Get the CPU device."""
return tf.device('/CPU:0')
def gpu(i=0): #@save
"""Get a GPU device."""
return tf.device(f'/GPU:{i}')
cpu(), gpu(), gpu(1)
(<tensorflow.python.eager.context._EagerDeviceContext at 0x7fc696ae8980>,
<tensorflow.python.eager.context._EagerDeviceContext at 0x7fc696c86d80>,
<tensorflow.python.eager.context._EagerDeviceContext at 0x7fc696c5e100>)
我们可以查询可用的GPU数量。
def num_gpus(): #@save
"""Get the number of available GPUs."""
return torch.cuda.device_count()
num_gpus()
2
def num_gpus(): #@save
"""Get the number of available GPUs."""
return npx.num_gpus()
num_gpus()
2
def num_gpus(): #@save
"""Get the number of available GPUs."""
try:
return jax.device_count('gpu')
except:
return 0 # No GPU backend found
num_gpus()
2
def num_gpus(): #@save
"""Get the number of available GPUs."""
return len(tf.config.experimental.list_physical_devices('GPU'))
num_gpus()
2
现在我们定义两个方便的函数,即使请求的GPU不存在,也允许我们运行代码。
def try_gpu(i=0): #@save
"""Return gpu(i) if exists, otherwise return cpu()."""
if num_gpus() >= i + 1:
return gpu(i)
return cpu()
def try_all_gpus(): #@save
"""Return all available GPUs, or [cpu(),] if no GPU exists."""
return [gpu(i) for i in range(num_gpus())]
try_gpu(), try_gpu(10), try_all_gpus()
(device(type='cuda', index=0),
device(type='cpu'),
[device(type='cuda', index=0), device(type='cuda', index=1)])
def try_gpu(i=0): #@save
"""Return gpu(i) if exists, otherwise return cpu()."""
if num_gpus() >= i + 1:
return gpu(i)
return cpu()
def try_all_gpus(): #@save
"""Return all available GPUs, or [cpu(),] if no GPU exists."""
return [gpu(i) for i in range(num_gpus())]
try_gpu(), try_gpu(10), try_all_gpus()
(gpu(0), cpu(0), [gpu(0), gpu(1)])
def try_gpu(i=0): #@save
"""Return gpu(i) if exists, otherwise return cpu()."""
if num_gpus() >= i + 1:
return gpu(i)
return cpu()
def try_all_gpus(): #@save
"""Return all available GPUs, or [cpu(),] if no GPU exists."""
return [gpu(i) for i in range(num_gpus())]
try_gpu(), try_gpu(10), try_all_gpus()
(gpu(id=0), CpuDevice(id=0), [gpu(id=0), gpu(id=1)])
def try_gpu(i=0): #@save
"""Return gpu(i) if exists, otherwise return cpu()."""
if num_gpus() >= i + 1:
return gpu(i)
return cpu()
def try_all_gpus(): #@save
"""Return all available GPUs, or [cpu(),] if no GPU exists."""
return [gpu(i) for i in range(num_gpus())]
try_gpu(), try_gpu(10), try_all_gpus()
(<tensorflow.python.eager.context._EagerDeviceContext at 0x7fc69696ad80>,
<tensorflow.python.eager.context._EagerDeviceContext at 0x7fc69696a500>,
[<tensorflow.python.eager.context._EagerDeviceContext at 0x7fc69696b140>,
<tensorflow.python.eager.context._EagerDeviceContext at 0x7fc69696b8c0>])
6.7.2. 张量和GPU¶
默认情况下,张量是在CPU上创建的。我们可以查询张量所在的设备。
x = torch.tensor([1, 2, 3])
x.device
device(type='cpu')
默认情况下,张量是在CPU上创建的。我们可以查询张量所在的设备。
x = np.array([1, 2, 3])
x.ctx
[22:01:52] ../src/storage/storage.cc:196: Using Pooled (Naive) StorageManager for CPU
cpu(0)
默认情况下,如果GPU/TPU可用,张量将在其上创建,否则如果不可用则使用CPU。我们可以查询张量所在的设备。
x = jnp.array([1, 2, 3])
x.device()
gpu(id=0)
默认情况下,如果GPU/TPU可用,张量将在其上创建,否则如果不可用则使用CPU。我们可以查询张量所在的设备。
x = tf.constant([1, 2, 3])
x.device
'/job:localhost/replica:0/task:0/device:GPU:0'
需要注意的是,无论何时我们想要对多个项进行操作,它们都需要在同一个设备上。例如,如果我们对两个张量求和,我们需要确保两个参数都位于同一个设备上——否则框架将不知道在哪里存储结果,甚至不知道如何决定在哪里执行计算。
6.7.2.1. 在GPU上存储¶
有几种方法可以在GPU上存储张量。例如,我们可以在创建张量时指定存储设备。接下来,我们在第一个gpu
上创建张量变量X
。在GPU上创建的张量只消耗该GPU的内存。我们可以使用nvidia-smi
命令查看GPU内存使用情况。一般来说,我们需要确保我们创建的数据不超过GPU内存限制。
X = torch.ones(2, 3, device=try_gpu())
X
tensor([[1., 1., 1.],
[1., 1., 1.]], device='cuda:0')
X = np.ones((2, 3), ctx=try_gpu())
X
[22:01:53] ../src/storage/storage.cc:196: Using Pooled (Naive) StorageManager for GPU
array([[1., 1., 1.],
[1., 1., 1.]], ctx=gpu(0))
# By default JAX puts arrays to GPUs or TPUs if available
X = jax.device_put(jnp.ones((2, 3)), try_gpu())
X
Array([[1., 1., 1.],
[1., 1., 1.]], dtype=float32)
with try_gpu():
X = tf.ones((2, 3))
X
<tf.Tensor: shape=(2, 3), dtype=float32, numpy=
array([[1., 1., 1.],
[1., 1., 1.]], dtype=float32)>
假设您至少有两个GPU,以下代码将在第二个GPU上创建一个随机张量Y
。
Y = torch.rand(2, 3, device=try_gpu(1))
Y
tensor([[0.0022, 0.5723, 0.2890],
[0.1456, 0.3537, 0.7359]], device='cuda:1')
Y = np.random.uniform(size=(2, 3), ctx=try_gpu(1))
Y
[22:01:54] ../src/storage/storage.cc:196: Using Pooled (Naive) StorageManager for GPU
array([[0.67478997, 0.07540122, 0.9956977 ],
[0.09488854, 0.415456 , 0.11231736]], ctx=gpu(1))
Y = jax.device_put(jax.random.uniform(jax.random.PRNGKey(0), (2, 3)),
try_gpu(1))
Y
Array([[0.57450044, 0.09968603, 0.7419659 ],
[0.8941783 , 0.59656656, 0.45325184]], dtype=float32)
with try_gpu(1):
Y = tf.random.uniform((2, 3))
Y
<tf.Tensor: shape=(2, 3), dtype=float32, numpy=
array([[0.25437534, 0.4355222 , 0.8891233 ],
[0.9142593 , 0.06548178, 0.87763405]], dtype=float32)>
6.7.2.2. 复制¶
如果我们想计算 X + Y
,我们需要决定在哪里执行这个操作。例如,如 图 6.7.1 所示,我们可以将 X
传输到第二个 GPU 并在那里执行操作。*不要*简单地将 X
和 Y
相加,因为这会导致异常。运行时引擎不知道该怎么做:它在同一设备上找不到数据,因此失败了。由于 Y
位于第二个 GPU 上,我们需要先将 X
移动到那里,然后才能将两者相加。
图 6.7.1 复制数据以在同一设备上执行操作。¶
Z = X.cuda(1)
print(X)
print(Z)
tensor([[1., 1., 1.],
[1., 1., 1.]], device='cuda:0')
tensor([[1., 1., 1.],
[1., 1., 1.]], device='cuda:1')
Z = X.copyto(try_gpu(1))
print(X)
print(Z)
[[1. 1. 1.]
[1. 1. 1.]] @gpu(0)
[[1. 1. 1.]
[1. 1. 1.]] @gpu(1)
Z = jax.device_put(X, try_gpu(1))
print(X)
print(Z)
[[1. 1. 1.]
[1. 1. 1.]]
[[1. 1. 1.]
[1. 1. 1.]]
with try_gpu(1):
Z = X
print(X)
print(Z)
tf.Tensor(
[[1. 1. 1.]
[1. 1. 1.]], shape=(2, 3), dtype=float32)
tf.Tensor(
[[1. 1. 1.]
[1. 1. 1.]], shape=(2, 3), dtype=float32)
现在数据(Z
和 Y
)都在同一个GPU上,我们可以将它们相加了。
Y + Z
tensor([[1.0022, 1.5723, 1.2890],
[1.1456, 1.3537, 1.7359]], device='cuda:1')
但是如果你的变量Z
已经存在于你的第二个GPU上呢?如果我们仍然调用Z.cuda(1)
会发生什么?它将返回Z
,而不是创建一个副本并分配新内存。
Z.cuda(1) is Z
True
Y + Z
array([[1.6747899, 1.0754012, 1.9956977],
[1.0948886, 1.415456 , 1.1123173]], ctx=gpu(1))
想象一下,你的变量Z
已经存在于你的第二个GPU上。如果我们仍然调用Z.copyto(gpu(1))
会发生什么?它会创建一个副本并分配新的内存,即使该变量已经存在于所需的设备上。有时,根据我们的代码运行的环境,两个变量可能已经存在于同一个设备上。所以我们只想在变量当前位于不同设备时才进行复制。在这些情况下,我们可以调用as_in_ctx
。如果变量已经存在于指定的设备中,那么这是一个空操作。除非你特别想创建一个副本,否则as_in_ctx
是首选方法。
Z.as_in_ctx(try_gpu(1)) is Z
True
Y + Z
Array([[1.5745004, 1.099686 , 1.7419659],
[1.8941783, 1.5965666, 1.4532518]], dtype=float32)
想象一下,你的变量Z
已经存在于你的第二个GPU上。如果我们仍然在相同的设备范围内调用Z2 = Z
会发生什么?它将返回Z
,而不是创建一个副本并分配新内存。
Z2 = jax.device_put(Z, try_gpu(1))
Z2 is Z
False
Y + Z
<tf.Tensor: shape=(2, 3), dtype=float32, numpy=
array([[1.2543753, 1.4355222, 1.8891233],
[1.9142593, 1.0654818, 1.877634 ]], dtype=float32)>
想象一下,你的变量Z
已经存在于你的第二个GPU上。如果我们仍然在相同的设备范围内调用Z2 = Z
会发生什么?它将返回Z
,而不是创建一个副本并分配新内存。
with try_gpu(1):
Z2 = Z
Z2 is Z
True
6.7.2.3. 旁注¶
人们使用GPU进行机器学习是因为他们期望GPU速度快。但在设备之间传输变量很慢:比计算慢得多。所以我们希望你在做一些慢速操作之前百分之百确定你想这么做。如果深度学习框架只是自动进行复制而不崩溃,那么你可能不会意识到你写了一些慢速的代码。
传输数据不仅慢,而且还使并行化变得更加困难,因为我们必须等待数据被发送(或者更确切地说是被接收),然后才能继续进行更多的操作。这就是为什么复制操作应该非常小心。根据经验,许多小的操作比一个大的操作要差得多。此外,除非你知道你在做什么,否则一次进行几个操作比在代码中穿插许多单个操作要好得多。这是因为如果一个设备必须等待另一个设备才能做其他事情,这些操作可能会阻塞。这有点像在队列中点咖啡,而不是通过电话预订,然后在你到达时发现咖啡已经准备好了。
最后,当我们打印张量或将张量转换为NumPy格式时,如果数据不在主内存中,框架会先将其复制到主内存中,导致额外的传输开销。更糟糕的是,它现在受到可怕的全局解释器锁的限制,这使得所有事情都等待Python完成。
6.7.3. 神经网络和GPU¶
同样,神经网络模型可以指定设备。以下代码将模型参数放在GPU上。
net = nn.Sequential(nn.LazyLinear(1))
net = net.to(device=try_gpu())
net = nn.Sequential()
net.add(nn.Dense(1))
net.initialize(ctx=try_gpu())
net = nn.Sequential([nn.Dense(1)])
key1, key2 = jax.random.split(jax.random.PRNGKey(0))
x = jax.random.normal(key1, (10,)) # Dummy input
params = net.init(key2, x) # Initialization call
strategy = tf.distribute.MirroredStrategy()
with strategy.scope():
net = tf.keras.models.Sequential([
tf.keras.layers.Dense(1)])
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1')
在接下来的章节中,我们将看到更多关于如何在GPU上运行模型的例子,仅仅因为这些模型将变得在计算上更加密集。
例如,当输入是GPU上的张量时,模型将在同一个GPU上计算结果。
net(X)
tensor([[0.7802],
[0.7802]], device='cuda:0', grad_fn=<AddmmBackward0>)
net(X)
array([[0.04995865],
[0.04995865]], ctx=gpu(0))
net.apply(params, x)
Array([-1.2849933], dtype=float32)
net(X)
<tf.Tensor: shape=(2, 1), dtype=float32, numpy=
array([[1.1455073],
[1.1455073]], dtype=float32)>
让我们确认模型参数存储在同一个GPU上。
net[0].weight.data.device
device(type='cuda', index=0)
net[0].weight.data().ctx
gpu(0)
print(jax.tree_util.tree_map(lambda x: x.device(), params))
FrozenDict({
params: {
layers_0: {
bias: gpu(id=0),
kernel: gpu(id=0),
},
},
})
net.layers[0].weights[0].device, net.layers[0].weights[1].device
('/job:localhost/replica:0/task:0/device:GPU:0',
'/job:localhost/replica:0/task:0/device:GPU:0')
让训练器支持GPU。
@d2l.add_to_class(d2l.Trainer) #@save
def __init__(self, max_epochs, num_gpus=0, gradient_clip_val=0):
self.save_hyperparameters()
self.gpus = [d2l.gpu(i) for i in range(min(num_gpus, d2l.num_gpus()))]
@d2l.add_to_class(d2l.Trainer) #@save
def prepare_batch(self, batch):
if self.gpus:
batch = [a.to(self.gpus[0]) for a in batch]
return batch
@d2l.add_to_class(d2l.Trainer) #@save
def prepare_model(self, model):
model.trainer = self
model.board.xlim = [0, self.max_epochs]
if self.gpus:
model.to(self.gpus[0])
self.model = model
@d2l.add_to_class(d2l.Module) #@save
def set_scratch_params_device(self, device):
for attr in dir(self):
a = getattr(self, attr)
if isinstance(a, np.ndarray):
with autograd.record():
setattr(self, attr, a.as_in_ctx(device))
getattr(self, attr).attach_grad()
if isinstance(a, d2l.Module):
a.set_scratch_params_device(device)
if isinstance(a, list):
for elem in a:
elem.set_scratch_params_device(device)
@d2l.add_to_class(d2l.Trainer) #@save
def __init__(self, max_epochs, num_gpus=0, gradient_clip_val=0):
self.save_hyperparameters()
self.gpus = [d2l.gpu(i) for i in range(min(num_gpus, d2l.num_gpus()))]
@d2l.add_to_class(d2l.Trainer) #@save
def prepare_batch(self, batch):
if self.gpus:
batch = [a.as_in_context(self.gpus[0]) for a in batch]
return batch
@d2l.add_to_class(d2l.Trainer) #@save
def prepare_model(self, model):
model.trainer = self
model.board.xlim = [0, self.max_epochs]
if self.gpus:
model.collect_params().reset_ctx(self.gpus[0])
model.set_scratch_params_device(self.gpus[0])
self.model = model
@d2l.add_to_class(d2l.Trainer) #@save
def __init__(self, max_epochs, num_gpus=0, gradient_clip_val=0):
self.save_hyperparameters()
self.gpus = [d2l.gpu(i) for i in range(min(num_gpus, d2l.num_gpus()))]
@d2l.add_to_class(d2l.Trainer) #@save
def prepare_batch(self, batch):
if self.gpus:
batch = [jax.device_put(a, self.gpus[0]) for a in batch]
return batch
简而言之,只要所有数据和参数都在同一个设备上,我们就可以高效地学习模型。在接下来的章节中,我们将看到几个这样的例子。
6.7.4. 小结¶
我们可以为存储和计算指定设备,例如CPU或GPU。默认情况下,数据在主内存中创建,然后使用CPU进行计算。深度学习框架要求用于计算的所有输入数据都在同一个设备上,无论是CPU还是同一个GPU。不小心移动数据可能会导致性能显著下降。一个典型的错误如下:在GPU上为每个小批量计算损失,并将其报告给命令行上的用户(或记录在NumPy ndarray
中)将触发一个全局解释器锁,从而使所有GPU停顿。更好的做法是在GPU内部为日志分配内存,并且只移动较大的日志。