Swing算法原理比较简单,是阿里早期使用到的一种召回算法,在阿里多个业务被验证过非常有效的一种召回方式,它认为 user-item-user 的结构比 itemCF 的单边结构更稳定。
Swing指的是秋千,例如用户 u uu 和用户 v vv ,都购买过同一件商品i ii,则三者之间会构成一个类似秋千的关系图。若用户 u 和用户 v 之间除了购买过 i 外,还购买过商品 j ,则认为两件商品是具有某种程度上的相似的。
也就是说,商品与商品之间的相似关系,是通过用户关系来传递的。为了衡量物品 i 和 j j的相似性,考察都购买了物品 i 和 j 的用户 u和用户 v , 如果这两个用户共同购买的物品越少,则物品 i和 j 的相似性越高。
Swing算法的表达式如下:
swing python 实现:
# -*- coding: utf-8 -*-
"""
Author : Thinkgamer
File : Swing.py
Software: PyCharm
Desc : 基于movie lens数据集实现Swing算法
"""
import pandas as pd
from itertools import combinations
import json
import os
alpha = 0.5
top_k = 20
def load_data(train_path, test_path):
train_data = pd.read_csv(train_path, sep="\t", engine="python", names=["userid", "movieid", "rate", "event_timestamp"])
test_data = pd.read_csv(test_path, sep="\t", engine="python", names=["userid", "movieid", "rate", "event_timestamp"])
print(train_data.head(5))
print(test_data.head(5))
return train_data, test_data
def get_uitems_iusers(train):
u_items = dict()
i_users = dict()
for index, row in train.iterrows():
u_items.setdefault(row["userid"], set())
i_users.setdefault(row["movieid"], set())
u_items[row["userid"]].add(row["movieid"])
i_users[row["movieid"]].add(row["userid"])
print("使用的用户个数为:{}".format(len(u_items)))
print("使用的item个数为:{}".format(len(i_users)))
return u_items, i_users
def cal_similarity(u_items, i_users):
item_pairs = list(combinations(i_users.keys(), 2))
print("item pairs length:{}".format(len(item_pairs))) # 1410360
item_sim_dict = dict()
cnt = 0
for (i, j) in item_pairs:
cnt += 1
print(cnt)
user_pairs = list(combinations(i_users[i] & i_users[j], 2))
result = 0.0
for (u, v) in user_pairs:
result += 1 / (alpha + list(u_items[u] & u_items[v]).__len__())
item_sim_dict.setdefault(i, dict())
item_sim_dict[i][j] = result
# print(item_sim_dict[i][j])
return item_sim_dict
def save_item_sims(item_sim_dict, path):
new_item_sim_dict = dict()
for item, sim_items in item_sim_dict.items():
new_item_sim_dict.setdefault(item, dict())
new_item_sim_dict[item] = dict(sorted(sim_items.items(), key = lambda k:k[1], reverse=True)[:top_k])
json.dump(new_item_sim_dict, open(path, "w"))
print("item 相似 item({})保存成功!".format(top_k))
return new_item_sim_dict
def evaluate(item_sim_dict, test):
# 可以参考《推荐系统开发实战》中的cf验证方式
pass
if __name__ == "__main__":
train_data_path = "../../data/ml-100k/ua.base"
test_data_path = "../../data/ml-100k/ua.test"
item_sim_save_path = "../../model/swing/item_sim_dict.json"
train, test = load_data(train_data_path, test_data_path)
if not os.path.exists(item_sim_save_path):
u_items, i_users = get_uitems_iusers(train)
item_sim_dict = cal_similarity(u_items, i_users)
new_item_sim_dict = save_item_sims(item_sim_dict, item_sim_save_path)
else:
new_item_sim_dict = json.load(open(item_sim_save_path, "r"))
evaluate(new_item_sim_dict, test)
————————————————
.Swing Spark实现
创建Swing类,其中的评估函数和predict函数这里并未提供,感兴趣的可以自己实现
/**
* @ClassName: Swing
* @Description: 实现Swing算法
* @author: Thinkgamer
**/
class SwingModel(spark: SparkSession) extends Serializable{
var alpha: Option[Double] = Option(0.0)
var items: Option[ArrayBuffer[String]] = Option(new ArrayBuffer[String]())
var userIntersectionMap: Option[Map[String, Map[String, Int]]] = Option(Map[String, Map[String, Int]]())
/*
* @Description 给参数 alpha赋值
* @Param double
* @return cf.SwingModel
**/
def setAlpha(alpha: Double): SwingModel = {
this.alpha = Option(alpha)
this
}
/*
* @Description 给所有的item进行赋值
* @Param [array]
* @return cf.SwingModel
**/
def setAllItems(array: Array[String]): SwingModel = {
this.items = Option(array.toBuffer.asInstanceOf[ArrayBuffer[String]])
this
}
/*
* @Description 获取两两用户有行为的item交集个数
* @Param [spark, data]
* @return scala.collection.immutable.Map<java.lang.String,scala.collection.immutable.Map<java.lang.String,java.lang.Object>>
**/
def calUserRateItemIntersection(data: RDD[(String, String, Double)]): Map[String, Map[String, Int]] = {
val rdd = data.map(l => (l._1, l._2)).groupByKey().map(l => (l._1, l._2.toSet))
val map = (rdd cartesian rdd).map(l => (l._1._1, (l._2._1, (l._1._2 & l._2._2).toArray.length)))
.groupByKey()
.map(l => (l._1, l._2.toMap))
.collectAsMap().toMap
map.take(10).foreach(println)
map
}
def fit(data: RDD[(String, String, Double)]): RDD[(String, String, Double)]= {
this.userIntersectionMap = Option(this.calUserRateItemIntersection(data))
println(this.userIntersectionMap.take(10))
val rdd = data.map(l => (l._2, l._1)).groupByKey().map(l => (l._1, l._2.toSet))
val result: RDD[(String, String, Double)] = (rdd cartesian rdd).map(l => {
val item1 = l._1._1
val item2 = l._2._1
val intersectionUsers = l._1._2 & l._2._2
var score = 0.0
for(u1 <- intersectionUsers){
for(u2 <- intersectionUsers){
score += 1.0 / (this.userIntersectionMap.get.get(u1).get(u2).toDouble + this.alpha.get)
}
}
(item1, item2, score) // (item1, item2, swingsocre)
})
result
}
def evalute(test: RDD[(String, String, Double)]) = { }
def predict(userid: String) = { }
def predict(userids: Array[String]) = { }
}
main函数调用
object Swing {
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder().master("local[10]").appName("Swing").enableHiveSupport().getOrCreate()
Logger.getRootLogger.setLevel(Level.WARN)
val trainDataPath = "data/ml-100k/ua.base"
val testDataPath = "data/ml-100k/ua.test"
import spark.sqlContext.implicits._
val train: RDD[(String, String, Double)] = spark.sparkContext.textFile(trainDataPath).map(_.split("\t")).map(l => (l(0), l(1), l(2).toDouble))
val test: RDD[(String, String, Double)] = spark.sparkContext.textFile(testDataPath).map(_.split("\t")).map(l => (l(0), l(1), l(2).toDouble))
val items: Array[String] = train.map(_._2).collect()
val swing = new SwingModel(spark).setAlpha(1).setAllItems(items)
val itemSims: RDD[(String, String, Double)] = swing.fit(train)
swing.evalute(test)
swing.predict("")
swing.predict(Array("", ""))
spark.close()
}
}
转载
https://blog.csdn.net/Gamer_gyt/article/details/115678598