如何访问 RandomForestClassifier(spark.ml-version)创建的模型中的各个树?

2023-12-13

如何访问 Spark ML 生成的模型中的各个树随机森林分类器?我正在使用 RandomForestClassifier 的 Scala 版本。


其实它有trees属性:

import org.apache.spark.ml.attribute.NominalAttribute
import org.apache.spark.ml.classification.{
  RandomForestClassificationModel, RandomForestClassifier, 
  DecisionTreeClassificationModel
}

val meta = NominalAttribute
  .defaultAttr
  .withName("label")
  .withValues("0.0", "1.0")
  .toMetadata

val data = sqlContext.read.format("libsvm")
  .load("data/mllib/sample_libsvm_data.txt")
  .withColumn("label", $"label".as("label", meta))

val rf: RandomForestClassifier = new RandomForestClassifier()
  .setLabelCol("label")
  .setFeaturesCol("features")

val trees: Array[DecisionTreeClassificationModel] = rf.fit(data).trees.collect {
  case t: DecisionTreeClassificationModel => t
}

正如你所看到的,唯一的问题是获取正确的类型,这样我们就可以实际使用它们:

trees.head.transform(data).show(3)
// +-----+--------------------+-------------+-----------+----------+
// |label|            features|rawPrediction|probability|prediction|
// +-----+--------------------+-------------+-----------+----------+
// |  0.0|(692,[127,128,129...|   [33.0,0.0]|  [1.0,0.0]|       0.0|
// |  1.0|(692,[158,159,160...|   [0.0,59.0]|  [0.0,1.0]|       1.0|
// |  1.0|(692,[124,125,126...|   [0.0,59.0]|  [0.0,1.0]|       1.0|
// +-----+--------------------+-------------+-----------+----------+
// only showing top 3 rows

Note:

如果您使用管道,您也可以提取单个树:

import org.apache.spark.ml.Pipeline

val model = new Pipeline().setStages(Array(rf)).fit(data)

// There is only one stage and know its type 
// but lets be thorough
val rfModelOption = model.stages.headOption match {
  case Some(m: RandomForestClassificationModel) => Some(m)
  case _ => None
}

val trees = rfModelOption.map {
  _.trees //  ... as before
}.getOrElse(Array())
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

