我很好奇transformers.BertModel 的内存使用情况。我想使用预训练模型来转换文本并保存标记 [CLS] 的输出。没有训练,只有推理。
我对 bert 的输入是 511 个 token 。由于批量大小为 16,我的代码内存不足。 GPU 有 32GB 内存。我的问题是如何估计Bert的内存使用量。
奇怪的是,批处理大小为 32 的另一个作业成功完成,设置相同。我的代码如下。
# Create dataloader
bs = 16
train_comb = ConcatDataset([train_data, valid_data])
train_dl = DataLoader(train_comb, sampler=RandomSampler(train_data), batch_size=bs)
model = BertModel.from_pretrained('/my_dir/bert_base_uncased/',
output_attentions=False,
output_hidden_states=False)
model.cuda()
out_list = []
model.eval()
with torch.no_grad():
for d in train_dl:
d = [i.cuda() for i in d]. # d = [input_ids, attention_mask, token_type_ids, labels]
inputs, labels = d[:3], d[3] # input_ids has shape 16 x 511
output = model(*inputs)[0][:, 0, :]
out_list.append(output)
outputs = torch.cat(out_list)
后来我把for循环改成了下面with torch.no_grad():
for d in train_dl:
d = [i.cuda() for i in d[:3]] # don't care about the labels
out_list.append(model(*d)[0][:, 0, :]) # remove the intermediary variables
del d
总而言之,我的问题是:最佳答案
经过一番搜索,发现错误是由将输出附加到 GPU 中的列表引起的。使用以下代码,错误消失了。
with torch.no_grad():
for d in train_dl:
d = [i.cuda() for i in d[:3]]
out_list.append(model(*d)[0][:, 0, :].cpu())
del d
没有 .cpu(),内存不断增加Tensor size: torch.Size([4, 511]), Memory allocated: 418.7685546875MB
Tensor size: torch.Size([4, 768]), Memory allocated: 424.7568359375MB
Tensor size: torch.Size([4, 511]), Memory allocated: 424.7568359375MB
Tensor size: torch.Size([4, 768]), Memory allocated: 430.7451171875MB
Tensor size: torch.Size([4, 511]), Memory allocated: 430.7451171875MB
Tensor size: torch.Size([4, 768]), Memory allocated: 436.7333984375MB
使用 .cpu(),内存不会改变。Tensor size: torch.Size([128, 511]), Memory allocated: 420.21875MB
Tensor size: torch.Size([128, 768]), Memory allocated: 420.21875MB
Tensor size: torch.Size([128, 511]), Memory allocated: 420.21875MB
Tensor size: torch.Size([128, 768]), Memory allocated: 420.21875MB
Tensor size: torch.Size([128, 511]), Memory allocated: 420.21875MB
Tensor size: torch.Size([128, 768]), Memory allocated: 420.21875MB
关于memory-management - 如何计算Bert的内存需求?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/63076190/