package
com.buwenbuhuo.spark.sql.project
import
java.text.DecimalFormat
import
org.apache.spark.sql.Row
import
org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import
org.apache.spark.sql.types._
/**
**
*
* @author 不温卜火
* *
* @create 2020-08-06 13:24
**
* MyCSDN : [url=https://buwenbuhuo.blog.csdn.net/]https://buwenbuhuo.blog.csdn.net/[/url]
*
*/
class
CityRemarkUDAF
extends
UserDefinedAggregateFunction {
// 输入数据的类型: 北京 String
override def inputSchema: StructType = {
StructType(Array(StructField(
"city"
, StringType)))
}
// 缓存的数据的类型 每个地区的每个商品 缓冲所有城市的点击量 北京->1000, 天津->5000 Map, 总的点击量 1000/?
override def bufferSchema: StructType = {
StructType(Array(StructField(
"map"
, MapType(StringType, LongType)), StructField(
"total"
, LongType)))
}
// 输出的数据类型 "北京21.2%,天津13.2%,其他65.6%" String
override def dataType: DataType = StringType
// 相同的输入是否应用有相同的输出.
override def deterministic: Boolean =
true
// 给存储数据初始化
override def initialize(buffer: MutableAggregationBuffer): Unit = {
//初始化map缓存
buffer(
0
) = Map[String, Long]()
// 初始化总的点击量
buffer(
1
) = 0L
}
// 分区内合并 Map[城市名, 点击量]
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
input match {
case
Row(cityName: String) =>
// 1. 总的点击量 + 1
buffer(
1
) = buffer.getLong(
1
) + 1L
// 2. 给这个城市的点击量 +1 => 找到缓冲区的map,取出来这个城市原来的点击 + 1 ,再复制过去
val map: collection.Map[String, Long] = buffer.getMap[String, Long](
0
)
buffer(
0
) = map + (cityName -> (map.getOrElse(cityName, 0L) + 1L))
case
_ =>
}
}
// 分区间的合并
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
val map1 = buffer1.getAs[Map[String, Long]](
0
)
val map2 = buffer2.getAs[Map[String, Long]](
0
)
val total1: Long = buffer1.getLong(
1
)
val total2: Long = buffer2.getLong(
1
)
// 1. 总数的聚合
buffer1(
1
) = total1 + total2
// 2. map的聚合
buffer1(
0
) = map1.foldLeft(map2) {
case
(map, (cityName, count)) =>
map + (cityName -> (map.getOrElse(cityName, 0L) + count))
}
}
// 最终的输出结果
override def evaluate(buffer: Row): Any = {
// "北京21.2%,天津13.2%,其他65.6%"
val cityAndCount: collection.Map[String, Long] = buffer.getMap[String, Long](
0
)
val total: Long = buffer.getLong(
1
)
val cityCountTop2: List[(String, Long)] = cityAndCount.toList.sortBy(-_._2).take(
2
)
var cityRemarks: List[CityRemark] = cityCountTop2.map {
case
(cityName, count) => CityRemark(cityName, count.toDouble / total)
}
// CityRemark("其他",1 - cityremarks.foldLeft(0D)(_+_.cityRatio))
cityRemarks :+= CityRemark(
"其他"
,cityRemarks.foldLeft(1D)(_ - _.cityRatio))
cityRemarks.mkString(
","
)
}
}
case
class
CityRemark(cityName: String, cityRatio: Double) {
val formatter =
new
DecimalFormat(
"0.00%"
)
override def toString: String = s
"$cityName:${formatter.format(cityRatio)}"
}