使用Spark ALS模型 + Faiss向量检索实现用户扩量实例

2023-10-31

1、通过ALS模型实现用户/商品Embedding的效果,获得其向量表示

准备训练数据, M = (U , I, R) 即 用户集U、商品集I、及评分数据R。

(1)商品集I的选择:可以根据业务目标确定商品候选集,比如TopK热度召回、或者流行度不高但在业务用户中区分度比较高的商品集等。个人建议量级控制在5W内,1W-2W左右比较合适,太大的话,用户产生行为的商品比较少,评分数据会非常的稀疏。

(2)用户集U的选择: 最好是粗召回策略确定的用户范围,因为ALS模型会生成所有U用户的特征向量表示,对于没有见过的用户u,没有其向量表示,其推荐也是冷启动策略。这里可以根据业务需要限制一个大范围,比如4000W-5000W的或大几百万的用户(从计算效率和内存使用上,个人建议500W内比较合适)。比如用户U定义为某些类目下购买人群、或者近期活跃人群等符合业务人群目标的潜在客户群。模型训练完之后,也是在这个用户集U中筛选出TopK相似的用户做推荐或扩量。

(3)评分数据R的选择:我们能采集到的大多是隐式反馈的数据,比如购买行为、浏览行为、收藏行为等。确定了U、I,确定了评分指标类型,就可以统计一段时间内,U对I的反馈数据R。数据量级大约在7亿条-10亿条,在模型参数设置合理的情况下,大约20-30分钟就可以训练完。

from pyspark.ml.recommendation import ALS
from pyspark.sql.functions import expr, isnull

""" ALS模型参数解读,和大小设置建议:
:param
         rank=10, maxIter=10, regParam=0.1, numUserBlocks=10,
         numItemBlocks=10, implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item",
         seed=None, ratingCol="rating", nonnegative=False, checkpointInterval=10,
         intermediateStorageLevel="MEMORY_AND_DISK",
         finalStorageLevel="MEMORY_AND_DISK", coldStartStrategy="nan", blockSize=4096

NumBlocks分块数:分块是为了并行计算,默认为10。可以根据数据量级适当放大,比如20。 可以对 numUserBlocks\numItemBlocks 单独进行配置并行度 ,也可以通过setNumBlocks(30)一起设置。

正则化参数:默认为1。 
秩rank:模型中隐藏因子的个数,默认是10。即特征向量的维度。
implicitPrefs:显式偏好信息-false,隐式偏好信息-true,默认false(显示) 。 电商场景中 购买、点击、分享,都是隐式反馈。
alpha:隐式反馈时的置信度参数,默认是1.0。只用于隐式的偏好数据。
setMaxIter(10):最大迭代次数,设置太大发生java.lang.StackOverflowError。建议范围 10 ~20。 超过20,比较容易失败。
coldStartStrategy: 预测时冷启动策略。默认是nan, 可以选择 drop。
"""

ratings = spark.sql("""
        select
          user_acct, user_id, main_sku_id, item_id, rating
        from dmb_dev.dmb_dev_als_model_rating_matrix
        """).repartition(3600)
train_data, test_data = ratings.randomSplit([0.9, 0.1], seed=4226)
train_data.cache()       
als = ALS() \
    .setImplicitPrefs(True) \
    .setAlpha(0.7) \
    .setMaxIter(20) \
    .setRank(10) \
    .setRegParam(0.01) \
    .setNumBlocks(30) \
    .setUserCol("user_id") \
    .setItemCol("item_id") \
    .setRatingCol("rating") \
    .setColdStartStrategy("drop")
print(als.explainParams())

als_model = als.fit(train_data)
als_model.write().overwrite().save(model_save_path)

# 训练集合所有用户U的向量表示
candidate_user_factors = als_model.userFactors.withColumnRenamed("id", "user_id")\
    .join(train_data.select("user_acct", "user_id").dropDuplicates(), ["user_id"])\
    .withColumn("bin_group", expr("round(rand(),1)"))
candidate_user_factors.cache()
candidate_user_factors.write.format("orc").mode("overwrite")\
    .saveAsTable("dev.dev_als_model_all_trained_users_factor_result")
train_data.unpersist()

