Huggingface 分类与预测作斗争

2024-01-21

我正在微调 longformer,然后使用进行预测TextClassificationPipeline and model(**inputs)方法。我不确定为什么会得到不同的结果

import pandas as pd
import datasets
from transformers import LongformerTokenizerFast, LongformerForSequenceClassification, Trainer, TrainingArguments, LongformerConfig
import torch.nn as nn
import torch
from torch.utils.data import DataLoader#Dataset, 
import numpy as np
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
from tqdm import tqdm
#import wandb
import os
from datasets import Dataset
from transformers import TextClassificationPipeline, AutoTokenizer, AutoModelForSequenceClassification

tokenizer = LongformerTokenizerFast.from_pretrained('folder_path/', max_length = maximum_len)

从保存的位置加载微调模型。使用原始分词器

saved_location='c:/xyz'
model_saved=AutoModelForSequenceClassification.from_pretrained(saved_location)
pipe = TextClassificationPipeline(model=model_saved, tokenizer=tokenizer, device=0)#tokenizer_saved, padding=True, truncation=True)
prediction = pipe(["The text to predict"], return_all_scores=True)
prediction
[[{'label': 'LABEL_0', 'score': 0.7107483148574829},
  {'label': 'LABEL_1', 'score': 0.2892516553401947}]]

第二种方法

inputs = tokenizer("The text to predict", return_tensors="pt").to(device)
outputs = model_saved(**inputs)#, labels=labels)
print (outputs['logits'])
#tensor([[ 0.4552, -0.4438]], device='cuda:0', grad_fn=<AddmmBackward0>)
torch.sigmoid(outputs['logits'])
#tensor([[0.6119, 0.3908]], device='cuda:0', grad_fn=<SigmoidBackward0>)

AutoModelForSequenceClassification返回概率0.71 and 0.29。当我看第二种方法时。它返回逻辑0.4552, -0.4438转换为概率0.6119, 0.3908

#更新1

第一个链接文本分类管道 https://huggingface.co/docs/transformers/v4.17.0/en/main_classes/pipelines#transformers.TextClassificationPipeline克罗诺克的回答如下

function_to_apply (str, optional, defaults to "default") — The function to apply to the model outputs in order to retrieve the scores. Accepts four different values:
"default": if the model has a single label, will apply the sigmoid function on the output. If the model has several labels, will apply the softmax function on the output.
"sigmoid": Applies the sigmoid function on the output.
"softmax": Applies the softmax function on the output.
"none": Does not apply any function on the output.

因为这是一个二元分类问题(单标签),它不应该应用 sigmoid 吗?


我假设model.config.num_labels==2,如果是这样的话,则文本分类管道 https://huggingface.co/docs/transformers/v4.17.0/en/main_classes/pipelines#transformers.TextClassificationPipeline应用 softmax 而不是 sigmoid 来计算概率 (code https://github.com/huggingface/transformers/blob/198c335d219a5eb4d3f124fdd1ce1a9cd9f78a9b/src/transformers/pipelines/text_classification.py#L142).

import torch

logits = torch.tensor([ 0.4552, -0.4438])
print(torch.softmax(logits,0))

Output:

