今天pytorch官网更新了pytorch2.0稳定版,迫不及待的我直接更新了,确实像官方所说,只需加入model=torch.compile(model)一行代码即可加速,加入的位置如下。
cpu训练:
model=UNet(deep_supervision=True) model=torch.compile(model)
单卡训练:
model=UNet(deep_supervision=True) model.to(Device) model=torch.compile(model)
多卡训练:
model=UNet(deep_supervision=True) model.to(Device) model=nn.parallel.DistributedDataParallel( model, device_ids=[local_rank], output_device=local_rank, broadcast_buffers=False, ) model=torch.compile(model)
注意 model = torch.compile(model) 这句话的位置对了就可以了,其他的不用改!!
多卡训练官方教程:https://pytorch.org/docs/stable/notes/ddp.html#distributed-data-parallel
torch.compile官方教程:https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html?utm_source=whats_new_tutorials&utm_medium=torch_compile