0. 协同过滤算法简介
协同过滤(Collaborative Filtering),简单来说是利用某兴趣相投、拥有共同经验之群体的喜好来推荐用户感兴趣的信息。根据关注内容的不同,协同过滤算法分为三类:
以用户为基础(User-based)的协同过滤:用相似统计的方法得到具有相似爱好或者兴趣的相邻用户,使用与推荐用户相似用户的感兴趣的项目进行推荐。
以项目为基础(Item-based)的协同过滤:“能够引起用户兴趣的项目,必定与其之前评分高的项目相似”,透过计算项目之间的相似性来代替用户之间的相似性。
以模型为基础(Model-based)的协同过滤:用历史数据得到一个模型,再用此模型进行预测。
mllib中实现了以模型为基础的协同过滤,使用als算法训练模型。
1. 数据源
book-crossing dataset:
其中包含三个文件
评分数据文件:"User-ID";"ISBN";"Book-Rating"
图书数据文件:"ISBN";"Book-Title";"Book-Author";"Year-Of-Publication";"Publisher";"Image-URL-S";"Image-URL-M";"Image-URL-L"
用户数据文件:“User-ID”;"Location";"Age"
2. 数据预处理
MLlib的ALS算法实现有一个小缺点:它要求user和item的ID必须是数值型,并且是32位非负整数。评分文件中userid为int类型,而ISBN为string类型,需要将其先转换为 int类型。这里我们采用将BX-Books中的所有图书一一对应到从1开始的自增id,使用hashmap保存映射关系,而后将BX-Book-Ratings中的ISBN映射到自增id上。在对数据进行映射的过程中发现,评分集中有ISBN未在图书表中出现,将此类型条目删除构成新的评分文件。
public static void processing() throws IOException {
ArrayList books = new ArrayList<>();
ArrayList ratings = new ArrayList<>();
CsvReader reader = new CsvReader("data/BX-Books.csv",';');
reader.readHeaders();
while (reader.readRecord()) {
books.add(reader.getValues());
}
reader.close();
reader = new CsvReader("data/BX-Book-Ratings.csv",';');
reader.readHeaders();
while (reader.readRecord()) {
ratings.add(reader.getValues());
}
reader.close();
//将isbn与自增int进行映射
HashMap map = new HashMap<>();
for(int i = 0 ; i < books.size() ; i++) {
map.put(books.get(i)[0],i+1);
}
//将isbn映射到int
FileWriter fileWriter = new FileWriter("data/book-rating.txt");
for(String[] rating:ratings) {
//当ISBN存在时
if(map.containsKey(rating[1])) {
fileWriter.write(rating[0].replaceAll("\"","")+";");
fileWriter.write(map.get(rating[1])+";");
fileWriter.write(rating[2].replaceAll("\"","")+"\n");
}
}
fileWriter.close();
}
3. 模型训练及推荐结果获取
创建一个类读取评分文件,在数据集中随机选取80%数据作为训练集,20%数据作为测试集。设置模型参数如最大迭代次数,正则项及冷启动策略等。全部参数如下:
numBlocks is the number of blocks the users and items will be partitioned into in order to parallelize computation (defaults to 10).
rank is the number of latent factors in the model (defaults to 10).
maxIter is the maximum number of iterations to run (defaults to 10).
regParam specifies the regularization parameter in ALS (defaults to 1.0).
implicitPrefs specifies whether to use the explicit feedback ALS variant or one adapted for implicit feedback data (defaults to false which means using explicit feedback).
alpha is a parameter applicable to the implicit feedback variant of ALS that governs the baseline confidence in preference observations (defaults to 1.0).
nonnegative specifies whether or not to use nonnegative constraints for least squares (defaults to false).public class Recommend {
public static class Rating implements Serializable {
private int userId;
private int bookId;
private float rating;
public Rating() {}
public Rating(int userId, int bookId, float rating) {
this.userId = userId;
this.bookId = bookId;
this.rating = rating;
}
public int getUserId() {
return userId;
}
public int getBookId() {
return bookId;
}
public float getRating() {
return rating;
}
public static Rating parseRating(String str) {
String[] fields = str.split(";");
if (fields.length != 3) {
throw new IllegalArgumentException("Each line must contain 3 fields");
}
int userId = Integer.parseInt(fields[0]);
int bookId = Integer.parseInt(fields[1]);
float rating = Float.parseFloat(fields[2]);
return new Rating(userId, bookId, rating);
}
}
public static void main(String[] args) {
SparkSession spark = SparkSession
.builder()
.appName("JavaALSExample")
.getOrCreate();
JavaRDD ratingsRDD = spark
.read().textFile("data/book-rating.txt").javaRDD()
.map(Rating::parseRating);
Dataset ratings = spark.createDataFrame(ratingsRDD, Rating.class);
Dataset[] splits = ratings.randomSplit(new double[]{0.8, 0.2});
Dataset training = splits[0];
Dataset test = splits[1];
ALS als = new ALS()
.setMaxIter(10)
.setRegParam(0.01)
.setUserCol("userId")
.setItemCol("bookId")
.setRatingCol("rating");
ALSModel model = als.fit(training);
// 冷启动策略
model.setColdStartStrategy("drop");
Dataset predictions = model.transform(test);
RegressionEvaluator evaluator = new RegressionEvaluator()
.setMetricName("rmse")
.setLabelCol("rating")
.setPredictionCol("prediction");
Double rmse = evaluator.evaluate(predictions);
System.out.println("Root-mean-square error = " + rmse);
// 全部用户推荐top10
Dataset userRecs = model.recommendForAllUsers(10);
// 全部图书推荐top10用户
Dataset bookRecs = model.recommendForAllItems(10);
// 部分用户推荐top10
Dataset users = ratings.select(als.getUserCol()).distinct().limit(3);
Dataset userSubsetRecs = model.recommendForUserSubset(users, 10);
// 部分图书推荐top10用户
Dataset books = ratings.select(als.getItemCol()).distinct().limit(3);
Dataset bookSubSetRecs = model.recommendForItemSubset(books, 10);
userRecs.show();
bookRecs.show();
userSubsetRecs.show(false); //不省略字符打印
bookSubSetRecs.show();
spark.stop();
}
}