CUDA 与 DataParallel:为什么有区别?

2024-06-19

我有一个简单的神经网络模型,我应用cuda() or DataParallel()在模型上如下所示。

model = torch.nn.DataParallel(model).cuda()

OR,

model = model.cuda()

当我不使用 DataParallel 时,只需将模型转换为cuda(),我需要将批量输入显式转换为cuda()然后将其交给模型,否则返回以下错误。

torch.index_select 收到无效的参数组合 - got (torch.cuda.FloatTensor, int, torch.LongTensor)

但是使用 DataParallel,代码可以正常工作。其余的其他事情都是一样的。为什么会出现这种情况?为什么当我使用 DataParallel 时,我不需要将批量输入显式转换为cuda()?


因为,DataParallel 允许 CPU 输入,因为它的第一步是将输入传输到适当的 GPU。

信息来源:https://discuss.pytorch.org/t/cuda-vs-dataparallel-why-the-difference/4062/3 https://discuss.pytorch.org/t/cuda-vs-dataparallel-why-the-difference/4062/3

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

CUDA 与 DataParallel:为什么有区别? 的相关文章

随机推荐