【推荐算法】双塔模型代码(tensorflow)

2023-11-20

【推荐算法】双塔模型介绍_MachineCYL的博客-CSDN博客

上文介绍了双塔模型的原理和结构,这篇介绍一下双塔模型的代码实现。我使用的是tensorflow来实现双塔模型和模型训练。

一、前期准备

  • tensorflow使用的版本是2.0.0
  • 数据格式(如果需要获取数据,可以见下方链接):

 不过由于是Demo,我只用了部分字段进行训练。

二、详细代码

  • DSSM模型代码
import tensorflow as tf

def dssm_model(feature_inputs, item_feature_columns, user_feature_columns, hidden_units):
    item_tower = tf.keras.layers.DenseFeatures(item_feature_columns)(feature_inputs)
    for num_nodes in hidden_units:
        item_tower = tf.keras.layers.Dense(num_nodes, activation='relu')(item_tower)

    user_tower = tf.keras.layers.DenseFeatures(user_feature_columns)(feature_inputs)
    for num_nodes in hidden_units:
        user_tower = tf.keras.layers.Dense(num_nodes, activation='relu')(user_tower)

    output = tf.keras.layers.Dot(axes=1)([item_tower, user_tower])
    output = tf.keras.layers.Dense(1, activation='sigmoid')(output)

    model = tf.keras.Model(feature_inputs, output)
    return model
  • 模型训练、预测与保存代码
def gen_dataset(data_df: pd.DataFrame, columns: dict):
    data_dict = dict()

    def _get_type(type_str):
        if type_str == "int32":
            return np.int32
        elif type_str == "float32":
            return np.float32
        elif type_str == "string" or type_str == "str":
            return np.str
        else:
            return np.int32

    for key in columns.keys():
        data_dict[key] = np.array(data_df[key]).astype(_get_type(columns[key]))

    return data_dict


def parse_argvs():
    parser = argparse.ArgumentParser(description='[DSSM]')
    parser.add_argument("--data_path", type=str, default='./data/')
    parser.add_argument("--model_path", type=str, default='./model_param')
    parser.add_argument("--epoch", type=int, default=10)
    parser.add_argument("--monitor", type=str, default="val_accuracy", choices=["val_accuracy", "val_auc"])
    parser.add_argument("--batch_size", type=int, default=12)
    args = parser.parse_args()
    print('[input params] {}'.format(args))

    return parser, args


