pytorch dataloader - 运行时错误:堆栈期望每个张量大小相等,但在条目 0 处得到 [157],在条目 1 处得到 [154]

2024-06-22

我是 pytorch 的初学者。我正在尝试进行基于方面的情感分析。我面临着主题中提到的错误。我的代码如下:我请求帮助解决此错误。提前致谢。我将分享整个代码和错误堆栈。!pip install transformers

import transformers
from transformers import BertModel, BertTokenizer, AdamW, get_linear_schedule_with_warmup
import torch
import numpy as np
import pandas as pd
import seaborn as sns
from pylab import rcParams
import matplotlib.pyplot as plt
from matplotlib import rc
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, classification_report
from collections import defaultdict
from textwrap import wrap
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
%matplotlib inline
%config InlineBackend.figure_format='retina'
sns.set(style='whitegrid', palette='muted', font_scale=1.2)
HAPPY_COLORS_PALETTE = ["#01BEFE", "#FFDD00", "#FF7D00", "#FF006D", "#ADFF02", "#8F00FF"]
sns.set_palette(sns.color_palette(HAPPY_COLORS_PALETTE))
rcParams['figure.figsize'] = 12, 8
RANDOM_SEED = 42
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

df = pd.read_csv("/Users/user1/Downloads/auto_bio_copy.csv")

我正在导入一个 csv 文件,其内容和标签如下所示:

df.head()

                     content                                      label
0   I told him I would leave the car and come back...   O O O O O O O O O O O O O O O O O O O O O O O ...
1   I had the ignition interlock device installed ...   O O O B-Negative I-Negative I-Negative O O O O...
2   Aug. 23 or 24 I went to Walmart auto service d...   O O O O O O O B-Negative I-Negative I-Negative...
3   Side note This is the same reaction I 'd gotte...   O O O O O O O O O O O O O O O O O O O O O O O ...
4   Locked out of my car . Called for help 215pm w...   O O O O O O O O O O O O O O O O O B-Negative O...

df.shape

(1999, 2)

我将标签值转换为整数,如下所示: O=零(0)、B-阳性=1、I-阳性=2、B-阴性=3、I-阴性=4、B-中性=5、I-中性=6、B-混合=7、I -混合=8

df['label'] = df.label.str.replace('O', '0')
df['label'] = df.label.str.replace('B-Positive', '1')
df['label'] = df.label.str.replace('I-Positive', '2')
df['label'] = df.label.str.replace('B-Negative', '3')
df['label'] = df.label.str.replace('I-Negative', '4')
df['label'] = df.label.str.replace('B-Neutral', '5')
df['label'] = df.label.str.replace('I-Neutral', '6')
df['label'] = df.label.str.replace('B-Mixed', '7')
df['label'] = df.label.str.replace('I-Mixed', '8')

接下来,将字符串转换为整数列表,如下所示:

df['label'] = df['label'].str.split(' ').apply(lambda s: list(map(int, s)))
df.head()
                     content                                         label
