一般来说,没有文档,因为对于 Spark 1.6 / 2.0,大多数相关 API 并不打算公开。它应该在 Spark 2.1.0 中改变(参见SPARK-7146 https://issues.apache.org/jira/browse/SPARK-7146).
API相对复杂,因为它必须遵循特定的约定才能使给定的Transformer
or Estimator
兼容于Pipeline
API。其中一些方法可能需要用于读写或网格搜索等功能。其他的,比如keyword_only
只是一个简单的帮助者,并不是严格要求的。
假设您为平均参数定义了以下混合:
from pyspark.ml.pipeline import Estimator, Model, Pipeline
from pyspark.ml.param.shared import *
from pyspark.sql.functions import avg, stddev_samp
class HasMean(Params):
mean = Param(Params._dummy(), "mean", "mean",
typeConverter=TypeConverters.toFloat)
def __init__(self):
super(HasMean, self).__init__()
def setMean(self, value):
return self._set(mean=value)
def getMean(self):
return self.getOrDefault(self.mean)
标准差参数:
class HasStandardDeviation(Params):
standardDeviation = Param(Params._dummy(),
"standardDeviation", "standardDeviation",
typeConverter=TypeConverters.toFloat)
def __init__(self):
super(HasStandardDeviation, self).__init__()
def setStddev(self, value):
return self._set(standardDeviation=value)
def getStddev(self):
return self.getOrDefault(self.standardDeviation)
和阈值:
class HasCenteredThreshold(Params):
centeredThreshold = Param(Params._dummy(),
"centeredThreshold", "centeredThreshold",
typeConverter=TypeConverters.toFloat)
def __init__(self):
super(HasCenteredThreshold, self).__init__()
def setCenteredThreshold(self, value):
return self._set(centeredThreshold=value)
def getCenteredThreshold(self):
return self.getOrDefault(self.centeredThreshold)
你可以创建基本的Estimator
如下:
from pyspark.ml.util import DefaultParamsReadable, DefaultParamsWritable
from pyspark import keyword_only
class NormalDeviation(Estimator, HasInputCol,
HasPredictionCol, HasCenteredThreshold,
DefaultParamsReadable, DefaultParamsWritable):
@keyword_only
def __init__(self, inputCol=None, predictionCol=None, centeredThreshold=1.0):
super(NormalDeviation, self).__init__()
kwargs = self._input_kwargs
self.setParams(**kwargs)
# Required in Spark >= 3.0
def setInputCol(self, value):
"""
Sets the value of :py:attr:`inputCol`.
"""
return self._set(inputCol=value)
# Required in Spark >= 3.0
def setPredictionCol(self, value):
"""
Sets the value of :py:attr:`predictionCol`.
"""
return self._set(predictionCol=value)
@keyword_only
def setParams(self, inputCol=None, predictionCol=None, centeredThreshold=1.0):
kwargs = self._input_kwargs
return self._set(**kwargs)
def _fit(self, dataset):
c = self.getInputCol()
mu, sigma = dataset.agg(avg(c), stddev_samp(c)).first()
return NormalDeviationModel(
inputCol=c, mean=mu, standardDeviation=sigma,
centeredThreshold=self.getCenteredThreshold(),
predictionCol=self.getPredictionCol())
class NormalDeviationModel(Model, HasInputCol, HasPredictionCol,
HasMean, HasStandardDeviation, HasCenteredThreshold,
DefaultParamsReadable, DefaultParamsWritable):
@keyword_only
def __init__(self, inputCol=None, predictionCol=None,
mean=None, standardDeviation=None,
centeredThreshold=None):
super(NormalDeviationModel, self).__init__()
kwargs = self._input_kwargs
self.setParams(**kwargs)
@keyword_only
def setParams(self, inputCol=None, predictionCol=None,
mean=None, standardDeviation=None,
centeredThreshold=None):
kwargs = self._input_kwargs
return self._set(**kwargs)
def _transform(self, dataset):
x = self.getInputCol()
y = self.getPredictionCol()
threshold = self.getCenteredThreshold()
mu = self.getMean()
sigma = self.getStddev()
return dataset.withColumn(y, (dataset[x] - mu) > threshold * sigma)
致谢本杰明-曼斯 https://stackoverflow.com/users/234944/benjamin-manns为了使用 DefaultParamsReadable、DefaultParamsWritable https://stackoverflow.com/a/52467470适用于 PySpark >= 2.3.0
最后可以如下使用:
df = sc.parallelize([(1, 2.0), (2, 3.0), (3, 0.0), (4, 99.0)]).toDF(["id", "x"])
normal_deviation = NormalDeviation().setInputCol("x").setCenteredThreshold(1.0)
model = Pipeline(stages=[normal_deviation]).fit(df)
model.transform(df).show()
## +---+----+----------+
## | id| x|prediction|
## +---+----+----------+
## | 1| 2.0| false|
## | 2| 3.0| false|
## | 3| 0.0| false|
## | 4|99.0| true|
## +---+----+----------+