使用树输出预测 Spark 中梯度提升树情况下的类概率

2024-04-09

众所周知,Spark 中的 GBT 目前可以为您提供预测标签。

我正在考虑尝试计算一个类的预测概率(假设所有实例都落在某个叶子下)

构建 GBT 的代码

import org.apache.spark.SparkContext
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.tree.GradientBoostedTrees
import org.apache.spark.mllib.tree.configuration.BoostingStrategy
import org.apache.spark.mllib.tree.model.GradientBoostedTreesModel
import org.apache.spark.mllib.util.MLUtils

//Importing the data
val data = sc.textFile("data/mllib/credit_approval_2_attr.csv") //using the credit approval data set from UCI machine learning repository

//Parsing the data
val parsedData = data.map { line =>
    val parts = line.split(',').map(_.toDouble)
    LabeledPoint(parts(0), Vectors.dense(parts.tail))
}

//Splitting the data
val splits = parsedData.randomSplit(Array(0.7, 0.3), seed = 11L)
val training = splits(0).cache() 
val test = splits(1)

// Train a GradientBoostedTrees model.
// The defaultParams for Classification use LogLoss by default.
val boostingStrategy = BoostingStrategy.defaultParams("Classification")
boostingStrategy.numIterations = 2 // We can use more iterations in practice.
boostingStrategy.treeStrategy.numClasses = 2
boostingStrategy.treeStrategy.maxDepth = 2
boostingStrategy.treeStrategy.maxBins = 32
boostingStrategy.treeStrategy.subsamplingRate = 0.5
boostingStrategy.treeStrategy.maxMemoryInMB =1024
boostingStrategy.learningRate = 0.1

// Empty categoricalFeaturesInfo indicates all features are continuous.
boostingStrategy.treeStrategy.categoricalFeaturesInfo = Map[Int, Int]()

val model = GradientBoostedTrees.train(training, boostingStrategy)  

model.toDebugString

为了简单起见,这给了我两棵深度为 2 的树,如下所示:

 Tree 0:
    If (feature 3 <= 2.0)
     If (feature 2 <= 1.25)
      Predict: -0.5752212389380531
     Else (feature 2 > 1.25)
      Predict: 0.07462686567164178
    Else (feature 3 > 2.0)
     If (feature 0 <= 30.17)
      Predict: 0.7272727272727273
     Else (feature 0 > 30.17)
      Predict: 1.0
  Tree 1:
    If (feature 5 <= 67.0)
     If (feature 4 <= 100.0)
      Predict: 0.5739387416147804
     Else (feature 4 > 100.0)
      Predict: -0.550117566730937
    Else (feature 5 > 67.0)
     If (feature 2 <= 0.0)
      Predict: 3.0383669122382835
     Else (feature 2 > 0.0)
      Predict: 0.4332824083446489

我的问题是:我可以使用上面的树来计算预测概率,例如:

对于用于预测的特征集中的每个实例

exp(树 0 的叶子分数 + 树 1 的叶子分数)/(1+exp(树 0 的叶子分数 + 树 1 的叶子分数))

这给了我一种概率。但不确定这是否是正确的方法。另外,是否有任何文档解释如何计算叶子分数(预测)。如果有人可以分享,我将非常感激。

任何建议都会很棒。


这是我使用 Spark 内部依赖项的方法。稍后您需要导入线性代数库进行矩阵运算,即将树预测与学习率相乘。

import org.apache.spark.mllib.linalg.{Vectors, Matrices}
import org.apache.spark.mllib.linalg.distributed.{RowMatrix}

假设您使用 GBT 构建模型:

val model = GradientBoostedTrees.train(trainingData, boostingStrategy)

使用模型对象计算概率:

// Get the log odds predictions from each tree
val treePredictions = testData.map { point => model.trees.map(_.predict(point.features)) }

// Transform the arrays into matrices for multiplication
val treePredictionsVector = treePredictions.map(array => Vectors.dense(array))
val treePredictionsMatrix = new RowMatrix(treePredictionsVector)
val learningRate = model.treeWeights
val learningRateMatrix = Matrices.dense(learningRate.size, 1, learningRate)
val weightedTreePredictions = treePredictionsMatrix.multiply(learningRateMatrix)

