在 PyTorch 中使用焦点损失处理不平衡数据集

2024-01-08

我发现这个实现focal loss在 GitHub 中,我使用它来解决不平衡数据集二元分类问题。

# IMPLEMENTATION CREDIT: https://github.com/clcarwin/focal_loss_pytorch
    class FocalLoss(nn.Module):
    def __init__(self, gamma=0.5, alpha=None, size_average=True):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.alpha = alpha
        if isinstance(alpha,(float,int)): self.alpha = torch.Tensor([alpha,1-alpha])
        if isinstance(alpha,list): self.alpha = torch.Tensor(alpha)
        self.size_average = size_average

    def forward(self, input, target):
        if input.dim()>2:
            input = input.view(input.size(0),input.size(1),-1)  # N,C,H,W => N,C,H*W
            input = input.transpose(1,2)    # N,C,H*W => N,H*W,C
            input = input.contiguous().view(-1,input.size(2))   # N,H*W,C => N*H*W,C
        target = target.view(-1,1)

        logpt = F.log_softmax(input)
        logpt = logpt.gather(1,target)
        logpt = logpt.view(-1)
        pt = Variable(logpt.data.exp())

        if self.alpha is not None:
            if self.alpha.type()!=input.data.type():
                self.alpha = self.alpha.type_as(input.data)
            at = self.alpha.gather(0,target.data.view(-1))
            logpt = logpt * Variable(at)

        loss = -1 * (1-pt)**self.gamma * logpt
        if self.size_average: return loss.mean()
        else: return loss.sum()

also

gamma=args.gamma
alpha=args.alpha

criterion = FocalLoss(gamma, alpha)
m = nn.Sigmoid()

我在训练阶段使用如下标准:

for i_batch, sample_batched in enumerate(dataloader_train):  
            #pdb.set_trace()        
            feats = torch.stack(sample_batched['image']) 
            labels = torch.as_tensor(sample_batched['label']).cuda() 
            print('feats shape: ', feats.shape)
            print('labels shape: ', labels.shape)
            output = model(feats)
            loss = criterion(m(output[:,1]-output[:,0]), labels.float())

错误是:

train: True test: False
preparing datasets and dataloaders......
creating models......

=>Epoches 1, learning rate = 0.0010000, previous best = 0.0000
training...
feats shape:  torch.Size([64, 419, 512])
labels shape:  torch.Size([64])
main_classifier.py:86: UserWarning: Implicit dimension choice for log_softmax has been deprecated. Change the call to include dim=X as an argument.
  logpt = F.log_softmax(input)
Traceback (most recent call last):
  File "main_classifier.py", line 346, in <module>
    loss = criterion(m(output[:,1]-output[:,0]), labels.float())
  File "/home/jalal/research/venv/dpcc/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "main_classifier.py", line 87, in forward
    logpt = logpt.gather(1,target)
IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)

我应该如何修复这个错误?

FocalLoss 的这种实现正确吗?


Unlike BCEWithLogitLoss,输入与您使用的相同的参数CrossEntropyLoss解决了问题:

#loss = criterion(m(output[:,1]-output[:,0]), labels.float())
loss = criterion(output, labels)

来自 NVidia 的 Piotr 的功劳 https://discuss.pytorch.org/t/logpt-logpt-gather-1-target-indexerror-dimension-out-of-range-expected-to-be-in-range-of-1-0-but-got-1/145225/6?u=mona_jalal

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