如何访问 RandomForestClassifier(spark.ml-version)创建的模型中的各个树? 的相关文章

  • 如何通过sparkSession向worker提交多个jar?

    我使用的是火花2 2 0 下面是我在 Spark 上使用的 java 代码片段 SparkSession spark SparkSession builder appName MySQL Connection master spark ip
  • 通用 scala 函数,其输入是变量数量的函数

    我想定义一个函数f需要另一个函数g 我们需要g采取采取n双打 对于某些固定n 并返回一个 Double 函数调用f g 应该返回具体值n 例如 f Math max 2因为 Math sin 具有类型 Double Double gt Do
  • 减少/折叠幺半群列表,但减少器返回任一

    我发现自己遇到过几次这样的情况 我有一个减速器 组合 fn 如下所示 def combiner a String b String Either String String a b asRight String 它是一个虚拟实现 但 fn
  • 在 Akka 中配置嵌套 Router

    我有一些嵌套的路由器 应创建它FromConfig 我想要的是这样的 test akka actor deployment worker router round robin nr of instances 5 slave router b
  • 如何将模型结果保存到文本文件?

    我正在尝试将从模型生成的频繁项集保存到文本文件中 该代码是 Spark ML 库中 FPGrowth 示例的示例 Using saveAsTextFile直接在模型上写入 RDD 位置而不是实际值 import org apache spa
  • 为什么 Scala 中的隐式类必须驻留在另一个特征/类/对象中?

    基于scala文档 http docs scala lang org overviews core implicit classes html http docs scala lang org overviews core implicit
  • 多个 scala 库导致 intellij 出错?

    我正在使用 intellij 14 和 scala 2 11 6 使用 homebrew 安装并使用符号链接 ln s usr local Cellar scala 2 11 6 libexec src usr local Cellar s
  • 如何使用 PySpark 预处理图像?

    我有一个项目 需要为 1 设置大数据架构 AWS S3 SageMaker 的概念验证使用 PySpark 预处理图像 2 执行 PCA and 3 训练一些机器或深度学习模型 我的问题是了解如何使用 PySpark 操作图像数据 但无法在
  • 更改 build.sbt 自定义任务中的版本

    我在 build sbt 中定义了一个自定义任务 val doSmth taskKey Unit smth doSmth version 1 0 SNAPSHOT 但它不会改变版本 我真正想要的是自定义 sbt 发布任务 它将始终将相同的版
  • 模拟 BlazeClientBuilder[IO] 以返回模拟客户端[IO]

    我正在使用BlazeClientBuilder IO resource方法得到Client IO 现在 我想模拟客户端进行单元测试 但不知道该怎么做 有没有一个好的方法来嘲笑这个 我会怎么做 class ExternalCall val r
  • 高效序列化案例类

    对于我正在工作的图书馆 我需要提供一个高效 便捷 typesafe序列化 scala 类的方法 理想的情况是用户可以创建一个案例类 并且只要所有成员都是可序列化的 它似乎也应该如此 我准确地知道序列化和反序列化阶段的类型 因此不需要 也不能
  • 对两种类型之间的二元关系进行建模

    有企业 也有人 用户可以对某个企业点赞或发表评论 但效果是一样的can not发生在一个人身上 当用户发布有关某个企业的内容或对其点赞时 该企业就被称为target喜欢或帖子 trait TargetingRelation Targetin
  • 使用 net.liftweb.json 或 scala.util.parsing.json 解析大型 (30MB) JSON 文件会出现 OutOfMemoryException。有什么建议吗?

    我有一个包含大量测试数据的 JSON 文件 我想解析这些数据并推送我正在测试的算法 它的大小约为 30MB 包含大约 60 000 个元素的列表 我最初在 scala util parsing json 中尝试了简单的解析器 如下所示 im
  • 为什么用scala写的代码比用java写的慢6倍?

    我不确定我在编写 scala 代码时是否犯了一些错误 问题是 The four adjacent digits in the 1000 digit number that have the greatest product are 9 9
  • scala中的反引号有什么用[重复]

    这个问题在这里已经有答案了 我在一本书上找到了以下代码 val list List 5 4 3 2 1 val result 0 list running total next element running total next elem
  • Java 8 Stream,获取头部和尾部

    Java 8 引入了Stream http download java net jdk8 docs api java util stream Stream html类似于 Scala 的类Stream http www scala lang
  • 具有继承类型的 Aux 模式推理失败

    我有一个复杂的玩具算法 我希望纯粹在类型级别上表示 根据饮食要求选择当天菜肴的修改 对卷积表示歉意 但我认为我们需要每一层才能达到我想要使用的最终界面 我的代码有一个问题 如果我们表达一个类型约束Aux 模式生成的类型基于另一个泛型类型 它
  • 将 IndexToString 应用于 Spark 中的特征向量

    Context 我有一个数据框 其中所有分类值都已使用 StringIndexer 进行索引 val categoricalColumns df schema collect case StructField name StringType
  • Spark 中的 Distinct() 函数如何工作?

    我是 Apache Spark 的新手 正在学习基本功能 有一个小疑问 假设我有一个元组 键 值 的 RDD 并且想从中获取一些唯一的元组 我使用distinct 函数 我想知道该函数基于什么基础认为元组是不同的 是基于键 值还是两者 di
  • 为什么自类型类可以声明类

    我知道 Scala 只能混合特征 这对于依赖注入和蛋糕模式是有意义的 我的问题是为什么我仍然可以声明一个需要另一个 类 但不需要特征的类 Code class C class D self C gt 这仍然编译成功 我认为它应该编译失败 因