// Calculate probability by ensembling the log odds
val classProb = weightedTreePredictions.rows.flatMap(_.toArray).map(x => 1 / (1 + Math.exp(-1 * x)))
classProb.collect

// You may tweak your decision boundary for different class labels
val classLabel = classProb.map(x => if (x > 0.5) 1.0 else 0.0)
classLabel.collect

以下是您可以直接复制并粘贴到 Spark-Shell 中的代码片段:

import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.linalg.{Vectors, Matrices}
import org.apache.spark.mllib.linalg.distributed.{RowMatrix}
import org.apache.spark.mllib.tree.GradientBoostedTrees
import org.apache.spark.mllib.tree.configuration.BoostingStrategy
import org.apache.spark.mllib.tree.model.GradientBoostedTreesModel

// Load and parse the data file.
val csvData = sc.textFile("data/mllib/sample_tree_data.csv")
val data = csvData.map { line =>
  val parts = line.split(',').map(_.toDouble)
  LabeledPoint(parts(0), Vectors.dense(parts.tail))
}
// Split the data into training and test sets (30% held out for testing)
val splits = data.randomSplit(Array(0.7, 0.3))
val (trainingData, testData) = (splits(0), splits(1))

// Train a GBT model.
val boostingStrategy = BoostingStrategy.defaultParams("Classification")
boostingStrategy.numIterations = 50
boostingStrategy.treeStrategy.numClasses = 2
boostingStrategy.treeStrategy.maxDepth = 6
boostingStrategy.treeStrategy.categoricalFeaturesInfo = Map[Int, Int]()

val model = GradientBoostedTrees.train(trainingData, boostingStrategy)

// Get class label from raw predict function
val predictedLabels = model.predict(testData.map(_.features))
predictedLabels.collect

// Get class probability
val treePredictions = testData.map { point => model.trees.map(_.predict(point.features)) }
val treePredictionsVector = treePredictions.map(array => Vectors.dense(array))
val treePredictionsMatrix = new RowMatrix(treePredictionsVector)
val learningRate = model.treeWeights
val learningRateMatrix = Matrices.dense(learningRate.size, 1, learningRate)
val weightedTreePredictions = treePredictionsMatrix.multiply(learningRateMatrix)
val classProb = weightedTreePredictions.rows.flatMap(_.toArray).map(x => 1 / (1 + Math.exp(-1 * x)))
val classLabel = classProb.map(x => if (x > 0.5) 1.0 else 0.0)
classLabel.collect
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

