keras/scikit-learn:使用 fit_generator() 进行交叉验证

2023-12-02

是否可以使用Keras 的 scikit-learn API和...一起fit_generator()方法?或者使用另一种方式来产生批次进行训练?我正在使用 SciPy 的稀疏矩阵,在输入 Keras 之前必须将其转换为 NumPy 数组,但由于内存消耗较高,我无法同时转换它们。这是我的批量生成函数:

def batch_generator(X, y, batch_size):
    n_splits = len(X) // (batch_size - 1)
    X = np.array_split(X, n_splits)
    y = np.array_split(y, n_splits)

    while True:
        for i in range(len(X)):
            X_batch = []
            y_batch = []
            for ii in range(len(X[i])):
                X_batch.append(X[i][ii].toarray().astype(np.int8)) # conversion sparse matrix -> np.array
                y_batch.append(y[i][ii])
            yield (np.array(X_batch), np.array(y_batch))

和交叉验证的示例代码:

from sklearn.model_selection import StratifiedKFold, GridSearchCV
from sklearn import datasets

from keras.models import Sequential
from keras.layers import Activation, Dense
from keras.wrappers.scikit_learn import KerasClassifier

import numpy as np


def build_model(n_hidden=32):
    model = Sequential([
        Dense(n_hidden, input_dim=4),
        Activation("relu"),
        Dense(n_hidden),
        Activation("relu"),
        Dense(3),
        Activation("sigmoid")
    ])
    model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])
    return model


iris = datasets.load_iris()
X = iris["data"]
y = iris["target"].flatten()

param_grid = {
    "n_hidden": np.array([4, 8, 16]),
    "nb_epoch": np.array(range(50, 61, 5))
}

model = KerasClassifier(build_fn=build_model, verbose=0)
skf = StratifiedKFold(n_splits=5).split(X, y) # this yields (train_indices, test_indices)

grid = GridSearchCV(model, param_grid, cv=skf, verbose=2, n_jobs=4)
grid.fit(X, y)

print(grid.best_score_)
print(grid.cv_results_["params"][grid.best_index_])

为了更多地解释它,它使用了所有可能的超参数组合param_grid建立一个模型。然后,每个模型都会在训练-测试数据分割上一一进行训练和测试(folds) 由...提供StratifiedKFold。那么给定模型的最终得分是所有折叠的平均得分。

那么是否可以在实际拟合之前在上面的代码中插入一些预处理子步骤来转换数据(稀疏矩阵)?

我知道我可以编写自己的交叉验证生成器,但它必须产生索引,而不是真实数据!


实际上,您可以使用稀疏矩阵作为带有生成器的 Keras 的输入。这是我在之前的项目中使用的版本:

> class KerasClassifier(KerasClassifier):
>     """ adds sparse matrix handling using batch generator
>     """
>     
>     def fit(self, x, y, **kwargs):
>         """ adds sparse matrix handling """
>         if not issparse(x):
>             return super().fit(x, y, **kwargs)
>         
>         ############ adapted from KerasClassifier.fit   ######################   
>         if self.build_fn is None:
>             self.model = self.__call__(**self.filter_sk_params(self.__call__))
>         elif not isinstance(self.build_fn, types.FunctionType):
>             self.model = self.build_fn(
>                 **self.filter_sk_params(self.build_fn.__call__))
>         else:
>             self.model = self.build_fn(**self.filter_sk_params(self.build_fn))
> 
>         loss_name = self.model.loss
>         if hasattr(loss_name, '__name__'):
>             loss_name = loss_name.__name__
>         if loss_name == 'categorical_crossentropy' and len(y.shape) != 2:
>             y = to_categorical(y)
>         ### fit => fit_generator
>         fit_args = copy.deepcopy(self.filter_sk_params(Sequential.fit_generator))
>         fit_args.update(kwargs)
>         ############################################################
>         self.model.fit_generator(
>                     self.get_batch(x, y, self.sk_params["batch_size"]),
>                                         samples_per_epoch=x.shape[0],
>                                         **fit_args)                      
>         return self                               
> 
>     def get_batch(self, x, y=None, batch_size=32):
>         """ batch generator to enable sparse input """
>         index = np.arange(x.shape[0])
>         start = 0
>         while True:
>             if start == 0 and y is not None:
>                 np.random.shuffle(index)
>             batch = index[start:start+batch_size]
>             if y is not None:
>                 yield x[batch].toarray(), y[batch]
>             else:
>                 yield x[batch].toarray()
>             start += batch_size
>             if start >= x.shape[0]:
>                 start = 0
>   
>     def predict_proba(self, x):
>         """ adds sparse matrix handling """
>         if not issparse(x):
>             return super().predict_proba(x)
>             
>         preds = self.model.predict_generator(
>                     self.get_batch(x, None, self.sk_params["batch_size"]), 
>                                                val_samples=x.shape[0])
>         return preds
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

keras/scikit-learn:使用 fit_generator() 进行交叉验证 的相关文章

