一般来说,Spark 在处理分类数据时依赖于列元数据。在您的管道中,这是由StringIndexer
(ft_string_indexer
)。机器学习总是预测标签,而不是原始字符串。通常你会使用IndexToString
变压器由ft_index_to_string
.
在火花中IndexToString
可以使用提供的标签列表 https://spark.apache.org/docs/latest/api/scala/index.html#org.apache.spark.ml.feature.IndexToString@setLabels(value:Array%5BString%5D):IndexToString.this.type or Column
元数据。很遗憾sparklyr
实施受到两个方面的限制:
-
它只能使用元数据 https://github.com/rstudio/sparklyr/blob/a375d1f9c0be5af8cafc4892c3e7946b8144f3f9/R/ml_feature_transformation.R#L131-L132,未在预测列上设置。
-
ft_string_indexer
丢弃经过训练的模型,因此它不能用于提取标签。
我可能错过了一些东西,但看起来你必须手动映射预测,例如通过joining
转换后的数据:
pred %>%
select(prediction=Resp_cat, Resp_prediction=Resp) %>%
distinct() %>%
right_join(pred)
Joining, by = "prediction"
# Source: lazy query [?? x 9]
# Database: spark_connection
prediction Resp_prediction ID Numb Resp Resp_cat id777a79821e1e
<dbl> <chr> <int> <int> <chr> <dbl> <dbl>
1 7 171 1 3 171 7 0
2 0 153 2 10 153 0 1
3 3 132 3 8 132 3 2
4 5 122 4 7 122 5 3
5 6 198 5 4 198 6 4
6 2 164 6 9 164 2 5
7 4 137 7 6 137 4 6
8 1 184 8 5 184 1 7
9 0 153 9 1 153 0 8
10 1 184 10 2 184 1 9
# ... with more rows, and 2 more variables: rawPrediction <list>,
# probability <list>
解释:
pred %>%
select(prediction=Resp_cat, Resp_prediction=Resp) %>%
distinct()
创建从预测(编码标签)到原始标签的映射。我们重命名Resp_cat
to prediction
所以它可以作为连接密钥,并且Resp
to Resp_prediction
以免与实际发生冲突Resp
.
最后我们应用右等值连接:
... %>% right_join(pred)
Note:
您应该指定树的类型:
ml_decision_tree(
response = "Resp_cat", features = "Numb",type = "classification")