我知道你想最后得到 DataFrame 。我看到两种可能的解决方案。我想说,在它们之间进行选择是品味问题。
从 RDD 创建列
以 RDD 的形式获取 id 和 cluster 对非常容易:
val idPointRDD = data.rdd.map(s => (s.getInt(0), Vectors.dense(s.getDouble(1),s.getDouble(2)))).cache()
val clusters = KMeans.train(idPointRDD.map(_._2), 3, 20)
val clustersRDD = clusters.predict(idPointRDD.map(_._2))
val idClusterRDD = idPointRDD.map(_._1).zip(clustersRDD)
然后你从中创建 DataFrame
val idCluster = idClusterRDD.toDF("id", "cluster")
它之所以有效,是因为地图不会改变 RDD 中数据的顺序,这就是为什么你可以只用预测结果压缩 id。
使用UDF(用户定义函数)
第二种方法涉及使用clusters.predict
方法作为 UDF:
val bcClusters = sc.broadcast(clusters)
def predict(x: Double, y: Double): Int = {
bcClusters.value.predict(Vectors.dense(x, y))
}
sqlContext.udf.register("predict", predict _)
现在我们可以使用它来向数据添加预测:
val idCluster = data.selectExpr("id", "predict(x, y) as cluster")
请记住,Spark API 不允许取消 UDF 注册。这意味着闭包数据将保存在内存中。
错误/非最佳解决方案
Using clusters.predict without broadcasting
It won't work in the distributed setup. Edit: actually it will work, I was confused by implementation of predict for RDD, which uses broadcast.
sc.makeRDD(clusters.predict(parsedData).toArray()).toDF()
toArray
收集驱动程序中的所有数据。这意味着在分布式模式下,您将把集群 ID 复制到一个节点中。