下面提供了可配置的direction
and tolerance
论据。使用窗口函数(last
对于“落后”,first
代表“向前”,两者都代表“最近”)。
另外,根据我的经验,by
熊猫的论点merge_asof
经常需要。所以我也努力将此参数添加到函数中。论据by
将是有益的,因为它可以通过创建分区进一步提高性能。
from pyspark.sql import functions as F, Window as W
def merge_asof(df_left, df_right, on: str, by=None, tolerance=None, direction: str='backward'):
def backward():
return add_diff(F.last(stru1, True).over(w0))
def forward():
return add_diff(F.first(stru1, True).over(w0.rowsBetween(0, W.unboundedFollowing)))
def nearest():
return F.array_sort(F.array(backward(), forward()))[0]
def add_diff(col):
return F.struct(
F.abs(F.col(on) - col[on]).alias('diff'),
col[on].alias(on),
col[c].alias(c)
)
df_r = df_right if by else df_right.withColumn('_by', F.lit(1))
df_l = df_left if by else df_left.withColumn('_by', F.lit(1))
df_l = df_l.withColumn('_df_l', F.lit(True))
by = [by] if isinstance(by, str) else by or ['_by']
join_on = [on] + by
df = df_l.join(df_r, join_on, 'full')
w0 = W.partitionBy(*by).orderBy(on)
for c in set(df_right.columns) - set(join_on):
stru1 = F.when(~F.isnull(c), F.struct(on, c))
stru2 = eval(f'{direction}()')
if tolerance:
stru2 = stru2.withField(c, F.when(stru2['diff'] <= tolerance, stru2[c]))
df = df.withColumn(c, stru2[c])
df = df.filter('_df_l').drop('_df_l', '_by')
return df
一些解释
首先,函数的参数被稍微修改,并且基于两者执行完全连接,on
and by
论据。
df_r = df_right if by else df_right.withColumn('_by', F.lit(1))
df_l = df_left if by else df_left.withColumn('_by', F.lit(1))
df_l = df_l.withColumn('_df_l', F.lit(True))
by = [by] if isinstance(by, str) else by or ['_by']
join_on = [on] + by
df = df_l.join(df_r, join_on, 'full')
然后,对于右侧数据框中的每一列(除了on
and by
列),正在计算一个新值direction
and tolerance
.
w0 = W.partitionBy(*by).orderBy(on)
for c in set(df_right.columns) - set(join_on):
stru1 = F.when(~F.isnull(c), F.struct(on, c))
stru2 = eval(f'{direction}()')
if tolerance:
stru2 = stru2.withField(c, F.when(stru2['diff'] <= tolerance, stru2[c]))
df = df.withColumn(c, stru2[c])
stru1
列(的struct类型)被创建,持有on
and c
价值观。eval(f'{direction}()')
执行一个基于的函数direction
。函数是为每个direction
值(“向后”、“向前”、“最近”)。这些函数向结构列添加另一字段(“diff”)。那么,如果“diff”在上面tolerance
level,列的值null
.
一些例子
df1_spark = spark.createDataFrame([{"timestamp": 0.5 * i, "a": i * 2} for i in range(66)])
df2_spark = spark.createDataFrame([{"timestamp": 0.33 * i, "b": i} for i in range(100)])
merge_asof(df1_spark, df2_spark, on='timestamp', direction='backward').show(3)
# +---------+---+---+
# |timestamp| a| b|
# +---------+---+---+
# | 0.0| 0| 0|
# | 0.5| 2| 1|
# | 1.0| 4| 3|
# +---------+---+---+
merge_asof(df1_spark, df2_spark, on='timestamp', direction='forward').show(3)
# +---------+---+---+
# |timestamp| a| b|
# +---------+---+---+
# | 0.0| 0| 0|
# | 0.5| 2| 2|
# | 1.0| 4| 4|
# +---------+---+---+
merge_asof(df1_spark, df2_spark, on='timestamp', direction='nearest').show(3)
# +---------+---+---+
# |timestamp| a| b|
# +---------+---+---+
# | 0.0| 0| 0|
# | 0.5| 2| 2|
# | 1.0| 4| 3|
# +---------+---+---+
merge_asof(df1_spark, df2_spark, on='timestamp', tolerance=0.05, direction='nearest').show()
# +---------+---+----+
# |timestamp| a| b|
# +---------+---+----+
# | 0.0| 0| 0|
# | 0.5| 2|null|
# | 1.0| 4| 3|
# | 1.5| 6|null|
# | 2.0| 8| 6|
# | 2.5| 10|null|
# | 3.0| 12| 9|
# | 3.5| 14|null|
# | 4.0| 16| 12|
# | 4.5| 18|null|
# | 5.0| 20| 15|
# | 5.5| 22|null|
# | 6.0| 24|null|
# | 6.5| 26|null|
# | 7.0| 28|null|
# | 7.5| 30|null|
# | 8.0| 32|null|
# | 8.5| 34|null|
# | 9.0| 36|null|
# | 9.5| 38|null|
# +---------+---+----+