H2O R api:从网格搜索中检索最佳模型

2024-04-15

我正在使用h2oR 中的包(v 3.6.0),并且我构建了一个网格搜索模型。现在,我正在尝试访问最小化验证集上的 MSE 的模型。在Python中sklearn,这在使用时很容易实现RandomizedSearchCV:

## Pseudo code:
grid = RandomizedSearchCV(model, params, n_iter = 5)
grid.fit(X)
best = grid.best_estimator_

不幸的是,这在 h2o 中并不那么简单。这是您可以重新创建的示例:

library(h2o)
## assume you got h2o initialized...

X <- as.h2o(iris[1:100,]) # Note: only using top two classes for example 
grid <- h2o.grid(
    algorithm = 'gbm',
    x = names(X[,1:4]),
    y = 'Species',
    training_frame = X,
    hyper_params = list(
        distribution = 'bernoulli',
        ntrees = c(25,50)
    )
)

Viewing grid打印大量信息,包括这部分:

> grid
ntrees distribution status_ok                                                                 model_ids
 50    bernoulli        OK Grid_GBM_file1742e107fe5ba_csv_10.hex_11_model_R_1456492736353_16_model_1
 25    bernoulli        OK Grid_GBM_file1742e107fe5ba_csv_10.hex_11_model_R_1456492736353_16_model_0

通过一些挖掘,您可以访问每个单独的模型并查看每个可以想象的指标:

> h2o.getModel(grid@model_ids[[1]])
H2OBinomialModel: gbm
Model ID:  Grid_GBM_file1742e107fe5ba_csv_10.hex_11_model_R_1456492736353_18_model_1 
Model Summary: 
  number_of_trees model_size_in_bytes min_depth max_depth mean_depth min_leaves max_leaves mean_leaves
1              50                4387         1         1    1.00000          2          2     2.00000


H2OBinomialMetrics: gbm
** Reported on training data. **

MSE:  1.056927e-05
R^2:  0.9999577
LogLoss:  0.003256338
AUC:  1
Gini:  1

Confusion Matrix for F1-optimal threshold:
           setosa versicolor    Error    Rate
setosa         50          0 0.000000   =0/50
versicolor      0         50 0.000000   =0/50
Totals         50         50 0.000000  =0/100

Maximum Metrics: Maximum metrics at their respective thresholds
                      metric threshold    value idx
1                     max f1  0.996749 1.000000   0
2                     max f2  0.996749 1.000000   0
3               max f0point5  0.996749 1.000000   0
4               max accuracy  0.996749 1.000000   0
5              max precision  0.996749 1.000000   0
6           max absolute_MCC  0.996749 1.000000   0
7 max min_per_class_accuracy  0.996749 1.000000   0

并带有一个lot经过挖掘,你终于可以得到这个:

> h2o.getModel(grid@model_ids[[1]])@model$training_metrics@metrics$MSE
[1] 1.056927e-05

为了得到一个应该是模型选择的顶级指标,这似乎需要做很多繁琐的工作。在我的情况下,我有一个包含数百个模型的网格,而我当前的 hacky 解决方案似乎不太“R 式”:

model_select_ <- function(grid) {
  model_ids <- grid@model_ids
  min = Inf
  best_model = NULL

  for(model_id in model_ids) {
    model <- h2o.getModel(model_id)
    mse <- model@model$training_metrics@metrics$MSE
    if(mse < min) {
      min <- mse
      best_model <- model
    }
  }

  best_model
}

对于机器学习实践如此核心的东西来说,这似乎有点矫枉过正,而且让我感到奇怪的是,h2o 没有一种“更干净”的方法来提取最佳模型,或者至少是模型指标。

我错过了什么吗?是否没有“开箱即用”的方法来选择最佳模型?


是的,有一种简单的方法可以提取 H2O 网格搜索的“顶部”模型。还有一些实用函数可以提取所有模型指标(例如h2o.mse)您一直在尝试访问。有关如何执行这些操作的示例可以在H2O-R/演示 https://github.com/h2oai/h2o-3/tree/master/h2o-r/demos and h2o-py/演示 https://github.com/h2oai/h2o-3/tree/master/h2o-py/demos上的子文件夹h2o-3 https://github.com/h2oai/h2o-3GitHub 存储库。