# query用户的向量表示
target_user_factors = spark.sql("""
        select
          user_acct, user_id
        from dev.dev_wdy_als_seed_users_table
        group by user_acct, user_id
        """).join(candidate_user_factors, ["user_acct", "user_id"])
target_user_factors.cache()
target_user_factors.write.format("orc").mode("overwrite")\
    .saveAsTable("dev.dev_als_model_seed_users_factor")

# 候选用户向量表示
search_user_factors = candidate_user_factors.join(target_user_factors,
                                                  candidate_user_factors["user_acct"] == target_user_factors["user_acct"],
                                                  "left_outer")\
    .where(isnull(target_user_factors["user_acct"]))\
    .select(candidate_user_factors["user_acct"], candidate_user_factors["user_id"],
            candidate_user_factors["features"], candidate_user_factors["bin_group"])
search_user_factors.write.format("orc").mode("overwrite")\
    .saveAsTable("dev.dev_als_model_candidate_users_factor")
candidate_user_factors.unpersist()
target_user_factors.unpersist()

2、通过Faiss快速实现向量TopK相似检索

如果没有装faiss,可以选择安装CPU/GPU版本, pip install faiss-cpu

关于faiss的使用说明,可以参考向量数据库入坑指南:聊聊来自元宇宙大厂 Meta 的相似度检索技术 Faiss - 知乎

 faiss来自facebook 开源 Meta Research · GitHub的github库为:GitHub - facebookresearch/faiss: A library for efficient similarity search and clustering of dense vectors.

根据业务需求的查询速度、精准度要求来选择合适的Faiss TopK向量查询方法。


# 判断 npy文件是否存在,不存在则执行以下操作;否则跳过此步骤,直接读取文件。
user_embedding = spark.sql("""
    select 
        features[0],features[1],features[2],features[3],features[4],
        features[5],features[6],features[7],features[8],features[9] 
    from 
        dev.dev_als_model_candidate_users_factor
    where bin_group=0.1""").toPandas()

# 量级500W内执行顺利,再大的量级容易内存溢出失败。
np.save("user_embedding_01.npy", np.array(user_embedding, order='C'))

user_embedding = np.load("user_embedding_01.npy")
print("user_embedding data sample:", user_embedding[:3])
print("user embedding shape", user_embedding.shape)
dimension = user_embedding.shape[1]
nums_user = user_embedding.shape[0]

faiss.normalize_L2(user_embedding)
index = faiss.IndexFlatIP(dimension)
index.add(user_embedding)
print("index is trained:", index.is_trained)
print("index n total:", index.ntotal)

# 判断文件是否存在,如果存在则直接读取,否则先下载保存到本地。
## 这里k=30 或更大时,查询易失败。 k=20, 查询耗时久,但会成功,大约3小时。 k=10时,
k = 5
query1 = spark.sql("""select features[0],features[1],features[2],features[3],features[4],features[5],features[6],features[7],
                        features[8],features[9] from dev.dev_als_model_seed_users_factor""").toPandas()
np.save("query.npy", np.array(query1, order='C'))
query = np.load("query.npy")
print("query shape:", query.shape)

# 查询
t0 = time.time()
Deg, Ind = index.search(query, k)
t1 = time.time()
print("平均耗时 %7.3f min" % ((t1 - t0)/60))

# 保存索引
faiss.write_index(index, "faiss_01.index")
np.save("Ind_01.npy", Ind)
np.save("Deg_01.npy", Deg)

res = []
for i in range(query.shape[0]):
    q_vector = query[i]
    r_list = Ind[i]
    for j in range(len(r_list)):
        r_vector = user_embedding[r_list][j]
        sim = Deg[i][j]
        res.append(([float(v) for v in r_vector], float(sim)))

res = spark.createDataFrame(res, ["recommend_vector", "similarity"]).repartition(10)
res.cache()

res.write.format("orc").mode("overwrite")\
        .saveAsTable("dev.dev_als_model_recommend_vector_result")

user_embedding = spark.sql("""
    select 
        *
    from 
        dev.dev_als_model_candidate_users_factor
    where bin_group=0.1""")
res.join(user_embedding, res["recommend_vector"] == user_embedding["features"])\
        .write.format("orc").mode("overwrite")\
        .saveAsTable("dev.dev_als_model_recommend_user_pin_result")

