加载sklearn模型后无法进行预测

2023-12-23

我使用 Scikit-Learn 创建了一个 ML 模型并保存了它。现在,当我加载模型时,我在转换和预测方面遇到了麻烦。 我在 DataFrame 中有 4 个功能。前两个特征是文本特征,另外两个特征是数字特征。结果列为 1 或 0。

为了训练我的模型,我使用了ColumnTransformer and CountVectorizer用于文本特征的转换和矢量化。我指定了要转换/矢量化的列的名称。 (文本 1 和文本 2 列)。数字列不需要矢量化,因此remainder='passthrough'正在解决这个问题。

有效的部分代码:

features = df.iloc[:, :-1]
results = df.iloc[:, -1]

transformerVectoriser = ColumnTransformer(transformers=[('vector word 1', CountVectorizer(analyzer='word', ngram_range=(1, 1), max_features = 12000, stop_words = 'english'), 'text1'),
                                                       ('vector phrase 3', CountVectorizer(analyzer='word', ngram_range=(3, 3), max_features = 2500, stop_words = 'english'), 'text2')],
                                                      remainder='passthrough') # Default is to drop untransformed columns, passthrough == leave columns as they are

x_train, x_test, y_train, y_test = train_test_split(features, results, test_size=0.3, random_state=0)

x_train = transformerVectoriser.fit_transform(x_train)
x_test = transformerVectoriser.transform(x_test)


model = clf.fit(x_train, y_train)
y_pred = model.predict(x_test)

filename = 'ml_model.sav'
pickle.dump(model, open(filename, 'wb'))

filename = 'ml_transformer.sav'
pickle.dump(transformerVectoriser, open(filename, 'wb'))

但是当我想加载模型并进行预测时,我收到错误:

# LOADING MODEL
model = pickle.load(open('ml_model.sav','rb'))
vectorizer = pickle.load(open('ml_transformer.sav','rb'))

# MAKING PREDICTION
data_for_prediction = vectorizer.transform([data_for_prediction]) #ERROR
print(model.predict_proba(data_for_prediction))

我收到错误:

ValueError: Specifying the columns using strings is only supported for pandas DataFrames

当我训练我的模型时,我使用了Pandasdataframe,当我想进行预测时,我只是将值放入列表中。所以data_for_prediction是列表,看起来像这样:

["text that should be vectorized with vectorizer that i created", "More texts that should be vectorized", 4, 7]

我认为这就是错误,因为我在使用 ColumnTransformer 时使用了列名,但现在当我想要进行预测时,向量化器不知道要向量化什么。 我的最终模型和矢量化器应该在 API 中使用,而 api 应该只接受 JSON,所以我不想将 JSON 转换为 DataFrame 并将其传递给模型。 有没有办法在不使用 pandas 的情况下修复此错误dataframe在我最后的 Flask 应用程序中。


训练数据是一个包含以下列的数据框:

x_train.columns

功能vectorizer.transform()想要相同格式的数据,所以假设

data_f_p = ["text that should be vectorized", 4,7,0]

对应于相同的四列x_train你可以把它变成一个数据框

data_f_p = pd.DataFrame([data_f_p], columns=x_train.columns)
data_f_p = vectorizer.transform(data_f_p)
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

加载sklearn模型后无法进行预测 的相关文章