由于您使用的是 R,因此这里有一个相关代码示例 https://github.com/h2oai/h2o-3/blob/master/h2o-r/demos/H2O_tutorial_eeg_eyestate_NOPASS.ipynb其中包括网格搜索和排序结果。您还可以在 R 文档中找到如何访问此信息h2o.getGrid功能。

打印出所有模型的 auc,按验证 AUC 排序:

auc_table <- h2o.getGrid(grid_id = "eeg_demo_gbm_grid", sort_by = "auc", decreasing = TRUE)
print(auc_table)

以下是输出示例:

H2O Grid Details
================

Grid ID: eeg_demo_gbm_grid 
Used hyper parameters: 
  -  ntrees 
  -  max_depth 
  -  learn_rate 
Number of models: 18 
Number of failed models: 0 

Hyper-Parameter Search Summary: ordered by decreasing auc
   ntrees max_depth learn_rate                  model_ids               auc
1     100         5        0.2 eeg_demo_gbm_grid_model_17 0.967771493797284
2      50         5        0.2 eeg_demo_gbm_grid_model_16 0.949609591795923
3     100         5        0.1  eeg_demo_gbm_grid_model_8  0.94941792664595
4      50         5        0.1  eeg_demo_gbm_grid_model_7 0.922075196552274
5     100         3        0.2 eeg_demo_gbm_grid_model_14 0.913785959685157
6      50         3        0.2 eeg_demo_gbm_grid_model_13 0.887706691652792
7     100         3        0.1  eeg_demo_gbm_grid_model_5 0.884064379717198
8       5         5        0.2 eeg_demo_gbm_grid_model_15 0.851187402678818
9      50         3        0.1  eeg_demo_gbm_grid_model_4 0.848921799270639
10      5         5        0.1  eeg_demo_gbm_grid_model_6 0.825662907513139
11    100         2        0.2 eeg_demo_gbm_grid_model_11 0.812030639460551
12     50         2        0.2 eeg_demo_gbm_grid_model_10 0.785379521713437
13    100         2        0.1  eeg_demo_gbm_grid_model_2  0.78299280750123
14      5         3        0.2 eeg_demo_gbm_grid_model_12 0.774673686150002
15     50         2        0.1  eeg_demo_gbm_grid_model_1 0.754834657912535
16      5         3        0.1  eeg_demo_gbm_grid_model_3 0.749285131682721
17      5         2        0.2  eeg_demo_gbm_grid_model_9 0.692702793188135
18      5         2        0.1  eeg_demo_gbm_grid_model_0 0.676144542037133

表中的第一行包含具有最佳 AUC 的模型,因此下面我们可以获取该模型并提取验证 AUC:

best_model <- h2o.getModel(auc_table@model_ids[[1]])
h2o.auc(best_model, valid = TRUE)

为了h2o.getGrid函数能够按验证集上的指标进行排序,您需要实际传递h2o.grid函数avalidation_frame。在上面的示例中,您没有传递validation_frame,因此您无法评估验证集上网格中的模型。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

H2O R api:从网格搜索中检索最佳模型 的相关文章