查询速度实验对比数据:

IndexIVFFlat IndexFlatIP IndexFlatIP
user embedding shape (4474857, 10) user embedding shape (4474857, 10) user embedding shape (4474857, 10)
query shape: (78525, 10) query shape: (78525, 10) query shape: (34525, 10)
k=5 k=5 k=10
平均耗时  10.522 min 平均耗时 > 6h 平均耗时 3h-4h

业务中查询的候选集可能有4000W-5000W,而且对于查询响应时间有要求,使用IndexIVFFlat更符合上线需求。

nlist = 50
quantizer = faiss.IndexFlatL2(dimension)
index = faiss.IndexIVFFlat(quantizer, dimension, nlist, faiss.METRIC_L2)
assert not index.is_trained
index.train(user_embedding)
assert index.is_trained
index.add(user_embedding)   # 添加索引可能会有一点慢
index.nprobe = 10    # 默认 nprobe 是1 ,这里设置为10

3、通过I2I2U 或者 I2U2U来获得用户扩量结果

上述实现的是U2U的扩量方法,使用的是User-Factor向量表示。第一个U来自于业务营销目标I下的历史已购人群。即这是一个I2U2U的扩量方法。 I (目标商品)---> U(历史购买) ---> U(TopK相似) 。

当然也可以通过使用Item-Factor向量表示,实现 I2I2U,即 I (目标商品)---> I(TopK相似) ---> U(历史购买) ,这样来做商品相似召回,实现用户的扩量。

基于实验效果,或历史数据的验证来选择使用哪种方法投产。

4、算法设计框架总结

可以看到,这个算法设计框架其实是 Embedding + Faiss ,即用户/商品的向量表示 + Faiss快速向量相似检索 的设计模式。

那么第一部分的ALS模型当然可以替换成任何一种可以效果更好的Embedding算法模型,比如BERT 、Transformer等深度学习模型。而第二部分Faiss的查询可以保持不动,只要替换查询数据源就可以了。当然也可以将其优化成GPU的,或更快速的查询方式,以满足线上业务的需求。

但整体的算法设计框架是不变的,Embedding向量化 + Faiss相似检索。

Done.

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

使用Spark ALS模型 + Faiss向量检索实现用户扩量实例 的相关文章

