如何在 pytorch 中正确保存 torch.nn.Sequential 模型?

2023-12-09

我非常清楚加载字典,然后使用旧的参数字典加载一个实例(例如这个很棒的问题和答案)。不幸的是,当我有一个torch.nn.Sequential我当然没有它的类定义。

所以我想仔细检查一下,正确的方法是什么。我相信torch.save就足够了(到目前为止我的代码还没有崩溃),尽管这些事情可能比人们想象的更微妙(例如,当我使用 pickle 时,我收到一条警告,但是torch.save在内部使用它,所以很混乱)。此外,numpy 有它自己的保存函数(例如,参见这个答案)这往往更有效,因此可能存在我可能会忽略的微妙权衡。


我的测试代码:


# creating data and running through a nn and saving it

import torch
import torch.nn as nn

from pathlib import Path
from collections import OrderedDict

import numpy as np

import pickle

path = Path('~/data/tmp/').expanduser()
path.mkdir(parents=True, exist_ok=True)

num_samples = 3
Din, Dout = 1, 1
lb, ub = -1, 1

x = torch.torch.distributions.Uniform(low=lb, high=ub).sample((num_samples, Din))

f = nn.Sequential(OrderedDict([
    ('f1', nn.Linear(Din,Dout)),
    ('out', nn.SELU())
]))
y = f(x)

# save data torch to numpy
x_np, y_np = x.detach().cpu().numpy(), y.detach().cpu().numpy()
np.savez(path / 'db', x=x_np, y=y_np)

print(x_np)
# save model
with open('db_saving_seq', 'wb') as file:
    pickle.dump({'f': f}, file)

# load model
with open('db_saving_seq', 'rb') as file:
    db = pickle.load(file)
    f2 = db['f']

# test that it outputs the right thing
y2 = f2(x)

y_eq_y2 = y == y2
print(y_eq_y2)

db2 = {'f': f, 'x': x, 'y': y}
torch.save(db2, path / 'db_f_x_y')

print('Done')

db3 = torch.load(path / 'db_f_x_y')
f3 = db3['f']
x3 = db3['x']
y3 = db3['y']
yy3 = f3(x3)

y_eq_y3 = y == y3
print(y_eq_y3)

y_eq_yy3 = y == yy3
print(y_eq_yy3)

Related:

  • 论坛的相关问题:https://discuss.pytorch.org/t/how-to-save-nn-sequential-as-a-model/89117/14

从代码中可以看出torch.nn.Sequential是基于torch.nn.Module: https://pytorch.org/docs/stable/_modules/torch/nn/modules/container.html#Sequential

所以你可以使用

f = torch.nn.Sequential(...)
torch.save(f.state_dict(), path)

就像其他任何事情一样torch.nn.Module.

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

如何在 pytorch 中正确保存 torch.nn.Sequential 模型? 的相关文章

随机推荐