在服务器上用的GPU训练,然后保存模型,本地测试的时候load模型遇到的问题。
应该是GPU所用的卡号不同,服务器是’cuda:1',本地是'cuda:0',所以会遇到这个问题,但其实在服务器上直接运行相同的测试代码是没有问题的。但也解决一下吧。
原来代码:
rnn = torch.load('lstm-classification.pth')
改为:
rnn = torch.load('lstm-classification.pth',map_location='cuda:0')
如果是多张卡,如3卡到2卡,就在后面加上,变成:map_location={'cuda:1': 'cuda:0'}