前言
以前用Keras用惯了,fit和fit_generator真的太好使了,模型断电保存搞个checkpoint回调函数就行了。近期使用pytorch进行训练,苦于没有类似的回调函数,写完网络进行训练的时候总不能每次都从头开始训练,于是乎就学了一下pytorch的模型相关操作。
训练过程
ArgumentParser解析器
argparse是一个Python模块:命令行选项、参数和子命令解析器。
主要有三个步骤:
- 创建 ArgumentParser() 对象
- 调用 add_argument() 方法添加参数
- 使用 parse_args() 解析添加的参数
如下:
parser = argparse.ArgumentParser()
parser.add_argument('--train-file', type=str, default='pre/91-image_x2.h5')
parser.add_argument('--eval-file', type=str, default='pre/Set5_x2.h5')
parser.add_argument('--outputs-dir', type=str, default='output/')
parser.add_argument('--weights-file', type=str, default='weight/')
parser.add_argument('--scale', type=int, default=2)
parser.add_argument('--lr', type=float, default=1e-4)
parser.add_argument('--batch-size', type=int, default=1024)
parser.add_argument('--num-epochs', type=int, default=1000)
parser.add_argument('--num-workers', type=int, default=16)
parser.add_argument('--seed', type=int, default=123)
args = parser.parse_args()
使用参数的时候可以通过args.xxx进行调用就行了,这样的好处就是方便统一管理。 如果用编译器运行就得给赋默认值,如果命令行运行,在运行的时候命令后面给出参数就行,有默认值的会进行覆盖。
参数设置
包含网络模型的实例化,损失函数,优化器等一系列操作。
if not os.path.exists(args.outputs_dir):
os.makedirs(args.outputs_dir)
cudnn.benchmark = True
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
torch.manual_seed(args.seed)
model = FSRCNN(scale_factor=args.scale).to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam([
{'params': model.first_part.parameters()},
{'params': model.mid_part.parameters()},
{'params': model.last_part.parameters(), 'lr': args.lr * 0.1}
], lr=args.lr)
模型加载/参数更新
加载上次训练生成的参数文件,通过update操作进行更新,并加载到现有模型中进行训练,这个也就是预训练参数,还要去掉参数中多余的k,v对。
model_dict = model.state_dict()
pre_dict = torch.load('训练参数文件.pth')
pre_dict = {k: v for k, v in pre_dict.items() if k in model_dict}
model_dict.update(pre_dict)
model.load_state_dict(model_dict)
为了统一操作,可以将再训练也加入到 ArgumentParser中去,给一个布尔类型的参数,为True时加载指定权重进行继续训练,为False时就从头开始训练。比如这样:
parser.add_argument('--resume', type=bool, default=True)
parser.add_argument('--resumePath', type=str, default='test/x4/epoch_47_psnr30.44.pth')
if args.resume:
model_dict = model.state_dict()
# pre_dict = torch.load('test/x4/epoch_904_psnr30.35.pth',map_location='cpu')
pre_dict = torch.load(args.resumePath)
pre_dict = {k: v for k, v in pre_dict.items() if k in model_dict}
model_dict.update(pre_dict)
model.load_state_dict(model_dict)
路径比较长的话,因为每次保存的权重路径一致,还可以用字符串进行拼接,也可以直接复制路径进行粘贴。
模型保存
torch.save就可以进行模型的保存,里面传入的参数不一样保存方式就不一样。
- 只保存参数(文件类型随意,pth,pkl,tar均可)
torch.save(model.state_dict(), path)
eg:
torch.save(model.state_dict(),
os.path.join(args.outputs_dir, 'epoch_{}_psnr{:.2f}.pth'.format(best_epoch, best_psnr)))
- 保存多个信息,如参数,优化器,epoch,可以用字典存起来保存
state = {'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': epoch}
torch.save(state, path)
这种方式加载的时候:
checkpoint = torch.load(path)
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
epoch = checkpoint(['epoch'])
torch.save(model, path)
可能报错
RuntimeError: xxx.pth is a zip archive (did you mean to use torch.jit.load()?)
这种类型的错误是因为版本原因导致的,在pytorch1.6及以上,pytorch默认使用zip文件格式来保存权重文件,导致这些权重文件无法直接被1.5及以下的pytorch加载。解决办法要么把pytorch升级到1.6及以上版本,要么就在保存的时候不适用zip格式进行保存。
state_dict = torch.load("xxx.pth")
torch.save(state_dict, "xxx.pth", _use_new_zipfile_serialization=False)
python AttributeError: 'module' object has no attribute 'dumps'解决办法
报这个错就是因为没有去除掉参数文件中多余的键值对,去掉就行。
pre_dict = {k: v for k, v in pre_dict.items() if k in model_dict}
RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is Fal....
pytorch1.6以下版本还不报错,1.6以上就报错了,原因是在加载的时候,没有GPU,只有CPU,必须在load中指明使用CPU才可以。
model = torch.load(model_path, map_location='cpu')
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)