随机推荐

  • struct - 使用 qsort 对 C 字符串进行排序

    我正在对一堆 IP 进行排序 但由于某种原因 它们的顺序错误 我不太确定问题出在哪里 66 249 71 3 190 148 164 245 207 46 232 182 190 148 164 245 190 148 164 245 20
  • Google 应用已发布到内部测试轨道,但无法找到/下载

    我已成功完成 APK 到内部测试轨道的发布过程 但是 当我尝试使用下面屏幕截图中的 在 GOOGLE PLAY 上查看 链接查看 Google Play 商店上下载的应用程序时 it opens a new window with the
  • 超链接在 Android UC 浏览器中不起作用

    我被一个问题困扰 我正在尝试通过放置在我的网站中的超链接打开 Android 应用程序 下面是链接 href intent Intent action com example myapp category android intent ca
  • 在 Objective-C 中,我可以在 c 浮点数组上声明 @property 吗?

    thing h interface Thing NSObject float stuff 30 property float stuff end thing m implementation Thing synthesize stuff e
  • 玩!没有正确关闭 H2

    我正在使用 Play 编写一个部署在 Tomcat 中的 Web 应用程序 因为应用程序不会处理太多数据 所以我将默认的 H2 数据库与 Hibernate 一起使用 当我想要部署新版本的应用程序时 我关闭 tomcat 擦除旧的 web
  • 如何使 bash 脚本与一个又一个命令一起工作?

    我有一个如下所示的 bash 脚本 首先 它将sorted bam 文件作为输入 并使用 stringtie 工具将每个样本gtf 作为输出 然后每个样本 gtf 的路径将被赋予到 mergelist txt 中 然后对它们使用 strin
  • 如何跟踪 celery 中的重试次数

    在 Celery 中 如何跟踪当前的重试 我知道我可以做这样的事情 app task bind True default retry delay 900 max retries 5 def send email self sender No
  • 活动开启两次

    我有一个使用的应用程序城市飞艇 http urbanairship com 用于推送通知 当通知到达并且用户单击它时 我的应用程序中的活动 A 应该打开并执行某些操作 我已经安装了BroadcastReceiver如图所示在文档中 http
  • 在 C++ 中将数组转换为集合

    有没有更简单的方法使用 C 将数组转换为集合而不是循环遍历其元素 最好使用标准模板库 对于所有标准库容器类型 请使用构造函数 http en cppreference com w cpp container set set std set
  • ASP.NET Owin OAuth (Google / Facebook) 正在重定向到默认的 login.aspx,而不是远程登录页面

    我正在使用 Owin 库 包括 Google 和 Facebook 设置 OAuth 从表面上看 Owin 启动课程注册得很好 我发现我没有被重定向到 Facebook 或 Google 的相应登录页面 而是被重定向到默认的 login a
  • 从 SDK 上的“getLastKnownLocation”获取 null

    我有一个与位置 API 相关的问题 我尝试了以下代码 LocationManager lm LocationManager getSystemService Context LOCATION SERVICE Location loc get
  • 避免 D3.js 中子节点重叠

    我正在使用 D3 js 构建一个树结构 显示 Facebook 用户和他 她的 Facebook 好友 根节点是用户 子节点是好友 我的 UI 中有固定宽度 问题是子节点将相互重叠 var nodes tree nodes root rev
  • 使用 Resharper 7 测试运行程序进行 Jasmine 测试的堆栈跟踪

    如何让 Resharper 7 测试运行程序显示 Jasmine 测试的堆栈跟踪 我的设置是 Resharper 7 在 Jasmine 中构建 测试运行器和 PhantomJs 执行任何失败的测试时 错误消息始终以以下内容结尾 Excep
  • jquery-ui - 取消拖动转义键

    我有一个可拖动的列表divs 和一个可放置区域 在 chrome FF 和 IE9 中 鼠标拖放功能运行良好 我想添加键盘交互 拖拽div使用按键时应恢复到列表esc钥匙 所以首先我这样做了 document keyup function
  • ng-grid 行模板中的日期格式

    我创建了一个具有以下列定义的 ng grid columns field CompanyPkid visible false field CompanyName visible false field StartDate visible f
  • 如何让用户能够使用我的应用程序播放视频?

    昨晚刚刚花了几个小时为 Honeycomb 开发了一个非常漂亮的视频播放器 现在我当然希望人们能够使用它 如何让我的应用程序监听 接收 视频播放广播 我猜这与manifest xml文件 但我无法在 Android 开发者网站上找到任何有关
  • 然后 Groupby 检查行匹配并计算该值的并发实例数

    我有这个数据框 car color years max years 0 audi black 1 7 1 audi blue 2 7 2 audi purple 4 7 3 audi black 6 7 4 bmw blue 1 5 5 b
  • 为什么 CAS(原子)操作比同步或易失性操作更快

    据我了解 synchronized关键字将本地线程缓存与主内存同步 volatile 关键字基本上总是在每次访问时从主内存中读取变量 当然 访问主内存比本地线程缓存要昂贵得多 因此这些操作的成本很高 然而 CAS 操作使用低级硬件操作 但仍
  • 有条件的 Mercurial 忽略文件

    我在 Mercurial 中有一个文件 我希望开发机器提取该文件 但我希望部署服务器不提取该文件 它具有开发机器没有的特殊模块 这是可能的 还是我应该有一个自定义的推送到服务器解决方案 而不是仅仅进行 hg pull 执行此操作的典型方法是
  • H2O R api:从网格搜索中检索最佳模型

    我正在使用h2oR 中的包 v 3 6 0 并且我构建了一个网格搜索模型 现在 我正在尝试访问最小化验证集上的 MSE 的模型 在Python中sklearn 这在使用时很容易实现RandomizedSearchCV Pseudo code