使用树输出预测 Spark 中梯度提升树情况下的类概率 的相关文章

  • 这对蒙蒂·霍尔来说是好还是坏的“模拟”?怎么会? [关闭]

    Closed 这个问题是基于意见的 help closed questions 目前不接受答案 通过试图解释蒙蒂霍尔问题 http en wikipedia org wiki Monty Hall problem昨天在课堂上给一位朋友说 我
  • 单击节点时打开分支?

    我被困住了jsTree http www jstree com 这里 到目前为止 它有效 我可以使用 图标浏览和展开节点 并在单击节点时打开页面 但我仍然希望它在有人单击节点时展开所有直接节点 我环视了至少两个小时 但什么也没找到 官方网站
  • 使用 tree-model-js 将树转换回 JSON

    是否有一种方法可以将 TreeModel 转换为 JSON 字符串 这样它就可以被存储 然后使用tree parse 目前在尝试时JSON stringify root 它给出了关于循环引用的明显错误 因为子级包含父级 父级包含子级 Use
  • 在 oracle 树查询中连接其他表

    给定一个简单的 id description 表t1 比如 id description 1 Alice 2 Bob 3 Carol 4 David 5 Erica 6 Fred 以及一个父子关系表t2 比如 parent child 1
  • 获取图表中走过的最长路线

    我有一组相互连接的节点 我有以下节点网络 这里0是起点 我想遍历尽可能多的节点 并且一个节点只遍历一次 另外 在从 0 到目标节点的旅程中 我只想有一个奇数编号的节点 如 1 3 5 7 现在我需要找出从起始位置 0 开始可以行驶的最长路线
  • 使用 rand(3) 生成随机数(9)

    您有一个函数 rand 3 它生成从 1 到 3 的随机整数 使用此函数构造另一个函数 rand 9 它生成从 1 到 9 的随机整数 这是一个简单的解决方案 rand 3 3 rand 3 1 您想要这样做的原因是它提供了从 1 到 9
  • 使用树输出预测 Spark 中梯度提升树情况下的类概率

    众所周知 Spark 中的 GBT 目前可以为您提供预测标签 我正在考虑尝试计算一个类的预测概率 假设所有实例都落在某个叶子下 构建 GBT 的代码 import org apache spark SparkContext import o
  • 构建具有继承的通用树

    我正在构建一个通用的Tree
  • Visual Studio代码侧边栏垂直引导线(自定义侧边栏)

    有人知道 Visual Studio 代码的扩展可以像 netbeans 一样在侧边栏 用于文件和文件夹 上显示垂直指南吗 或者vscode中有一些设置吗 Netbeans 快照 https i stack imgur com CFJsw
  • 如何递归探索Python嵌套字典? [复制]

    这个问题在这里已经有答案了 我很好奇是否有一种方法可以在 python 中递归地探索嵌套字典 我的意思是 假设我们有一个如下示例 d a b c 1 2 3 获取最里面字典的内容需要什么代码 c 1 2 3 遍历a and b 在这种情况下
  • QTableView 仅显示使用 QAbstractItemModel 实现的树模型的叶子

    假设我有一个树结构 树叶在bold 抱歉这些点 A A1 A2 B B1 B11 B2 C 存储在 QAbstractItemModel 中 具有设置的父 子关系 如何在 QTableView 中仅显示树叶 基本思想是实现一个 QSortF
  • 寻找成熟的 M-Tree 实现 [关闭]

    Closed 这个问题不符合堆栈溢出指南 help closed questions 目前不接受答案 我正在寻找一个成熟的 java M Tree 实现 甚至任何 M Tree 实现 除了我找到的唯一实现 http en wikipedia
  • 单场淘汰赛 - 可能的组合数量

    单场淘汰赛中 8 人参加的组合有多少种 比赛总数为 7 场 但我还需要这组比赛的组合数量 如果玩家在树中的哪个位置开始并不重要 而只关心他 她与哪些对手战斗以及他 她能坚持多久 我们可以说左边的玩家总是获胜 然后只需计算创建的方法数量最下面
  • Beaglebone Black 上的 GPIO

    我目前遇到了 Beaglebone black GPIO 引脚的问题 我正在寻找一种正确的方法来读取 C 中的 GPIO 引脚 p8 4 的值 如果我理解正确的话 我尝试使用一个库 该库使用了在引入设备树之前不支持的旧方法 我尝试寻找其他解
  • 将 rbf 与 scipy 一起使用时出现内存错误

    I want to plot some points with the rbf function like here to get the density distribution of the points 如果我运行以下代码 它工作正常
  • 提取给定节点的所有父节点

    我正在尝试使用以下命令提取每个给定 GO Id 节点 的所有父级EBI RDF sparql 端点 https www ebi ac uk rdf services sparql 我是根据this https stackoverflow c
  • Tic-Tac-Toe AI:如何制作树?

    在制作井字游戏机器人时 我在尝试理解 树 时遇到了巨大的障碍 我理解这个概念 但我不知道如何实现它们 有人可以向我展示一个如何为这种情况生成树的示例吗 或者关于生成树的好教程 我想最困难的部分是生成部分树 我知道如何实现生成整棵树 但不知道
  • ';'预期但发现“导入” - Scala 和 Spark

    我正在尝试使用 Spark 和 Scala 来编译一个独立的应用程序 我不知道为什么会收到此错误 topicModel scala 2 expected but import found error import org apache sp
  • Webix 树节点的 Font Awesome 图标

    Webix 与 Font Awesome 集成 http docs webix com desktop icon types html 但是如何使用 Font Awesome 图标代替树中的默认文件夹 文件图标来设置各个节点的样式呢 这是我
  • 如何将模型结果保存到文本文件?

    我正在尝试将从模型生成的频繁项集保存到文本文件中 该代码是 Spark ML 库中 FPGrowth 示例的示例 Using saveAsTextFile直接在模型上写入 RDD 位置而不是实际值 import org apache spa

随机推荐