if __name__ == '__main__':
    parser, args = parse_argvs()
    data_path = args.data_path
    model_path = args.model_path
    monitor = args.monitor
    epoch = args.epoch
    batch_size = args.batch_size

    # ====================================================================================
    # read data
    data_path = os.path.abspath(data_path)
    print("[DSSM] read file path: {}".format(data_path))
    train_data = pd.read_csv(os.path.join(data_path, "trainingSamples.csv"), sep=",")
    test_data = pd.read_csv(os.path.join(data_path, "testSamples.csv"), sep=",")
    data_pd = pd.concat([train_data, test_data])

    # ====================================================================================
    # define input for keras model
    columns_dict = {
        'movieId': 'int32',
        'movieGenre1': 'string',
        'movieAvgRating': 'float32',
        'userId': 'int32',
        'userGenre1': 'string',
        'userAvgRating': 'float32'
    }

    inputs = dict()
    for key in columns_dict.keys():
        inputs[key] = tf.keras.layers.Input(name=key, shape=(), dtype=columns_dict[key])
    print("[DSSM] input for keras model: \n {}".format(inputs))

    # ====================================================================================
    # movie embedding feature
    movie_col = tf.feature_column.categorical_column_with_identity(key='movieId', num_buckets=1001)
    movie_emb_col = tf.feature_column.embedding_column(movie_col, 10)

    movie_genre_1_vocab = data_pd['movieGenre1'].dropna().unique()
    movie_genre_1_col = tf.feature_column.categorical_column_with_vocabulary_list(key='movieGenre1',
                                                                                  vocabulary_list=movie_genre_1_vocab)
    movie_genre_1_emb_col = tf.feature_column.embedding_column(movie_genre_1_col, 10)

    movie_avg_rating = tf.feature_column.numeric_column(key='movieAvgRating')

    # user embedding feature
    user_col = tf.feature_column.categorical_column_with_identity(key='userId', num_buckets=30001)
    user_emb_col = tf.feature_column.embedding_column(user_col, 10)

    user_genre_1_vocab = data_pd['userGenre1'].dropna().unique()
    user_genre_1_col = tf.feature_column.categorical_column_with_vocabulary_list(key='userGenre1',
                                                                                 vocabulary_list=user_genre_1_vocab)
    user_genre_1_emb_col = tf.feature_column.embedding_column(user_genre_1_col, 100)

    user_avg_rating = tf.feature_column.numeric_column(key='userAvgRating')

    # ====================================================================================
    # train model
    model = dssm_model(feature_inputs=inputs,
                       item_feature_columns=[movie_emb_col, movie_genre_1_emb_col, movie_avg_rating],
                       user_feature_columns=[user_emb_col, user_genre_1_emb_col, user_avg_rating],
                       hidden_units=[30, 10])

    model.compile(
        loss='binary_crossentropy',
        optimizer='adam',
        metrics=['accuracy', tf.keras.metrics.AUC(curve='ROC')])

    filepath = os.path.join(model_path, "checkpoint", "dssm-weights-best.hdf5")
    checkpoint = tf.keras.callbacks.ModelCheckpoint(
        filepath, monitor=monitor, verbose=1, save_best_only=True, mode='max')

    train_data_input = gen_dataset(data_df=train_data, columns=columns_dict)
    model.fit(x=train_data_input, y=train_data["label"].values,
              epochs=epoch, callbacks=[checkpoint], verbose=2, batch_size=batch_size, validation_split=0.1)

    # ====================================================================================
    # predict, use best model.
    test_data_input = gen_dataset(data_df=test_data, columns=columns_dict)
    model.load_weights(filepath=filepath)

    pred_ans = model.predict(x=test_data_input, batch_size=batch_size)
    print("\n[BEST] ===============================================================")
    print("[test] LogLoss: {} ".format(round(log_loss(test_data["label"].values, pred_ans), 4)))
    print("[test] Accuracy: {} ".format(round(accuracy_score(test_data["label"].values, pred_ans >= 0.5), 4)))
    print("[test] AUC: {} ".format(round(roc_auc_score(test_data["label"].values, pred_ans), 4)))
    print("[test] classification_report: \n{} ".format(classification_report(test_data["label"].values, pred_ans >= 0.5, digits=4)))

    # ====================================================================================
    # save model
    model_path = os.path.abspath(model_path)
    print("[DSSM] save model path: {}".format(model_path))

    model.summary()
    tf.keras.models.save_model(
        model,
        os.path.join(model_path, "dssm"),
        overwrite=True,
        include_optimizer=True,
        save_format=None,
        signatures=None,
        options=None
    )
  • 运行结果展示(部分)

需要获取训练数据和代码可以访问我的github,如果觉得有帮助,请star收藏,谢谢~

DSSM代码

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

【推荐算法】双塔模型代码(tensorflow) 的相关文章

