但是,RandomForest模型无法通过客户端代码新建,因此似乎无法在管道api中使用RandomForest。
嗯,确实如此,但您只是想使用错误的类。代替mllib.tree.RandomForest
你应该使用ml.classification.RandomForestClassifier
。这是一个基于的示例MLlib 文档中的一个 https://spark.apache.org/docs/latest/mllib-ensembles.html#classification.
import org.apache.spark.ml.classification.RandomForestClassifier
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.feature.StringIndexer
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.util.MLUtils
import sqlContext.implicits._
case class Record(category: String, features: Vector)
val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt")
val splits = data.randomSplit(Array(0.7, 0.3))
val (trainData, testData) = (splits(0), splits(1))
val trainDF = trainData.map(lp => Record(lp.label.toString, lp.features)).toDF
val testDF = testData.map(lp => Record(lp.label.toString, lp.features)).toDF
val indexer = new StringIndexer()
.setInputCol("category")
.setOutputCol("label")
val rf = new RandomForestClassifier()
.setNumTrees(3)
.setFeatureSubsetStrategy("auto")
.setImpurity("gini")
.setMaxDepth(4)
.setMaxBins(32)
val pipeline = new Pipeline()
.setStages(Array(indexer, rf))
val model = pipeline.fit(trainDF)
model.transform(testDF)
这里有一件事我无法弄清楚。据我所知,应该可以使用从中提取的标签LabeledPoints
直接,但由于某种原因它不起作用并且pipeline.fit
raises IllegalArgumentExcetion
:
RandomForestClassifier 的输入带有无效的标签列标签,但没有指定类的数量。
因此,丑陋的伎俩StringIndexer
。应用后我们得到所需的属性({"vals":["1.0","0.0"],"type":"nominal","name":"label"}
)但有些课程ml
没有它似乎工作得很好。