在 PyTorch 中使用焦点损失处理不平衡数据集 的相关文章

  • 从数据框中按索引删除行

    我有一个数组wrong indexes train其中包含我想从数据框中删除的索引列表 0 63 151 469 1008 要删除这些索引 我正在尝试这样做 df train drop wrong indexes train 但是 代码失败
  • python 中的代表

    我实现了这个简短的示例来尝试演示一个简单的委托模式 我的问题是 这看起来我已经理解了委托吗 class Handler def init self parent None self parent parent def Handle self
  • 如何迭代按值排序的 Python 字典?

    我有一本字典 比如 a 6 b 1 c 2 我想迭代一下by value 不是通过键 换句话说 b 1 c 2 a 6 最直接的方法是什么 sorted dictionary items key lambda x x 1 对于那些讨厌 la
  • 从 ffmpeg 获取实时输出以在进度条中使用(PyQt4,stdout)

    我已经查看了很多问题 但仍然无法完全弄清楚 我正在使用 PyQt 并且希望能够运行ffmpeg i file mp4 file avi并获取流式输出 以便我可以创建进度条 我看过这些问题 ffmpeg可以显示进度条吗 https stack
  • 通过列表理解压平列表列表

    我正在尝试使用 python 中的列表理解来展平列表 我的清单有点像 1 2 3 4 5 6 7 8 只是为了打印这个列表列表中的单个项目 我编写了这个函数 def flat listoflist for item in listoflis
  • 将数据帧行转换为字典

    我有像下面的示例数据这样的数据帧 我正在尝试将数据帧中的一行转换为类似于下面所需输出的字典 但是当我使用 to dict 时 我得到了索引和列值 有谁知道如何将行转换为像所需输出那样的字典 任何提示都非常感激 Sample data pri
  • if 语句未命中中的 continue 断点

    在下面的代码中 两者a and b是生成器函数的输出 并且可以评估为None或者有一个值 def testBehaviour self a None b 5 while True if not a or not b continue pri
  • Pandas 中允许重复列

    我将一个大的 CSV 包含股票财务数据 文件分割成更小的块 CSV 文件的格式不同 像 Excel 数据透视表之类的东西 第一列的前几行包含一些标题 公司名称 ID 等在以下列中重复 因为一家公司有多个属性 而不是一家公司只有一栏 在前几行
  • 忽略 Mercurial hook 中的某些 Mercurial 命令

    我有一个像这样的善变钩子 hooks pretxncommit myhook python path to file myhook 代码如下所示 def myhook ui repo kwargs do some stuff 但在我的例子中
  • 如何计算numpy数组中元素的频率?

    我有一个 3 D numpy 数组 其中包含重复的元素 counterTraj shape 13530 1 1 例如 counterTraj 包含这样的元素 我只显示了几个元素 array 136 129 130 103 102 101 我
  • 为什么Python的curses中escape键有延迟?

    In the Python curses module I have observed that there is a roughly 1 second delay between pressing the esc key and getc
  • 为什么在 Python 2.4 中使用 Unicode 数据会出现 ASCII 编码错误,而在 2.7 中却不会?

    我有一个程序 当在 Python 2 7 中运行时 会生成正确的 Unicode 输出到标准输出 当在 Python 2 4 中运行时 我得到UnicodeEncodeError ascii codec can t encode chara
  • 更改 `base_compiledir` 以将编译后的文件保存在另一个目录中

    theano base compiledir指编译后的文件存放的目录 有没有办法可以永久设置theano base compiledir到不同的位置 也许通过修改一些内部 Theano 文件的内容 http deeplearning net
  • 对图像块进行多重处理

    我有一个函数必须循环遍历图像的各个像素并计算一些几何形状 此函数需要很长时间才能运行 在 24 兆像素图像上大约需要 5 小时 但似乎应该很容易在多个内核上并行运行 然而 我一生都找不到一个有据可查 解释充分的例子来使用 Multiproc
  • Seaborn Pairplot 图例不显示颜色

    我一直在学习如何在Python中使用seaborn和pairplot 这里的一切似乎都工作正常 但由于某种原因 图例不会显示相关的颜色 我无法找到解决方案 因此如果有人有任何建议 请告诉我 x sns pairplot stats2 hue
  • 在 pytube3 中获取 youtube 视频的标题?

    我正在尝试构建一个应用程序来使用 python 下载 YouTube 视频pytube3 但我无法检索视频的标题 这是我的代码 from pytube import YouTube yt YouTube link print yt titl
  • 在 Pandas 中使用正则表达式的多种模式

    我是Python编程的初学者 我正在探索正则表达式 我正在尝试从 描述 列中提取一个单词 数据库名称 我无法给出多个正则表达式模式 请参阅下面的描述和代码 描述 Summary AD1 Low free DATA space in data
  • 使用yield 进行字典理解

    作为一个人为的例子 myset set a b c d mydict item yield join item s for item in myset and list mydict gives as cs bs ds a None b N
  • 在Python中按属性获取对象列表中的索引

    我有具有属性 id 的对象列表 我想找到具有特定 id 的对象的索引 我写了这样的东西 index 1 for i in range len my list if my list i id specific id index i break
  • Scrapy Spider不存储状态(持久状态)

    您好 有一个基本的蜘蛛 可以运行以获取给定域上的所有链接 我想确保它保持其状态 以便它可以从离开的位置恢复 我已按照给定的网址进行操作http doc scrapy org en latest topics jobs html http d

随机推荐