随机推荐

  • “表达式中的声明和声明”是 GNU C 特有的吗?

    Are 表达式中的声明和声明特定于 GNU C 或者这个功能也包含在C99标准中 它是 GCC 扩展 参见海湾合作委员会文档 例如这里是 gcc 4 3 3 查看 GCC 扩展的完整列表 和C99 规范可在此处获取 如果您使用 pedant
  • Phantomjs抓取网页功能不起作用

    我正在使用 phantomjs 学习如何抓取网页 到目前为止我已经开发了以下代码 我知道我能够连接到该网站 但我根本无法从表中获取数据 am我走在正确的轨道上吗 我的目标是从表中抓取数据this地点 我还知道我需要使用 includeJs
  • 在 javascript “创建阶段”中,函数是否在变量之前设置?

    我正在学习 Udemy 课程Javascript 理解奇怪的部分现在 我刚刚了解了解释器解释 JS 时发生的创建阶段和执行阶段 我有一个问题 但我首先会向您展示我正在使用的代码 http codepen io rsf pen bEgpNY
  • 何时使用 `<>` 和 `!=` 运算符?

    找不到太多这方面的信息 尝试比较两个值 但它们不能相等 就我而言 它们可以 并且经常是 大于或小于 我应该使用 if a lt gt b dostuff or if a b dostuff 这一页说它们相似 这意味着它们至少有一些不同之处
  • Java 文档:“从接口 X 继承的方法”的含义是什么

    我一定缺少一些基本的 Java 术语 类可以扩展 因此它们的方法可以是遗传由他们的子类 接口可以是实施的 实现类必须实现接口的所有方法 接口本身不实现任何内容 仅进行声明 那么 为什么当我查看 HashSet 的文档时 https docs
  • 并排反应传单

    我想并排显示两个图块层 就像并排的传单插件一样 https github com digidem leaflet side by side 但是 我不确定如何通过反应来做到这一点 有没有办法在react中使用上述插件 您对如何实现此功能还有
  • 如何从 angular.dart 组件内部调用 jquery 插件?

    我正在通过尝试制作一个可以访问现有 jquery 插件的组件来学习 angular dart 组件 我正在尝试类似以下的事情 library mylib import dart html querySelector import packa
  • 使用 JavaScript 动态加载 JavaScript

    经过一个多小时的尝试让它工作后 我认为这是因为跨域策略 但我真的认为这会起作用 我也找不到很多相关信息 但是 这是我的问题 我有一个网站叫http mysite com然后我包括一个第三方脚本 我写的 及其地址http supercools
  • SASS 语法未在 css 中生成 &:hover

    我一直在四处寻找 在 stackoverwflow 和其他资源上发现了一些类似的问题 但其中大多数是关于语法错误的 有人可以告诉我这段代码有什么问题以及为什么 SASS 没有在生成的 css 中生成 hover 吗 这是我的 SASS 代码
  • 如何使用鼠标拖动事件在java小程序上绘制矩形并使其保持不变

    我有可以绘制矩形的程序 我有两个问题无法解决 当我绘制矩形后 它不会留下来 我拥有的唯一清除画布的代码 重绘仅在鼠标拖动时调用 为什么当我释放鼠标或移动鼠标时 我的画布会变清晰 第二件事并不是什么大问题 但我无法弄清楚 当我的矩形的高度或宽
  • Google Apps 脚本 V8 运行时使用哪个版本的 ECMAScript?

    当您创建新的 Google Apps 脚本时 它似乎默认支持 v8 运行时 这文档 states Apps 脚本支持两种 JavaScript 运行时 现代的 V8 运行时和由 Mozilla 的 Rhino JavaScript 解释器提
  • 当查询 SSRS 数据集之间没有数据时,向报告添加值

    这基本上与我在这个线程中提出的问题相同 当查询 SSRS 中没有数据时向报告添加值 现在唯一的区别是我想将相同的功能扩展到不同的数据集 想象一下 我有两个数据集 Dataset1 Dataset2 两者具有相同的主键 在本例中 销售代表 类
  • Python 多处理存储数据,直到在每个进程中进一步调用

    我有一个无法在进程之间共享的类型的大对象 它有方法来实例化它并处理它的数据 我当前的做法是首先在主父进程中实例化该对象 然后在发生某些事件时将其传递给子进程 问题是 每当子进程运行时 它们每次都会将对象复制到内存中 这需要一段时间 我想将它
  • 是否可以在React中使用CSS自定义FullCalendar?

    我刚刚从 FullCalendar 开始 我在一个react项目 现在一切都很好 但我想定制实际的日历 我希望它尊重我的客户需求 我的问题 是否可以添加班级名称像这样的 FullCalendar 组件 我尝试过 但无法到达 css 文件中的
  • Retrofit:如何解析组合了数组和对象的JSON数组?

    我正在开发一个 Android 应用程序 它使用 Retrofit OkHttp 连接到 REST API 并使用 JSON 数据 我对 Retrofit 还很陌生 所以我仍在学习它是如何工作的 但到目前为止 一切都非常顺利 然而 我遇到了
  • 不要与 SVN 进行 diff 合并

    我想了解我在功能分支上所做的所有更改的差异 目前我使用 svn log stop on copy awk r NAME print 1 xargs l svn diff c gt code diff 不幸的是 这包括主干合并到我的分支中并使
  • 如何读取android设备上beacon的UDID、Major、Minor?

    我正在尝试为 Android 开发 BLE 应用程序 有什么方法可以检测和读取 Android 设备上信标的 UDID 主要 次要吗 我已阅读 RadiusNetworks android ibeacon service 但我不明白为什么
  • 使用 Unity(而不是温莎城堡)可以实现这一点吗?

    This 博客文章展示了一种使用 Castle Windsor 和 NSubstitute 实现自动模拟的方法 我不知道也不使用 Castle Windsor 但我确实使用 Unity 和 NSubstitute 有没有办法使用 Unity
  • 如何在一个命令行操作中解压文件并重命名文件夹?

    我想下载一个文件 解压它并重命名该文件夹 我可以下载该文件并将其解压 curl https s3 amazonaws com sampletest sample tar gz tar xz 如何在同一命令中重命名文件夹 curl https
  • keras/scikit-learn:使用 fit_generator() 进行交叉验证

    是否可以使用Keras 的 scikit learn API和 一起fit generator 方法 或者使用另一种方式来产生批次进行训练 我正在使用 SciPy 的稀疏矩阵 在输入 Keras 之前必须将其转换为 NumPy 数组 但由于