AttributeError:实现 Pytorch 框架时无法 pickle 本地对象 'pre_datasets..'

2023-12-01

我试图在 CNN 上实现一个 pytorch 框架。
我确信代码是正确的,因为它来自教程,并且当我在 GoogleDrive 上的 Jupyter Notebook 上运行它时它可以工作。
但是当我尝试将其本地化为.py文件,它提示错误:
AttributeError: Can't pickle local object 'pre_datasets.<locals>.<lambda>'我知道这是关于在函数外部推断对象,但是这个错误的具体问题是什么?
我应该如何解决它?

这是代码的主要部分。

def pre_datasets():
    TRAIN_TFM = transforms.Compose(
        [
            transforms.Resize(size=(128, 128)),
            # TODO
            transforms.ToTensor(),
        ]
    )
    train_set = DatasetFolder(
        root=CONFIG["train_set_path"],
        loader=lambda x: Image.open(x),
        extensions="jpg",
        transform=TRAIN_TFM,
    )
    train_loader = DataLoader(
        dataset=train_set,
        batch_size=CONFIG["batch_size"],
        shuffle=True,
        num_workers=CONFIG["num_workers"],
        pin_memory=True,
    )
    return train_loader

def train(train_loader):
    ...
    for epoch in range(CONFIG["num_epochs"]):
    ...
        for batch in train_loader: # error happened here
    ...

if __name__ == "__main__":
    train_loader = pre_datasets()
    train(train_loader)

这是错误消息:

Traceback (most recent call last):
  File "HW03_byCRZ.py", line 197, in <module>
    train(train_loader, valid_loader)
  File "HW03_byCRZ.py", line 157, in train
    for batch in train_loader:
  File "/Users/ceezous/opt/anaconda3/envs/pytorch_env/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 355, in __iter__
    return self._get_iterator()
  File "/Users/ceezous/opt/anaconda3/envs/pytorch_env/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 301, in _get_iterator
    return _MultiProcessingDataLoaderIter(self)
  File "/Users/ceezous/opt/anaconda3/envs/pytorch_env/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 914, in __init__
    w.start()
  File "/Users/ceezous/opt/anaconda3/envs/pytorch_env/lib/python3.8/multiprocessing/process.py", line 121, in start
    self._popen = self._Popen(self)
  File "/Users/ceezous/opt/anaconda3/envs/pytorch_env/lib/python3.8/multiprocessing/context.py", line 224, in _Popen
    return _default_context.get_context().Process._Popen(process_obj)
  File "/Users/ceezous/opt/anaconda3/envs/pytorch_env/lib/python3.8/multiprocessing/context.py", line 284, in _Popen
    return Popen(process_obj)
  File "/Users/ceezous/opt/anaconda3/envs/pytorch_env/lib/python3.8/multiprocessing/popen_spawn_posix.py", line 32, in __init__
    super().__init__(process_obj)
  File "/Users/ceezous/opt/anaconda3/envs/pytorch_env/lib/python3.8/multiprocessing/popen_fork.py", line 19, in __init__
    self._launch(process_obj)
  File "/Users/ceezous/opt/anaconda3/envs/pytorch_env/lib/python3.8/multiprocessing/popen_spawn_posix.py", line 47, in _launch
    reduction.dump(process_obj, fp)
  File "/Users/ceezous/opt/anaconda3/envs/pytorch_env/lib/python3.8/multiprocessing/reduction.py", line 60, in dump
    ForkingPickler(file, protocol).dump(obj)
AttributeError: Can't pickle local object 'pre_datasets.<locals>.<lambda>'

我有类似的问题并且我使用过dill像这样:

import dill as pickle

它开箱即用!

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

AttributeError:实现 Pytorch 框架时无法 pickle 本地对象 'pre_datasets..' 的相关文章

随机推荐