随机推荐

  • mysql 利用binlog增量备份,还原实例(日志备份数据库)

    一 什么是增量备份 增量备份 就是将新增加的数据进行备份 假如你一个数据库 有10G的数据 每天会增加10M的数据 数据库每天都要备份一次 这么多数据是不是都要备份呢 还是只要备份增加的数据呢 很显然 我只要备份增加的数据 这样减少服务器的
  • C++ 调用python

    本文代码已在vs2017上验证 c 调用python需要三类文件 这些文件都可以在python安装目录下找到 1 include文件夹 位于python目录下 2 dll文件 位于python目录下 如python37 dll 3 lib文
  • 超分辨率概述

    1 什么是超分辨率增强 Video super resolution is the task of upscaling a video from a low resolution to a high resolution 超分辨率 Supe
  • Git & GitHub 入门6:用好commit message

    git log 可以查看所有的 commit messages 修改repo中的文件内容后 add该文件 直接运行命令git commit进入message编辑状态 可以输入多行commit message说明 完成后点击ECS键退出编辑
  • Gin-swaggo为gin框架提供Swagger 文档

    官方 https github com swaggo gin swagger 开始使用 为API方法增加注释 加在controller api 层 See Declarative Comments Format 运行下面命令下载swgo g
  • L2-4 部落PTA

    在一个社区里 每个人都有自己的小圈子 还可能同时属于很多不同的朋友圈 我们认为朋友的朋友都算在一个部落里 于是要请你统计一下 在一个给定社区中 到底有多少个互不相交的部落 并且检查任意两个人是否属于同一个部落 输入格式 输入在第一行给出一个
  • hadoop3.2.1编译安装

    基础环境 centos 7 7 三台 hadoop需要的环境 Requirements Unix System JDK 1 8 Maven 3 3 or later ProtocolBuffer 2 5 0 CMake 3 1 or new
  • echart 折线图设置y轴单位_如何让echarts中y轴的单位位于数值的右上角

    展开全部 1 创建折线图的数据区 包括年份和数据 2 仅选择数据区创建折线图 插入选项卡 图表62616964757a686964616fe78988e69d8331333363396364工具组 折线图 3 得到的折线图x坐标不满足要求
  • c++可变参数模板函数

    可变参数模版函数 类型一致 可变参数 使用头文件 cstdarg va list arg ptr 开头指针 va start arg ptr n 从开头开始读取n个 va arg arg ptr T 根据数据类型取出数据 va end ar
  • jdk1.8升级后 sun.io.CharToByteConverter 错误处理

    项目工程中用到jdk1 6相关方法 可以使用 但是升级到jdk1 8以后 编译出现java lang NoClassDefFoundError sun io CharToByteConverter错误 后经查询 是jdk1 8版本中已经从s
  • 前端02:CSS选择器等基础知识

    CSS基础选择器 设置字体样式 文本样式 CSS的三种引入方式 能使用Chrome调试工具调试样式 HTML专注做结构呈现 样式交给CSS 即结构 HTML 和样式CSS相分离 CSS主要由量分布构成 选择器以及一条或多条声明 选择器 给谁
  • 深度学习10篇文章之Interleaved Group Convolution

    本文主要讲解Ting Zhang的Interleaved Group Convolutions for Deep Neural Networks 该文对Group convolution有较为详细的讲解 Abstract 文章开篇引出了 I
  • 新昌中学2021高考成绩查询,2021绍兴市地区高考成绩排名查询,绍兴市高考各高中成绩喜报榜单...

    距离2018年高考还有不到一个月的时间了 很多人在准备最后冲刺的同时 也在关心高考成绩 2018各地区高考成绩排名查询 高考各高中成绩喜报榜单尚未公布 下面是往年各地区高考成绩排名查询 高考各高中成绩喜报榜单 想要了解同学可以参考下 同时关
  • 轻松学懂图(下)——Dijkstra和Bellman-Ford算法

    概述 在上一篇文章中讲述了Kruskal和Prim算法 用于得到最小生成树 今天将会介绍两种得到最短路径的算法 Dijlkstra和Bellman Ford算法 Dijkstra算法 算法的特点 属于单源最短路径算法 什么是单源呢 通俗的说
  • 前端使用自定义指令实现埋点【vue3】

    vue项目有时候会需要进行数据采集 记录用户行为习惯 而且很多页面都会使用到 所以用vue自定义指令来实现埋点功能 埋点的几种方式 页面埋点 浏览次数及时长等 点击埋点 每一次点击行为 曝光埋点 统计区域是否被用户浏览 import cre
  • 神经网络量化----TensorRT深刻解读

    神经网络量化 TensorRT深刻解读 目录 神经网络量化 TensorRT深刻解读 前言 一 TensorRT简介 二 难点 1 架构 2 功能 三 实现 1 conv和ReLU的融合 2 conv和ReLU的融合 quant utils
  • oracle 解锁 账户_oracle用户解锁三种方法

    ORA 28000 the account is locked 的解决办法 2009 11 11 18 51 ORA 28000 the account is locked 第一步 使用 PL SQL 登录名为 system 数据库名称不变
  • python cplex优化包工具箱教程

    python cplex优化包教程 在做优化课题时 常常需要用到优化算法 个人优化算法专栏链接如下 最优化实战例子 需要掌握一些优化算法 但是一些比较出名的优化工具箱还是要会用 今天讲解下cplex工具箱 CPLEX Optimizer 是
  • RocketMQ-实际开发中遇到的几个问题

    消息幂等性 什么是幂等性 一个操作任意执行多次与执行一次的结果相同 这个操作就是幂等 生产者发送消息之后 为了确保消费者消费成功 我们通常会采用手动签收方式确认消费 MQ就是使用了消息超时 重传 确认机制来保证消息必达 场景 1 订单服务
  • 使用Spark ALS模型 + Faiss向量检索实现用户扩量实例

    1 通过ALS模型实现用户 商品Embedding的效果 获得其向量表示 准备训练数据 M U I R 即 用户集U 商品集I 及评分数据R 1 商品集I的选择 可以根据业务目标确定商品候选集 比如TopK热度召回 或者流行度不高但在业务用