随机推荐

  • 以最少的 malloc 调用次数为二维数组分配内存

    我使用下面的代码片段使用最小数量为二维数组分配内存malloc calls 我想使用下标 p i j 访问数组 define ROW 3 define COL 2 int main void ptr malloc ROW COL sizeo
  • 使用接近“INT_MAX”的“count”值传送数据

    消息传递接口 API 始终使用int作为一个类型count变量 例如 原型为MPI Send is int MPI Send const void buf int count MPI Datatype datatype int dest i
  • 防止向记分板提交欺诈性信息

    我正在开发 Flash 游戏的后端 我需要secure数据进入记分板 该游戏将在许多网站上以横幅广告形式托管 用户将在广告中玩游戏 然后点击进入主网站以保存其详细信息 目前我正在思考这个问题 用户玩游戏并点击提交分数 在后台 横幅将分数和原
  • 从扩展中禁用“wordBasedSuggestions”等默认设置

    我正在开发 VSCode 的扩展 它提供完成项 但其中有单词建议 我知道您可以在用户 工作空间设置中禁用editor wordBasedSuggestions但是有没有办法从扩展中做到这一点 是的 扩展程序可以通过贡献来更改设置的默认值co
  • Sql where 子句在过滤器为空的情况下返回所有内容

    我下面有一个 sql 表 SrNo Name Value 1 A X1 2 B NULL 3 C X3 4 D X4 5 E NULL 6 F NULL 我试图从表中获取所有记录 并满足以下两个条件 a 如果 Value 列上的过滤器为 n
  • 使用 core-plot 库创建 .ipa 时,xCode 4“找不到文件”

    我的应用程序已准备好发布 但无法创建所需的 ipa 我在一个非常小的例子中重现了我的问题 1 创建一个新项目 我使用了导航栏应用程序 2 存档构建 3 分享 ipa 在指定位置创建 4 下载 安装 core plot 5 使用方法2添加库
  • 有没有办法从多个文件夹运行所有 pytest 用例?

    假设我有test case1 py在文件夹中A and test case2 py在文件夹中B 我可以使用一个单一的来运行它们吗pytest命令 文件夹结构 projectfolder A test case1 py projectfold
  • 在 Facebook IOS SDK 中禁用单点登录 (SSO)

    我们构建了一个使用 Facebook SDK 的 iOS 应用程序 不幸的是 我们的客户要求我们禁用应用程序中的后台 这意味着 Facebook 单点登录 SSO 方案对我们不起作用 因为我们的应用程序现在在登录 授权后启动时从头开始 在
  • jq - 如何根据属性值的“黑名单”选择对象

    类似于这里回答的问题 jq 如何根据属性值的 白名单 选择对象 我想根据属性值黑名单选择对象 以下内容可以很好地作为白名单 curl s https api github com repos stedolan jq commits per
  • JQuery 表单提交添加请求标头

    我想问一下调用前是否可以指定 headers myForm submit 我知道您可以在 AJAX post 请求中指定 但是在提交这个简单的表单之前可以吗 是的你可以 需要一定的本土化JavaScript苦差事 我就是这样做的 h1 Cu
  • 如何优化2个相同的内核,占用率50%,可以在CUDA中同时运行?

    我在 CUDA 中有 2 个相同的内核 报告理论占用率为 50 并且可以同时运行 但是 在不同的流中调用它们会显示顺序执行 每个内核调用的网格和块尺寸如下 Grid 3 568 620 Block 256 1 1 With 50 regis
  • 来自样式对象的 PHPExcel 特定单元格格式

    我在项目中使用 PHPExcel 需要设置 Excel 工作表单元格的样式 我所做的是创建一个 PHPExcel 样式对象 如下所示 style red text new PHPExcel Style 然后 我使用此样式的设置函数来填充对象
  • 下载的文件作为控制器(ASP.NET MVC 3)中的流会自动处理吗?

    让我们假设下载所选文件的控制器 public FileResult Download string f Stream file MyModel DownloadFiles f return File file application oct
  • 为设备手动设置 USB 传输类型

    我尝试在 ARM 板 Pandaboard 上运行 Asus xtion 并且我已经安装并使用了 openni 提供的示例 例如 NiSimpleRead 为了让这些示例在此平台上运行 需要进行一些调整 其中之一是将 UsbInterfac
  • 为什么浮动元素的背景看起来独立于内容而移动?

    在下面的 CSS 代码中 背景似乎是divTwo已经落后了divOne 但内容divTwo似乎已被抛在后面 为什么 div 的背景似乎独立于内容移动 divOne width 300px height 100px background co
  • 正确理解相同主机/不同端口和安全性的 CORS

    我不做太多客户端网络编程 所以我试图理解这个概念与我的具体情况的关系 我有一个 RESTful WCF 服务在 50000 多个端口上运行 此外 我还有一堆用 HTML5 CSS3 JavaScript 编写的 Web 表单 不是 ASP
  • Python中如何获取最新的目录

    我正在寻找一种方法 可以找到在另一个目录中创建的最新目录 我唯一的方法是os listdir 但它显示了里面的所有文件和目录 如何仅列出目录以及如何访问目录的属性以查找最新创建的目录 谢谢 import os dirs d for d in
  • PHP下载MySQL数据库备份

    我想让客户能够手动下载其数据库的备份 我正在使用 PHP 和 MySQL 对该网站进行编码 因此 管理员用户登录后 菜单中会出现一个链接 用于将 sql 文件下载到本地计算机 我怎样才能用 PHP 来完成这个任务 尝试从 PHP 备份数据库
  • CSS @font-face 在 ie9 中不起作用

    我设法使用一种自定义字体 该字体适用于每个值得被称为 浏览器 的浏览器 出色地一如既往这些很酷的东西不适用于 ie 在本例中为 ie9 我尝试了以下方法 font face font family Roboto src url fonts
  • 如何访问 RandomForestClassifier(spark.ml-version)创建的模型中的各个树?

    如何访问 Spark ML 生成的模型中的各个树随机森林分类器 我正在使用 RandomForestClassifier 的 Scala 版本 其实它有trees属性 import org apache spark ml attribute