随机推荐

  • 如何通过索引将项目添加到 Laravel Eloquent Collection 中?

    我尝试了以下方法 但它不起作用 index 2 collection gt put index item4 例如 如果 collection 看起来像这样 collection item1 item2 item3 我想结束 collecti
  • Gradle 在 bin 目录中创建重复的启动脚本

    我正在尝试通过 gradle 创建多个启动脚本文件 但不知何故 一个特定的启动脚本文件正在重复 startScripts enabled false run enabled false def createScript project ma
  • Java 泛型放在 Map>

    有没有办法以类型安全的方式进行以下实现 public void myMethod Map
  • 如何比较 Svelte 3 中的 Prop 变化

    Svelte 3 中是否有一种机制可以在渲染之前比较组件内的 prop 更改 类似于 反应从Props获取DerivedState https reactjs org docs react component html static get
  • 车把模板中 href 标签中的 Ember 插值

    我正在尝试建立一个到谷歌地图的简单链接 并将动态地址插入到 href 字段中 我已经尝试过下面的代码以及大量其他乱七八糟的东西 但没有运气 如何在车把 href 字段中插入动态 ember 字符串 我正在使用 ember 导轨和车把 如果我
  • 将二进制路径添加到 emacs $PATH

    我尝试了以下方法 setenv PATH concat getenv PATH mybin setq exec path append exec path mybin 但这从来没有奏效 我试过M 并键入二进制名称之一 并且在使用二进制名称进
  • Select2:init后如何设置数据?

    我需要在初始化 select2 后设置一个数据数组 所以我想做这样的事情 var select select select2 select data id 1 text value1 id 1 text value1 但我收到以下错误 当附
  • 连接字符串和实体框架的问题

    我有一个数据库 sql 2008 mdf 文件 一个带有 edmx 文件的类库项目 是使用向导创建的 所以连接字符串也是由向导制作的 该项目位于 teamfoundation 服务器上 我可以在编码时使用所有向导创建的对象 但是当我运行程序
  • DisplayFormat 未应用于十进制值

    我有一个模型属性 我正在尝试使用 EditorFor 模板进行渲染 并且我正在尝试使用 DisplayFormat 属性应用格式 然而 它根本不起作用 它完全被忽略了 这是我的模板 model System Decimal Html Tex
  • IoC:如何动态创建对象

    我无法理解如何在需要动态创建对象的场景中使用 IoC 假设我有这样的课程 abstract class Field public Field ICommandStack commandStack abstract class Entity
  • 使用 python 的树莓派旋转编码器脚本

    我有一个设置 其中有一个电机以每秒约 1 转的速度转动直径 5 厘米的轴 我需要在预定的转数后停止电机 现在假设是 10 转 我使用的传感器机制只是一个磁铁和簧片开关 以下脚本可以很好地记录每次触发开关的情况 import RPi GPIO
  • Android SQLite 数据库损坏

    这个链接准确地描述了我的问题 http old nabble com Android database corruption td28044218 html a28044218 http old nabble com Android dat
  • 如何在postgresql中使用设置种子选择可重复的随机数?

    我想要实现的是为流程选择一个控制组 为此 我使用 random 为了调试 一致性 我希望能够以可重复的方式设置随机数 意思是 一旦它为用户 123 分配随机数 0 001 我就运行查询 在不同的时间 我删除以前的数据 调用相同的查询 并再次
  • 如何从应用程序设置 Azure (webapp) 接收数据到我的 webjob

    我用 C 创建了一个 Azure WebJob 我在 Azure 上有一个 Web 应用程序 我将 WebJob 添加到了我的订阅中 一切都很好 但在应用程序设置中我添加了一个新条目 例如
  • C 比较两个位图的最快方法

    有两个字符数组形式的位图数组 有数百万条记录 使用 C 来比较它们的最快方法是什么 我可以想象在 for 循环中一次使用按位运算符异或 1 个字节 关于位图的重要一点 算法运行的 1 到 10 次中 位图可能会有所不同 大多数时候它们都是一
  • async void 方法每次调用时都会创建一个新线程吗?

    我有以下场景 async void DoStuff button1 Click s p gt DoStuff 我不确定当我打电话时会发生什么async void方法 而第一次调用仍然不完整 该调用是否会在每次调用时创建一个新线程 还是会销毁
  • 将 jar 库导入 android-studio

    android studio 0 2 7 Fedora 18 Hello 我正在尝试将 jtwitter jar 添加到我的项目中 首先我尝试执行以下操作 1 Drag the jtwitter jar into the root dire
  • 使用 Wagtail 页面或 Django 模型的指南?

    例如 我想使用wagtail建立一个电子商务网站 其中一个组件是订单 我认为 order 不应该是 wagtail Page 而是简单的 Django 模型 见下面的代码 from django db import models from
  • 如何计算两个国家到国家、国家到城市、城市到城市之间的距离? [关闭]

    很难说出这里问的是什么 这个问题是含糊的 模糊的 不完整的 过于宽泛的或修辞性的 无法以目前的形式得到合理的回答 如需帮助澄清此问题以便重新打开 访问帮助中心 help reopen questions 如何计算两个国家到国家 国家到城市
  • 加载sklearn模型后无法进行预测

    我使用 Scikit Learn 创建了一个 ML 模型并保存了它 现在 当我加载模型时 我在转换和预测方面遇到了麻烦 我在 DataFrame 中有 4 个功能 前两个特征是文本特征 另外两个特征是数字特征 结果列为 1 或 0 为了训练