tensor([0.7107, 0.2893])
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Huggingface 分类与预测作斗争 的相关文章

  • 递归 lambda 表达式可能吗?

    我正在尝试编写一个调用自身的 lambda 表达式 但我似乎找不到任何语法 或者即使它是可能的 本质上我想将以下函数传输到以下 lambda 表达式中 我意识到这是一个愚蠢的应用程序 它只是添加 但我正在探索可以在 python 中使用 l
  • 如何在Python中流式传输和操作大数据文件

    我有一个相对较大 1 GB 的文本文件 我想通过跨类别求和来减小其大小 Geography AgeGroup Gender Race Count County1 1 M 1 12 County1 2 M 1 3 County1 2 M 2
  • 用缺失的日期填充其他列 Nan Pandas DataFrame

    我实际上是从几个 Excel 文件中提取数据来监控我的每日卡路里摄入量 我设法使用列表理解来生成日期 我尝试使用合并或连接 但它不起作用 ValueError 您正在尝试合并对象和 float64 列 date list 2021 05 2
  • 如何检查python xlrd库中的excel文件是否有效

    有什么办法与xlrd库来检查您使用的文件是否是有效的 Excel 文件 我知道还有其他库可以检查文件头 我可以使用文件扩展名检查 但为了多平台性我想知道是否有任何我可以使用的功能xlrd库本身在尝试打开文件时可能会返回类似 false 的内
  • 检查 Python 中的可迭代对象中的所有元素的谓词是否计算为 true

    我很确定有一个常见的习语 但我无法通过谷歌搜索找到它 这是我想做的 用Java Applies the predicate to all elements of the iterable and returns true if all ev
  • 如何在Python中同时运行两只乌龟?

    我试图让两只乌龟一起移动 而不是一只接着另一只移动 例如 a turtle Turtle b turtle Turtle a forward 100 b forward 100 但这只能让他们一前一后地移动 有没有办法让它们同时移动 有没有
  • Mac OS X 中文件系统的 Unicode 编码在 Python 中不正确?

    在 OS X 和 Python 中处理 Unicode 文件名有点困难 我试图在代码中稍后使用文件名作为正则表达式的输入 但文件名中使用的编码似乎与 sys getfilesystemencoding 告诉我的不同 采取以下代码 usr b
  • Python3.0 - 标记化和取消标记化

    我正在使用类似于以下简化脚本的内容来解析较大文件中的 python 片段 import io import tokenize src foo bar src bytes src encode src io BytesIO src src l
  • 在 Django OAuth Toolkit 中安全创建新应用程序

    如何将 IsAdminUser 权限添加到 Django OAuth Toolkit 中的 o applications 视图 REST FRAMEWORK DEFAULT PERMISSION CLASSES rest framework
  • Django send_mail SMTPSenderRefused 530 与 gmail

    一段时间以来 我一直在尝试使用 Django 从我正在开发的网站接收电子邮件 现在 我还没有部署它 并且我正在使用Django开发服务器 我不知道这是否会影响它 这是我的 settings py 配置 EMAIL BACKEND djang
  • Geodjango距离查询未检索到正确的结果

    我正在尝试根据地理位置的接近程度来检索一些帖子 正如您在代码中看到的 我正在使用 GeoDjango 并且代码在视图中执行 问题是距离过滤器似乎被完全忽略了 当我检查查询集上的距离时 我得到了预期距离 1m 和 18km 但 18km 的帖
  • SMTP_SSL SSLError: [SSL: UNKNOWN_PROTOCOL] 未知协议 (_ssl.c:590)

    此问题与 smtplib 的 SMTP SSL 连接有关 当与 SMTP 无 ssl 连接时 它正在工作 在 SMTP SSL 中尝试相同的主机和端口时 出现错误 该错误仅基于主机 gmail 设置也工作正常 请检查下面的示例 如果 Out
  • Weka J48 分类器:无法处理数字类?

    我现在尝试使用 Weka 在我的训练数据上构建 J48 C4 5 分类器模型 首先我这样做 这似乎很顺利 java Xmx10G cp weka weka jar weka core converters TextDirectoryLoad
  • ANTLR 获取并拆分词法分析器内容

    首先 对我的英语感到抱歉 我还在学习 我为我的框架编写 Python 模块 用于解析 CSS 文件 我尝试了 regex ply python 词法分析器和解析器 但我发现自己在 ANTLR 中 第一次尝试 我需要解析 CSS 文件中的注释
  • 将seaborn.palplot轴添加到现有图形中以可视化不同调色板

    将seaborn人物添加到子图中是usually https seaborn pydata org examples cubehelix palette html创建图形时通过传递 ax 来完成 例如 sns kdeplot x y cma
  • 在 keras 中保存和加载权重

    我试图从我训练过的模型中保存和加载权重 我用来保存模型的代码是 TensorBoard log dir output model fit generator image a b gen batch size steps per epoch
  • SocketIO + Flask 检测断开连接

    我在这里有一个不同的问题 但意识到它可以简化为 如何检测客户端何时从页面断开连接 关闭其页面或单击链接 换句话说 套接字连接关闭 我想制作一个带有更新用户列表的聊天应用程序 并且我在 Python 上使用 Flask 当用户连接时 浏览器发
  • 从 NumPy 数组到 Mat 的 C++ 转换 (OpenCV)

    我正在围绕 ArUco 增强现实库 基于 OpenCV 编写一个薄包装器 我试图构建的界面非常简单 Python 将图像传递给 C 代码 C 代码检测标记并将其位置和其他信息作为字典元组返回给 Python 但是 我不知道如何在 Pytho
  • 在 Django 查询中使用 .extra(select={...}) 引入的值上使用 .aggregate() ?

    我正在尝试计算玩家每周玩游戏的次数 如下所示 player game objects extra select week WEEK games game date aggregate count Count week 但姜戈抱怨说 Fiel
  • 给定文档,选择相关片段

    当我在这里提出问题时 自动搜索返回的问题的工具提示给出了问题的前一点 但其中相当一部分没有给出任何比理解问题更有用的文本 标题 有谁知道如何制作一个过滤器来删除问题中无用的部分 我的第一个想法是修剪仅包含某个列表中的单词的任何前导句子 例如

随机推荐