一次热编码期间出现 RunTimeError

2024-04-14

我有一个数据集,其中类值以 1 步从 -2 到 2(i.e., -2,-1,0,1,2)其中 9 标识未标记的数据。 使用一种热编码

self._one_hot_encode(labels)

我收到以下错误:RuntimeError: index 1 is out of bounds for dimension 1 with size 1

due to

self.one_hot_labels = self.one_hot_labels.scatter(1, labels.unsqueeze(1), 1)

错误应该从[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 9, 1, 1, 1, 1, 1, 1],我在映射设置中有 9 等于索引 9 到 1。即使在浏览了过去的问题和类似问题的答案之后(例如,索引 1 超出尺寸 1 的维度 0 的范围 https://stackoverflow.com/questions/67185851/index-1-is-out-of-bounds-for-dimension-0-with-size-1)。 涉及错误的部分代码如下:

def _one_hot_encode(self, labels):
    # Get the number of classes
    classes = torch.unique(labels)
    classes = classes[classes != 9] # unlabelled 
    self.n_classes = classes.size(0)

    # One-hot encode labeled data instances and zero rows corresponding to unlabeled instances
    unlabeled_mask = (labels == 9)
    labels = labels.clone()  # defensive copying
    labels[unlabeled_mask] = 0
    self.one_hot_labels = torch.zeros((self.n_nodes, self.n_classes), dtype=torch.float)
    self.one_hot_labels = self.one_hot_labels.scatter(1, labels.unsqueeze(1), 1)
    self.one_hot_labels[unlabeled_mask, 0] = 0

    self.labeled_mask = ~unlabeled_mask

def fit(self, labels, max_iter, tol):
    
    self._one_hot_encode(labels)

    self.predictions = self.one_hot_labels.clone()
    prev_predictions = torch.zeros((self.n_nodes, self.n_classes), dtype=torch.float)

    for i in range(max_iter):
        # Stop iterations if the system is considered at a steady state
        variation = torch.abs(self.predictions - prev_predictions).sum().item()
        

        prev_predictions = self.predictions
        self._propagate()

数据集示例:

ID  Target  Weight  Label   Score   Scale_Cat   Scale_num
0   A   D   65.1    1   87  Up  1
1   A   X   35.8    1   87  Up  1
2   B   C   34.7    1   37.5    Down    -2
3   B   P   33.4    1   37.5    Down    -2
4   C   B   33.1    1   37.5    Down    -2
5   S   X   21.4    0   12.5    NA  9

我用作参考的源代码在这里:https://mybinder.org/v2/gh/thibaudmartinez/label-propagation/master?filepath=notebook.ipynb https://mybinder.org/v2/gh/thibaudmartinez/label-propagation/master?filepath=notebook.ipynb

错误的完整跟踪:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-126-792a234f63dd> in <module>
      4 label_propagation = LabelPropagation(adj_matrix_t)
----> 6 label_propagation.fit(labels_t) # causing error
      7 label_propagation_output_labels = label_propagation.predict_classes()
      8 

<ipython-input-115-54a7dbc30bd1> in fit(self, labels, max_iter, tol)
    100 
    101     def fit(self, labels, max_iter=1000, tol=1e-3):
--> 102         super().fit(labels, max_iter, tol)
    103 
    104 ## Label spreading

<ipython-input-115-54a7dbc30bd1> in fit(self, labels, max_iter, tol)
     58             Convergence tolerance: threshold to consider the system at steady state.
     59         """
---> 60         self._one_hot_encode(labels)
     61 
     62         self.predictions = self.one_hot_labels.clone()

<ipython-input-115-54a7dbc30bd1> in _one_hot_encode(self, labels)
     42         labels[unlabeled_mask] = 0
     43         self.one_hot_labels = torch.zeros((self.n_nodes, self.n_classes), dtype=torch.float)
---> 44         self.one_hot_labels = self.one_hot_labels.scatter(1, labels.unsqueeze(1), 1)
     45         self.one_hot_labels[unlabeled_mask, 0] = 0
     46 

RuntimeError: index 1 is out of bounds for dimension 1 with size 1


我浏览了你的笔记本(我认为你将 9 更改为 -1 以便运行)并看到这部分代码:

# Learn with Label Propagation
label_propagation = LabelPropagation(adj_matrix_t)
print("Label Propagation: ", end="")
label_propagation.fit(labels_t)
label_propagation_output_labels = label_propagation.predict_classes()

最终调用:

self.one_hot_labels = self.one_hot_labels.scatter(1, labels.unsqueeze(1), 1)

是出了问题的地方。

请花一点时间阅读有关 scatter 的 pytorch 手册:火炬分散 https://pytorch.org/docs/stable/generated/torch.Tensor.scatter_.html我们了解到,对于分散来说,了解暗淡、索引、src 和自身矩阵很重要。对于一种热编码,dim=1 或 0 并不重要,我们的 src 矩阵是 1(稍后我们将对此进行更多研究)。您现在在维度 1 上调用 scatter,索引矩阵为 [40,1],结果(自身)矩阵为 [40,5]。

我在这里看到两个问题:

  1. 您正在使用文字类别虚拟变量 (-2,-1,0,1,2) 作为索引矩阵中的编码索引。这将导致 scatter 在 src 矩阵中搜索这些索引。这是索引越界的地方
  2. 您提到有 6 个类 -2、-1、0、1、2 和 9 为未标记的,但您是 5 个类的热门编码。 (是的,我知道您希望未标记的类全部为零,但这用分散实现有点困难。我稍后会解释)。

那么我们该如何解决这个问题呢?

问题一:让我们从一个小例子开始:

index = torch.tensor([[5],[0],[3],[5],[1],[4]]); print(index.shape); print(index)
result = torch.zeros(6, 6, dtype=src.dtype).scatter_(1, index, src); print(result.shape); print(result)

这会给我们

torch.Size([6, 1])
tensor([[5],
        [0],
        [3],
        [5],
        [1],
        [4]])
torch.Size([6, 6])
tensor([[0, 0, 0, 0, 0, 1],
        [1, 0, 0, 0, 0, 0],
        [0, 0, 0, 1, 0, 0],
        [0, 0, 0, 0, 0, 1],
        [0, 1, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0]])

索引矩阵是 6 个观测值和 1 个观测值(类别) 自矩阵是 6 个观测值,具有 6 类一热编码向量 scatter(dim=1) 创建 self 矩阵的方式是 torch 首先检查行(观察),然后将该行的值更改为存储在 src 矩阵中同一行但列的值存储在索引中的值。

self[i][index[i][j][k]][k] = src[i][j][k]

因此,在您的情况下,您试图将 1 的值应用到 self[40,1] 中索引 [0] 列(等于 1)的行中。给你问题中的错误。虽然我检查了你的笔记本,错误是 索引 -1 超出了尺寸为 5 的维度 1 的范围。它们都是相同的根本原因。

问题 2:One-hot 编码

在这种情况下,使用冷编码进行完整的one-hot 比one-hot 更容易。原因是,对于单热冷编码,您需要在 src 矩阵中为每个未标记的观察创建一个 0 值。这比仅仅使用 1 作为 src 更痛苦。另请阅读此链接:OHE 全零是否有效? https://stats.stackexchange.com/questions/408626/is-it-valid-to-have-all-zeroes-in-a-one-hot-encoded-categorical-feature我认为对每个类别都使用 one-hot 更有意义。

因此,对于第二个问题,我们只需要简单地将类别映射到结果/自身矩阵的索引中。由于我们有 6 个类别,我们只需将它们映射到 0,1,2,3,4,5 即可。一个简单的 lambda 函数就可以解决这个问题。我使用随机采样器从类列表中获取数据标签,如下所示:(我从 6 个类中随机创建了 40 个观察值)

classes = list([-2,-1,0,1,2,9])

labels = list()
for i in range(0,40):
    labels.append(list([(lambda x: x+2 if x !=9 else 5)(random.sample(classes,1)[0])]))

index_aka_labels = torch.tensor(labels)
print(index_aka_labels)
print(index_aka_labels.shape)
torch.zeros(40, 6, dtype=src.dtype).scatter_(1, index_aka_labels, 1)

最终,我们达到了我们想要的OHE结果:

tensor([[0, 0, 0, 0, 0, 1],
        [0, 0, 1, 0, 0, 0],
        [0, 0, 0, 0, 1, 0],
        [0, 0, 0, 0, 1, 0],
        ... (40 observations)
        [0, 1, 0, 0, 0, 0],
        [0, 0, 0, 1, 0, 0],
        [1, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 1],
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

一次热编码期间出现 RunTimeError 的相关文章

  • 数据操作 startdate enddate python pandas

    我有一个促销描述数据集 其中包含有关正在运行的各种促销活动及其开始日期 结束日期的信息 promo item start date end date Buy1 get 1 A 2015 01 08 2015 01 12 Buy1 get 1
  • 是否可以将名为“None”的值添加到枚举类型?

    我可以将名为 None 的值添加到枚举中吗 例如 from enum import Enum class Color Enum None 0 represent no color at all red 1 green 2 blue 3 co
  • 箱线图与箱线图有何不同?

    我想知道当我们在海生图书馆中有箱线图时为什么会有箱线图 我知道一件事是箱线图优化了表示数据的方式 特别是对于大型数据集 但我不知道为什么 除此之外 我没有任何充分的理由使用箱线图 箱线图将中位数显示为中心线 第 50 个百分位数 然后将第
  • 日期/时间值的 Django URL 转换器

    我正在尝试使用 Django 内置的 URL 转换器将 URL 中的日期时间字符串转换为视图中的日期对象 如果我手动输入 URL 它们会按预期工作 但尝试为其生成 URL 时找不到匹配项 我的转换器很简单 from django utils
  • 识别 Windows 版本

    我正在编写一个打印出详细 Windows 版本信息的函数 输出可能是这样的元组 32bit XP Professional SP3 English 它将支持 Windows XP 及更高版本 我一直坚持获取 Windows 版本 例如 专业
  • 遍历 globals() 字典

    我 尝试 使用globals 在我的程序中迭代所有全局变量 我就是这样做的 for k v in globals iteritems function k v 当然 这样做时 我只是创建了另外 2 个全局变量 k and v 所以我得到这个
  • __subclasses__ 没有显示任何内容

    我正在实现一个从适当的子类返回对象的函数 如果我搬家SubClass from base py 没有出现子类 subclasses 它们必须在同一个文件中吗 也许我从来没有直接导入subclass py对Python隐藏子类 我能做些什么
  • “分页文件太小,无法完成此操作”尝试训练 YOLOv5 对象检测模型时出错

    我有大约 50000 个图像和注释文件用于训练 YOLOv5 对象检测模型 我在另一台计算机上仅使用 CPU 训练模型没有问题 但需要太长时间 因此我需要 GPU 训练 我的问题是 当我尝试使用 GPU 进行训练时 我不断收到此错误 OSE
  • Python 对象属性 - 访问方法

    假设我有一个具有某些属性的类 在 Pythonic OOP 中 如何访问这些属性是最好的 就像obj attr 或者也许编写 get 访问器 此类事物可接受的命名风格是什么 Edit 您能否详细说明使用单下划线或双前导下划线命名属性的最佳实
  • 将 Matlab MEX 文件中的函数直接嵌入到 Python 中

    我正在使用专有的 Matlab MEX 文件在 Matlab 中导入一些仿真结果 当然没有可用的源代码 Matlab 的接口实际上非常简单 因为只有一个函数 返回一个 Matlab 结构体 我想知道是否有任何方法可以直接从Python调用M
  • 调试 python Web 服务

    我正在使用找到的说明here http www diveintopython net http web services user agent html 尝试检查发送到我的网络服务器的 HTTP 命令 但是 我没有看到按照教程中的建议在控制
  • captureWarnings 设置为 True 不会捕获警告

    我想记录所有警告 我以为这样的设定captureWarnings to True应该可以解决问题 但事实并非如此 代码 import logging import warnings from logging handlers import
  • 按多索引的一级对 pandas DataFrame 进行排序

    我有一个多索引 pandas DataFrame 需要按索引器之一进行排序 这是数据片段 gene VIM treatment dose time TGFb 0 1 2 0 158406 1 2 0 039158 10 2 0 052608
  • 带回溯的 Dijkstra 算法?

    In a 相关主题 https stackoverflow com questions 28333756 finding most efficient path between two nodes in an interval graph
  • 计算素数并附加到列表

    我最近开始尝试使用 python 解决 Euler 项目的问题 并且在尝试计算素数并将其附加到列表中时遇到了这个障碍 我编写了以下代码 但我很困惑为什么它在运行时不输出任何内容 import math primes def isPrime
  • 在 python 中计时时,我应该如何考虑 subprocess.Popen() 开销?

    编码社区的成员比我更聪明 我有一个 python 问题要问你们 我正在尝试优化一个 python 脚本 该脚本 除其他外 返回子进程执行和终止的挂钟时间 我想我已经接近这样的事情了 startTime time time process s
  • Python中如何实现相对导入

    考虑 stuff init py mylib py Foo init py main py foo init py script py script py想要进口mylib py 这只是一个示例 但实际上我只想在父目录中进行模块的相对导入
  • 矩阵求逆 (3,3) python - 硬编码与 numpy.linalg.inv

    对于大量矩阵 我需要计算定义为的距离度量 尽管我确实知道强烈建议不要使用矩阵求逆 但我没有找到解决方法 因此 我尝试通过对矩阵求逆进行硬编码来提高性能 因为所有矩阵的大小均为 3 3 我预计这至少会是一个微小的改进 但事实并非如此 为什么
  • 张量流多元线性回归不收敛

    我正在尝试使用张量流训练具有正则化的多元线性回归模型 由于某种原因 我无法获取以下代码的训练部分来计算我想要用于梯度下降更新的误差 我在设置图表时做错了什么吗 def normalize data matrix averages np av
  • 如何访问模板缓存? - 姜戈

    I am 缓存 HTML在几个模板内 例如 cache 900 stats stats endcache 我可以使用以下方式访问缓存吗低级图书馆 例如 html cache get stats 我确实需要对模板缓存进行一些细粒度的控制 有任

随机推荐