并不是所有的torch方法都支持半精度计算。测试半精度计算需要在cuda上,cpu不支持半精度。因此首先需要创建半精度变量,并放到cuda设备上。部分方法在低版本不支持,在高版本支持半精度计算,部分方法一直不支持。例如行列式计算torch.linalg.det()不支持半精度。看如下代码:
import torch
a = torch.randn((4,4), dtype=torch.float16).cuda()
b = a.float()
c = b.det()
d = a.det()
c = b.det()是单精度计算,正常;
d = a.det()是半精度计算,出错,***not implemented for 'Half'
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)