随机推荐

  • 第二十一章 webpack5原理loader概述

    简介 loader其实是一个函数 用来帮助 webpack 将不同类型的文件转换为 webpack 可识别的模块 loader的分类以及执行顺序 1 分类 pre 前置loader normal 普通loader inline 内联load
  • 编译型语言和解释型语言各自的特点和区别,Python的解释器

    编译型语言和解释型语言各自的特点和区别 Python的解释器 编译型语言 将源代码通过编译器编译生成可执行文件 机器指令 再由机器运行机器码 解释型语言 通过解释器逐行解释每一句源代码 打个比方 编译型相当于用中英文词典 翻译器 将一本英文
  • Vue如何封装组件

    要封装一个 Vue 组件 可以按照以下步骤进行操作 创建一个新的 Vue 单文件组件 vue 文件 并命名为你的组件名 例如 MyComponent vue 在组件文件中 使用
  • 关于python传参引发的一些思考

    人总有不会的 遇到一些问题深究下去必定有所收获 这个问题是在我写python爬虫项目的时候的疑问 可能是我太菜了 以前没学透彻 也可能是上学期学Java的时候按值传递的特点给搞混了 因为当时在用多线程的生产者消费者问题处理资源队列 参考别人
  • task_5 - 副本

    Task01 Task06树模型与集成学习笔记整理 1 Task01 信息论基础 决策树分类思想 用树的节点代表样本集合 通过某些判定条件来对节点内的样本进行分配 将它们划分到当前节点下的子节点 这样决策树希望各个子节点中类别的纯度之和应高
  • 内存文件系统提升磁盘性能瓶颈

    author skate time 2011 08 22 提升磁盘性能瓶颈 linux的内存文件系统 ramdisk ramfs tmpfs ramdisk 是块设备 在使用它们之前必须用选择文件系统将其格式化 并且调整文件系统大小比较麻烦
  • 【廖雪峰python进阶笔记】模块

    1 导入模块 要使用一个模块 我们必须首先导入该模块 Python使用import语句导入一个模块 例如 导入系统自带的模块 math import math 你可以认为math就是一个指向已导入模块的变量 通过该变量 我们可以访问math
  • Python Pandas导出Hbase数据到dataframe

    Python导出Hbase数据的思路 使用happybase连接Hbase 使用table scan 扫数据 将得到的数据整理为dataframe格式 将从Hbase中得到的byte类型的数据转为str类型的数据 示例代码 import h
  • 数据结构之哈希(C++实现)

    数据结构之哈希 C 1 哈希概念 顺序结构以及平衡树中 元素关键码与存储位置之间没有对应关系 因此在查找一个元素的时候 要经过关键码多次比较 顺序表查找的时间复杂度为O N 而平衡树中树的高度为O log 2 N 搜索的效率取决于搜索过程中
  • Mybatis

    文章目录 前言 业务逻辑 使用Mybatis实现 使用Mybatis plus实现 前言 工作的时候 遇到了需要将一个数据库的一些数据插入或更新到另一个数据库 一开始使用insert into TABLE col1 col2 VALUES
  • 全国大学生计算机技能应用大赛Java模拟题

    全国大学生计算机技能应用大赛Java模拟题 竞赛官网 http www cnccac com 单选题 1 以下哪个不是java的垃圾回收算法 A 标记清除算法 B 空间分配算法 C 标记整理算法 D 分代回收算法 2 下列名称在java语言
  • cocos 基础动作加上简单特效

    使用文理缓存创建精灵 cc Director getInstance getTextureCache addImage WechatIMG3 png localsp cc Sprite createWithTexture cc Direct
  • Error inflating class androidx.constraintlayout.widget.ConstraintLayout

    今天下载了android studio 3 3 1体验体验新版本来着 没想到新建项目直接来了个这个 android view InflateException Binary XML file line 2 Error inflating c
  • 常见的距离算法和相似度(相关系数)计算方法

    摘要 1 常见的距离算法 1 1欧几里得距离 Euclidean Distance 以及欧式距离的标准化 Standardized Euclidean distance 1 2马哈拉诺比斯距离 Mahalanobis Distance 1
  • vue3 ---- 递归组件生成menu菜单 && 路由守卫鉴权

    目录 递归组件 el menu 父组件 子组件 路由 Vue路由守卫实现登录鉴权 全局守卫 路由独享的守卫 组件内的守卫 完整的导航解析流程 菜单权限 按钮权限 对于一些有规律的DOM结构 如果我们再一遍遍的编写同样的代码 显然代码是比较繁
  • IDEA切换分支导致项目异常, 部分类爆红问题解决

    关于idea切换分支导致项目异常爆红的方式解决两种办法 1 maven 并没有及时刷新 所以 当我们第一时间出现这个问题的时候 首选是刷新maven 如图所示 2 如果刷新mavne 还是没有解决idea 项目爆红的情况的话 那我们就需要考
  • 计算机不能创建用户,Windows10系统无法创建新用户该怎么办?

    由于工作需要 需要对同一台计算机创建多个用户帐户 Windows7操作系统创建新用户的方法很简单 简单几步就能够轻松完成创建 参照Windows7操作系统创建新用户的步骤 发现并不适用于Windows10操作系统 系统会提示需要登录Micr
  • CocosCreator波浪Shader

    waveEffect effect Copyright c 2017 2020 Xiamen Yaji Software Co Ltd CCEffect techniques passes vert sprite vs vert frag
  • Serverless 的前世今生

    作者 阿里云用户组 从云计算到 Serverless 架构 大家好 我是阿里云 Serverless 产品经理刘宇 很高兴可以和大家一起探索 Serverless 架构的前世今生 从云计算到云原生再到 Serverless 架构 技术飞速发
  • 【推荐算法】双塔模型代码(tensorflow)

    推荐算法 双塔模型介绍 MachineCYL的博客 CSDN博客 上文介绍了双塔模型的原理和结构 这篇介绍一下双塔模型的代码实现 我使用的是tensorflow来实现双塔模型和模型训练 一 前期准备 tensorflow使用的版本是2 0