0   I told him I would leave the car and come back...   [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...
1   I had the ignition interlock device installed ...   [0, 0, 0, 3, 4, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...
2   Aug. 23 or 24 I went to Walmart auto service d...   [0, 0, 0, 0, 0, 0, 0, 3, 4, 4, 4, 0, 0, 0, 0, ...
3   Side note This is the same reaction I 'd gotte...   [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...
4   Locked out of my car . Called for help 215pm w...   [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...
PRE_TRAINED_MODEL_NAME = 'bert-base-cased'
tokenizer = BertTokenizer.from_pretrained(PRE_TRAINED_MODEL_NAME)
token_lens = []
for txt in df.content:
  tokens = tokenizer.encode_plus(txt, max_length=512, add_special_tokens=True, truncation=True, return_attention_mask=True)
  token_lens.append(len(tokens))
MAX_LEN = 512
class Auto_Bio_Dataset(Dataset):
    def __init__(self, contents, labels, tokenizer, max_len):
        self.contents = contents
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_len = max_len
    def __len__(self):
        return len(self.contents)
    def __getitem__(self, item):
        content = str(self.contents[item])
        label = self.labels[item]
        encoding = self.tokenizer.encode_plus(
          content,
          add_special_tokens=True,
          max_length=self.max_len,
          return_token_type_ids=False,
          #padding='max_length',
          pad_to_max_length=True,
          truncation=True,
          return_attention_mask=True,
          return_tensors='pt'
        )
        return {
          'content_text': content,
          'input_ids': encoding['input_ids'].flatten(),
          'attention_mask': encoding['attention_mask'].flatten(),
          'labels': torch.tensor(label)
        }
df_train, df_test = train_test_split(
  df,
  test_size=0.1,
  random_state=RANDOM_SEED
)
df_val, df_test = train_test_split(
  df_test,
  test_size=0.5,
  random_state=RANDOM_SEED
)
df_train.shape, df_val.shape, df_test.shape
((1799, 2), (100, 2), (100, 2))
def create_data_loader(df, tokenizer, max_len, batch_size):
    ds = Auto_Bio_Dataset(
        contents=df.content.to_numpy(),
        labels=df.label.to_numpy(),
        tokenizer=tokenizer,
        max_len=max_len
  )
    return DataLoader(
        ds,
        batch_size=batch_size,
        num_workers=2
  )
BATCH_SIZE = 16
train_data_loader = create_data_loader(df_train, tokenizer, MAX_LEN, BATCH_SIZE)
val_data_loader = create_data_loader(df_val, tokenizer, MAX_LEN, BATCH_SIZE)
test_data_loader = create_data_loader(df_test, tokenizer, MAX_LEN, BATCH_SIZE)
data = next(iter(train_data_loader))
data.keys()

错误如下:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-71-e0a71018e473> in <module>
----> 1 data = next(iter(train_data_loader))
      2 data.keys()

~/opt/anaconda3/lib/python3.7/site-packages/torch/utils/data/dataloader.py in __next__(self)
    528             if self._sampler_iter is None:
    529                 self._reset()
--> 530             data = self._next_data()
    531             self._num_yielded += 1
    532             if self._dataset_kind == _DatasetKind.Iterable and \

~/opt/anaconda3/lib/python3.7/site-packages/torch/utils/data/dataloader.py in _next_data(self)
   1222             else:
   1223                 del self._task_info[idx]
-> 1224                 return self._process_data(data)
   1225 
   1226     def _try_put_index(self):

~/opt/anaconda3/lib/python3.7/site-packages/torch/utils/data/dataloader.py in _process_data(self, data)
   1248         self._try_put_index()
   1249         if isinstance(data, ExceptionWrapper):
-> 1250             data.reraise()
   1251         return data
   1252 

~/opt/anaconda3/lib/python3.7/site-packages/torch/_utils.py in reraise(self)
    455             # instantiate since we don't know how to
    456             raise RuntimeError(msg) from None
--> 457         raise exception
    458 
    459 

RuntimeError: Caught RuntimeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/Users/namrathabhandarkar/opt/anaconda3/lib/python3.7/site-packages/torch/utils/data/_utils/worker.py", line 287, in _worker_loop
    data = fetcher.fetch(index)
  File "/Users/namrathabhandarkar/opt/anaconda3/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 52, in fetch
    return self.collate_fn(data)
  File "/Users/namrathabhandarkar/opt/anaconda3/lib/python3.7/site-packages/torch/utils/data/_utils/collate.py", line 157, in default_collate
    return elem_type({key: default_collate([d[key] for d in batch]) for key in elem})
  File "/Users/namrathabhandarkar/opt/anaconda3/lib/python3.7/site-packages/torch/utils/data/_utils/collate.py", line 157, in <dictcomp>
    return elem_type({key: default_collate([d[key] for d in batch]) for key in elem})
  File "/Users/namrathabhandarkar/opt/anaconda3/lib/python3.7/site-packages/torch/utils/data/_utils/collate.py", line 138, in default_collate
    return torch.stack(batch, 0, out=out)
RuntimeError: stack expects each tensor to be equal size, but got [157] at entry 0 and [154] at entry 1

我在一些github帖子中发现这个错误可能是因为batch size的原因,所以我将batch size改为8,然后错误如下:

BATCH_SIZE = 8
train_data_loader = create_data_loader(df_train, tokenizer, MAX_LEN, BATCH_SIZE)
val_data_loader = create_data_loader(df_val, tokenizer, MAX_LEN, BATCH_SIZE)
test_data_loader = create_data_loader(df_test, tokenizer, MAX_LEN, BATCH_SIZE)
data = next(iter(train_data_loader))
data.keys()
RuntimeError                              Traceback (most recent call last)
<ipython-input-73-e0a71018e473> in <module>
----> 1 data = next(iter(train_data_loader))
      2 data.keys()

~/opt/anaconda3/lib/python3.7/site-packages/torch/utils/data/dataloader.py in __next__(self)
    528             if self._sampler_iter is None:
    529                 self._reset()
--> 530             data = self._next_data()
    531             self._num_yielded += 1
    532             if self._dataset_kind == _DatasetKind.Iterable and \

~/opt/anaconda3/lib/python3.7/site-packages/torch/utils/data/dataloader.py in _next_data(self)
   1222             else:
   1223                 del self._task_info[idx]
-> 1224                 return self._process_data(data)
   1225 
   1226     def _try_put_index(self):

~/opt/anaconda3/lib/python3.7/site-packages/torch/utils/data/dataloader.py in _process_data(self, data)
   1248         self._try_put_index()
   1249         if isinstance(data, ExceptionWrapper):
-> 1250             data.reraise()
   1251         return data
   1252 

~/opt/anaconda3/lib/python3.7/site-packages/torch/_utils.py in reraise(self)
    455             # instantiate since we don't know how to
    456             raise RuntimeError(msg) from None
--> 457         raise exception
    458 
    459 

RuntimeError: Caught RuntimeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/Users/namrathabhandarkar/opt/anaconda3/lib/python3.7/site-packages/torch/utils/data/_utils/worker.py", line 287, in _worker_loop
    data = fetcher.fetch(index)
  File "/Users/namrathabhandarkar/opt/anaconda3/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 52, in fetch
    return self.collate_fn(data)
  File "/Users/namrathabhandarkar/opt/anaconda3/lib/python3.7/site-packages/torch/utils/data/_utils/collate.py", line 157, in default_collate
    return elem_type({key: default_collate([d[key] for d in batch]) for key in elem})
  File "/Users/namrathabhandarkar/opt/anaconda3/lib/python3.7/site-packages/torch/utils/data/_utils/collate.py", line 157, in <dictcomp>
    return elem_type({key: default_collate([d[key] for d in batch]) for key in elem})
  File "/Users/namrathabhandarkar/opt/anaconda3/lib/python3.7/site-packages/torch/utils/data/_utils/collate.py", line 137, in default_collate
    out = elem.new(storage).resize_(len(batch), *list(elem.size()))
RuntimeError: Trying to resize storage that is not resizable

我不确定是什么导致了第一个错误(主题中提到的错误)。我在代码中使用填充和截断,但出现错误。

非常感谢任何解决此问题的帮助。

提前致谢。


快速回答:您需要实施自己的collate_fn创建时的函数DataLoader. See PyTorch 论坛的讨论 https://discuss.pytorch.org/t/dataloader-gives-stack-expects-each-tensor-to-be-equal-size-due-to-different-image-has-different-objects-number/91941/7.

您应该能够将函数对象传递给DataLoader实例化:

def my_collate_fn(data):
    # TODO: Implement your function
    # But I guess in your case it should be:
    return tuple(data)

return DataLoader(
    ds,
    batch_size=batch_size,
    num_workers=2,
    collate_fn=my_collate_fn
)

这应该是解决这个问题的方法,但作为临时补救措施,以防出现紧急情况或快速测试很好,只需更改batch_size to 1以防止火炬试图将不同形状的东西堆叠起来。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

pytorch dataloader - 运行时错误:堆栈期望每个张量大小相等,但在条目 0 处得到 [157],在条目 1 处得到 [154] 的相关文章

随机推荐

  • Laravel 文件下载 - php_fileinfo 扩展未启用

    我正在使用 Laravel 5 4 13 和 PHP 7 1 并迁移到共享主机 我正在尝试使该网站正常运行 但由于缺少扩展名而无法正常运行 php fileinfo 这是网站崩溃的代码 file base path storage app
  • EJB 3.0 - 嵌套事务!= 需要新的?

    我刚刚阅读了 掌握 EJB 3 0 的事务章节 10 现在我对嵌套事务感到困惑 书上说 EJB 定义的事务管理器不 支持嵌套事务 它只需要支持扁平交易 站点 278 注释 这个事实不仅在这本书中有描述 我在其他书籍 网站中也发现了这种说法
  • 在 Google 日历 API 中开会

    如何在java中的google calendar api中添加google meet 请帮我 我还没看懂谷歌文档 https developers google com calendar create events https develo
  • Java 运算符:|= 按位或并赋值示例[重复]

    这个问题在这里已经有答案了 我刚刚浏览了某人写的代码 我看到了 用法 查找Java运算符 它建议按位或和分配操作 任何人都可以解释并给我一个例子吗 这是读取它的代码 for String search textSearch getValue
  • Spring Boot:@TestConfiguration 在集成测试期间不覆盖 Bean

    我有一个Bean定义在一个装饰有的类中 Configuration Configuration public class MyBeanConfig Bean public String configPath return productio
  • 有没有办法强制 print!/println!使用 Windows 换行符 (CR LF)

    我在 Windows 10 上使用 Rust 1 9 当使用一些代码并比较从标准输出捕获的结果时 我注意到输出使用以 0x0A 10 LF 结尾的 Linux 行 而不是 Windows 0x0D 0x0A 13 10 CR LF 我尝试了
  • 是否可以通过逆向工程从 .ipa 文件获取原始源代码?

    我目前是一名iPhone应用程序开发人员 试图通过提取当前的ipa解决方案和github解决方案来学习生成音频脉冲的机制 使用Hopper解压时 只会生成重新编译的 不完整的汇编代码 我们看不到任何可供进一步探索的目标类别 有没有其他方法可
  • 如何在流星中将变量从服务器发送到客户端?

    我有一个带有文本输入和按钮的页面 当我将 YouTube 视频的链接插入文本字段并按下按钮时 视频下载到本地文件夹中 问题 如何将下载视频的本地副本的链接发送回客户端 更一般的问题 如何将变量从服务器发送到客户端 该变量是临时的 不会存储在
  • 哪些因素会导致 Win32 错误 665(文件系统限制)?

    我维护一个应用程序 该应用程序从数据记录器收集数据并将该数据附加到二进制文件的末尾 该系统的本质是文件一次可以小步增长 gt 4 GB 我的应用程序的一位用户在他的 NTFS 分区上遇到过尝试追加数据失败的情况 该错误是由于调用 fflus
  • 在 Apache Airflow 中实施 Postgres Sql

    我在 Ubuntu 版本 18 04 3 服务器上实现了 Apache Airflow 当我设置它时 我使用了 sql lite 通用数据库 这使用了顺序执行器 我这样做只是为了玩玩并习惯这个系统 现在我正在尝试使用本地执行器 并且需要将我
  • Java swing 在鼠标拖放中绘制矩形

    我正在创建一个矩形绘图程序 仅当程序拖动到底部时才会绘制正方形 即使向另一个方向拖动 我也想确保正确绘制正方形 我该如何修复它 请帮我 DrawRect java import javax swing import java awt imp
  • 是否有 Safari Mobile(即 iPad 和 iPhone)支持的字体列表?

    我正在寻找 iPad 和 iPhone 版 Safari Mobile 支持的字体的详尽列表 事实上 我可以在我的网站中使用哪些字体 你应该尝试这个网站 http iosfonts com http iosfonts com 它有一个表格
  • 使用 Python 的 Mac 键盘监听器

    我已经尝试了键盘监听器的所有代码 我看到一篇文章说 Mac 会阻止系统监听键盘按下的声音 我正在使用Python 我也使用 pynput 作为库 如何让 Mac 监听我的按键操作 它只监听特殊键 如 Shift Alt 和 Command
  • 命名空间 Visualstudio 不存在于 Microsoft 命名空间中,缺少程序集引用

    我继承了一个我无法编译的 C Visual Studio 2010 项目 因为它正在查找我无法满足的以下参考 using Microsoft VisualStudio Tools Applications Runtime 我是 VS 的新手
  • 删除新展示位置

    我知道对使用placement new 创建的变量调用delete 然后访问该内存块具有未定义的行为 int x new int 2 char ch new x char ch t delete ch 但是 如果在堆栈上分配内存块而不是堆
  • java中没有这样的方法错误

    我收到以下错误 如下所示 java lang NoSuchMethodError org apache poi hssf usermodel HSSFSheet addMergedRegion Lorg apache poi hssf ut
  • 使用 Flash 获取计算机信息

    是否可以使用 Adob e Flash 检索计算机信息 RAM 硬盘大小 CPU 速度等 如果是这样 有人可以向我指出一个网站 告诉我如何做吗 我认为你无法获得 RAM 硬盘大小或时钟速度 Flash 在虚拟机中运行 并且它可能被设置为仅向
  • 在包之间传递关联数组作为参数

    我有两个单独的 Oracle v9 2 PL SQL 包 并且我试图将 package1 中的过程中的关联数组 即索引表 作为参数传递给 package2 中的过程 这可能吗 我不断得到PLS 00306 wrong number or t
  • 保留函数名可以重载吗?

    这个问题是后续问题this one https stackoverflow com q 50898508 5376789 考虑以下程序 include
  • pytorch dataloader - 运行时错误:堆栈期望每个张量大小相等,但在条目 0 处得到 [157],在条目 1 处得到 [154]

    我是 pytorch 的初学者 我正在尝试进行基于方面的情感分析 我面临着主题中提到的错误 我的代码如下 我请求帮助解决此错误 提前致谢 我将分享整个代码和错误堆栈 pip install transformers import trans