通过 TensorFlow 中 CSV 的分类特征数组列创建多热 SparseTensor

2023-12-28

这是推荐系统中处理稀疏特征(例如一些ID特征)的典型方式。我正在寻找一种方便的方法来为 TensorFlow 管道准备数据。

我做了很多搜索,但尚未找到好的解决方案。

下面是似乎接近我需要的,但尚未工作。

See #######下面的部分

数据文件如下:

csv = [
  '1221,cc,1',
  '213,aa|cc|ff,1',
]

对于第二行,我需要一些 SparseTensor,例如 multi-hot

 aa bb cc dd ee ff
| 0  0  1  0  0  0 |
| 1  0  1  0  0  1 |

完整版本的代码是:

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import os
import shutil
import sys

import tensorflow as tf  # pylint: disable=g-bad-import-order

_CSV_COLUMNS = ['a_id', 'b_id', 'tags', 'label']
_CSV_COLUMN_DEFAULTS = [[0], [0], [''], [0]]


def input_fn(data_file, num_epochs, shuffle, batch_size):
    """Generate an input function for the Estimator."""

    assert tf.gfile.Exists(data_file), (
        '%s not found. Please make sure you have run data_download.py and '
        'set the --data_dir argument to the correct path.' % data_file)

    """
$ cat vocab.txt
a
b
c
d
e
f
g
h
i
j
k
l
m
n
    """
    table = tf.contrib.lookup.index_table_from_file(
        vocabulary_file='vocab.txt', num_oov_buckets=1)

    def parse_csv(value):
        print('Parsing', data_file)
        columns = tf.decode_csv(value, record_defaults=_CSV_COLUMN_DEFAULTS)
        features = dict(zip(_CSV_COLUMNS, columns))

        ########################  BEGIN  ###########################

        # support multi-hot sparse features
        split_tags = tf.string_split([columns[2]], '|')  # hard-coded 'tags' column index
        # Output: tags.indices Tensor("StringSplit:0", shape=(?, 2), dtype=int64)
        print('tags.indices', split_tags.indices)

        indice_idx = tf.map_fn(lambda x : x[0], split_tags.indices)
        # Output: indice_idx Tensor("map/TensorArrayStack/TensorArrayGatherV3:0", shape=(?,), dtype=int64)
        print('indice_idx', indice_idx)
        value_idx = tf.map_fn(lambda x : x[1], split_tags.indices)

        value_arr = tf.cast(tf.gather(split_tags.values, value_idx), tf.int64)
        # Output:  value_arr shape (?,)
        print('value_arr shape', value_arr.shape)

        # stack is doing: [1, 2, 3], [4, 5, 6] ==> [[1, 2], [3, 4], [5,6]]
        new_indices = tf.stack([indice_idx, value_arr], axis=1)
        print('new_indices', new_indices)

        new_values = tf.ones_like(value_arr)
        # Output:  new_values Tensor("ones_like:0", shape=(?,), dtype=int64)
        print('new_values', new_values)

        with tf.Session() as s1:
            s1.run([tf.global_variables_initializer(), tf.tables_initializer()])
            ##### FAIL here #####
            # InvalidArgumentError (see above for traceback): You must feed a value for placeholder tensor 'arg0' with dtype string
            # [[Node: arg0 = Placeholder[dtype=DT_STRING, shape=<unknown>, _device="/job:localhost/replica:0/task:0/device:CPU:0"]()]]
            print(split_tags.values.eval())
            print(indice_idx.eval())
            print('value_arr', value_arr.eval())
            print('new_values', new_values.eval())

        categorial_tensor = tf.SparseTensor(
            indices=new_indices,
            values=new_values,
            dense_shape=[new_indices.shape[1], 4])

        ########################   END   ###########################

        categorical_cols = {
            'tags': categorial_tensor}

        features.update(categorical_cols)

        labels = features.pop('label')
        return features, tf.equal(labels, 1)

    # Extract lines from input files using the Dataset API.
    dataset = tf.data.TextLineDataset(data_file)

    if shuffle:
        dataset = dataset.shuffle(buffer_size=6)  # num of lines in the file

    dataset = dataset.map(parse_csv, num_parallel_calls=5)

    # We call repeat after shuffling, rather than before, to prevent separate
    # epochs from blending together.
    dataset = dataset.repeat(num_epochs)
    dataset = dataset.batch(batch_size)
    return dataset


"""
$ cat data.csv
1,2,a|c|g,1
0,1,c|f,0
0,2,b|g,1
0,1,b|v,0
0,1,g|j|k|l,1
0,1,a,0
"""
train_file = 'data.csv'
epochs_between_evals = 2
batch_size = 40
ds = input_fn(train_file, epochs_between_evals, True, batch_size)

with tf.Session() as s:
    s.run([tf.global_variables_initializer(), tf.tables_initializer()])
    print(s.run(ds))

None

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

通过 TensorFlow 中 CSV 的分类特征数组列创建多热 SparseTensor 的相关文章

  • 使用预训练(Tensorflow)CNN 提取特征

    深度学习已成功应用于多个大型数据集 用于对少数类别 猫 狗 汽车 飞机等 进行分类 其性能优于 SIFT 特征袋 颜色直方图等更简单的描述符 然而 训练这样的网络需要每个类别大量的数据和大量的训练时间 然而 在花时间设计和训练这样一种设备并
  • 使用张量流导出神经网络的权重

    我使用张量流工具编写了神经网络 一切正常 现在我想导出神经网络的最终权重以制定单一的预测方法 我怎样才能做到这一点 您需要在训练结束时使用以下命令保存模型tf train Saver https www tensorflow org ver
  • 在基本 Tensorflow 2.0 中运行简单回归

    我正在学习 Tensorflow 2 0 我认为在 Tensorflow 中实现最基本的简单线性回归是一个好主意 不幸的是 我遇到了几个问题 我想知道这里是否有人可以提供帮助 考虑以下设置 import tensorflow as tf 2
  • model.predict() 返回类而不是概率

    Hello 我是第一次使用 Keras 我训练并保存了一个模型 作为 json 文件及其权重 该模型旨在将图像分为 3 个类别 我的编译方法 model compile loss categorical crossentropy optim
  • 在 Tensorflow2 中将图冻结为 pb

    我们通过图形冻结保存来自 TF1 的许多模型 tf train write graph self session graph def some path get graph definitions with weights output g
  • TensorFlow:Dst 张量未初始化

    The MNIST For ML Beginners当我运行时教程给我一个错误print sess run accuracy feed dict x mnist test images y mnist test labels 其他一切都运行
  • 期望最大化算法的数值示例[重复]

    这个问题在这里已经有答案了 由于我不确定给出的公式 有人可以提供 EM 算法的简单数字示例吗 一个非常简单的具有 4 或 5 个笛卡尔坐标的坐标就可以了 那这个呢 http en wikibooks org wiki Data Mining
  • 如何用Python构建游戏神经网络?

    我是神经网络初学者 我想通过教计算机下跳棋来学习神经网络的基础知识 其实我想学的游戏是盛气凌人 http en wikipedia org wiki Domineering and Hex http en wikipedia org wik
  • Keras 错误:预计会看到 1 个数组

    当我尝试在 keras 中训练 MLP 模型时出现以下错误 我使用的是 keras 版本1 2 2 检查模型输入时出错 您输入的 Numpy 数组列表 传递给您的模型的尺寸不是模型预期的尺寸 预期的 查看 1 个数组 但得到以下 12859
  • 如何跨多个文本文件查找字典中键的频率?

    我应该计算文档 individual articles 中所有文件中字典 d 的所有键值的频率 这里 文档 individual articles 大约有20000个txt文件 文件名为1 2 3 4 例如 假设 d Britain 5 7
  • 在 Windows 上,运行“导入张量流”会生成“没有名为“_pywrap_tensorflow”的模块”错误

    在 Windows 上 TensorFlow 在执行后报告以下一个或两个错误import tensorflow陈述 No module named pywrap tensorflow DLL load failed 对我来说问题是 cuDN
  • 敏感性特异性图 python

    我正在尝试重现类似于此的灵敏度特异性图 其中 X 轴是阈值 但我还没有找到如何做到这一点 一些 skalern 指标 如 ROC 曲线 会返回真阳性和假阳性 但我还没有找到任何选项来制作此图 我试图将概率与实际标签进行比较以保持计数 我得到
  • 是否可以使用具有余弦相似度的 KDTree?

    看来我不能使用这个相似度度量sklearn例如 KDTree 但我需要 因为我正在使用测量单词向量相似度 对于这种情况 快速鲁棒定制算法是什么 我知道关于Local Sensitivity Hashing 但它应该经过大量调整和测试才能找到
  • Tensorflow 初始化给出所有 1

    张量流1 12 0 在下面的代码片段中 wrapped rv val和seq rv val似乎应该是等效的 但事实并非如此 相反 seq rv val 被正确初始化为随机生成的 init val 数组 但wrapped rv val 设置为
  • 什么是tensorflow.python.data.ops.dataset_ops._OptionsDataset?

    我正在使用来自tensorflow的Transformer代码 https www tensorflow org beta tutorials text transformer https www tensorflow org beta t
  • 如何在 keras 模型中使用张量流度量函数?

    使用Python 3 5 2张量流RC 1 1 我正在尝试在 keras 中使用张量流度量函数 所需的功能接口似乎是相同的 但调用 import pandas import numpy import tensorflow contrib k
  • 如何在 python 中使用交叉验证执行 GridSearchCV

    我正在执行超参数调整RandomForest如下使用GridSearchCV X np array df features all features y np array df gold standard labels x train x
  • 机器学习的周期性数据(例如度角 -> 179 与 -179 相差 2)

    我使用 Python 进行核密度估计 并使用高斯混合模型对多维数据样本的可能性进行排名 每一条数据都是一个角度 我不确定如何处理机器学习的角度数据的周期性 首先 我通过添加 360 来删除所有负角 因此所有负角都变成了正角 179 变成了
  • 在自定义 keras 层的调用函数中传递附加参数

    我创建了一个自定义 keras 层 目的是在推理过程中手动更改前一层的激活 以下是基本层 它只是将激活值乘以一个数字 import numpy as np from keras import backend as K from keras
  • 如何使用 keras.backend.gradients() 获取梯度值

    我试图获得 Keras 模型的输出相对于模型输入 x 而不是权重 的导数 似乎最简单的方法是使用 keras backend 中的 梯度 它返回梯度张量 https keras io backend https keras io backe

随机推荐

  • 转储 HPROF 文件时 DDMS 未显示预期输出 [关闭]

    这个问题不太可能对任何未来的访客有帮助 它只与一个较小的地理区域 一个特定的时间点或一个非常狭窄的情况相关 通常不适用于全世界的互联网受众 为了帮助使这个问题更广泛地适用 访问帮助中心 help reopen questions I am
  • Ajax 更新程序在 Internet Explorer 中不工作

    我有一个 ajax 更新程序的问题 我无法解决 我有这个代码服务 我曾经遇到过这个问题 我假设您有一个函数 可以定期向服务器发出请求以获取数据 然后如果自上次发出请求以来该数据已更改 则更新页面 即使您在 HTTP 标头中告诉它不要缓存 I
  • Python:将字符串转换为函数名; getattr 还是相等?

    我正在编辑 PROSS py 以使用蛋白质结构的 cif 文件 在现有的 PROSS py 内部 有以下函数 如果它不与任何类关联 我相信这是正确的名称 它们仅存在于 py 文件中 def unpack pdb line line ATOF
  • 签名 Jar 文件中的 Spring 组件扫描 (@Autowire) 速度很慢

    几年前 我们在独立的 java 应用程序中遇到了 Spring 组件扫描缓慢的问题 所以我在 stackoverflow 中询问 慢弹簧元件扫描 https stackoverflow com questions 17747364 slow
  • Firebase HTTP 请求的云函数

    我想从 Android 向云功能发送 HTTP 请求 发布一些值 然后将这些值输入到实时数据库中 index js const functions require firebase functions exports testPost fu
  • AngularJS 指令链接函数在 jasmine 测试中未调用

    我正在创建一个元素指令 该指令在其中调用服务link功能 app directive depositList depositService function depositService return templateUrl deposit
  • 是否可以使用仅限 Tab 的智能感知补全作为 Visual Studio 2019 中所有文件的默认设置?

    默认情况下 Visual Studio 2019 Intellisense 使用自动完成代替仅制表符完成对于所有文件 在每个文件的基础上 您可以使用以下命令 在自动和仅选项卡智能感知完成之间切换 编辑 gt 智能感知 gt 菜单选项 工具栏
  • 为什么标准化设备坐标系是左手坐标系?

    起初我想知道为什么 NDC 的范围是从 1 到 1 而不是从 0 到 1 我想也许将原点放在中心是有用的 但为什么它使用左手坐标系呢 是否只是距离较远的物体 Z 值较高 这对我来说已经是一个足够好的理由了 但为什么它使用左手坐标系呢 让我们
  • 将 JPA query.getResultList() 转换为 MY 对象

    我正在对我的数据库执行查询JPA 该查询 查询 4 个表 结果聚合来自不同表的列 我的查询是这样的 Query query em createQuery SELECT o A o B o C e D c E FROM Table1 o Ta
  • 使用 SmtpClient 发送邮件作为回复

    场景 需要发送一封邮件 该邮件实际上是来自 asp net c 程序的回复邮件 我设法将邮件发送给客户端 但它作为新邮件发送 Code var SMTP genRepository GetData SELECT FROM LOCATION
  • 修改 xargs 中的替换字符串

    当我使用时xargs有时我不需要显式使用替换字符串 find name txt xargs rm rf 在其他情况下 我想指定替换字符串以便执行以下操作 find name txt xargs I mv foo bar 上一个命令会将当前目
  • Android 应用程序中的所有 Activity 共享一个 SQLiteOpenHelper 实例是否可以?

    将 SQLiteOpenHelper 的单个实例作为子类应用程序的成员 并且让所有需要 SQLiteDatabase 实例的活动从一个帮助器获取它是否可以 单击此处查看我关于此主题的博客文章 http www androiddesignpa
  • Java 可序列化、ObjectInputstream、非阻塞 I/O

    我刚刚开始使用 Java 序列化 我不清楚如何在非阻塞 I O 的场景中从源获取对象 我能找到的所有文档都表明使用 ObjectInputStream 是proper读取序列化对象的方法 但是 正如我提到的 我正在使用 java nio 并
  • DAO架构的必要性是什么

    当用Java编程时 是否总是需要根据DAO架构来编码 如果是的话 使用它有什么好处 我正在做一个项目 其类图如下所示 这样做有什么缺点 实体类 private void fillSONumber try ZnAlSalesOrder o n
  • 相同变音符号(变音符号)的不同 UTF-8 签名 - 编写变音符号的 2 种二进制方法

    我有一个很大的问题 我在网上找不到任何帮助 我将一个网站的页面从 OSX 移至 Linux 两个系统都在 de DE UTF 8 中运行 并遇到了一个相当未知的问题 有些文件不再被发现 但显然以 明显 相同的名称存在于硬盘上 所有这些文件都
  • Flink 作业在集群节点上的分布

    我们有 4 个作业 运行在 3 个节点上 每个节点有 4 个槽位 在 Flink 1 3 2 上 作业均匀分布在每个节点上 升级到 flink 1 5 后 每个作业都在单个节点上运行 如果没有剩余插槽 则可以转移到另一个节点 有没有办法恢复
  • 如何在WP8中引用System.Net.Http?

    我对 WP8 开发相对较新 并且遇到了一个我无法解决的问题 即使经过几个小时的谷歌搜索后也是如此 我正在使用 Visual Studio 2012 并使用 NuGet 实现了 System Net Http 已检查引用 复制本地设置为 tr
  • 无法在 IIS 7 Windows Server 2008 64 位上运行经典 ASP

    我们有几个用经典 ASP 构建的 Web 应用程序 目前在 Windows Server 2003 32 位和 IIS 6 上运行 我们正在尝试将其迁移到运行带有 IIS 7 的 Windows Server 2008 64 位的新服务器
  • 可以在命令行上指定额外的插件目录吗

    使用 Eclipse 3 4 是否可以从命令行提供附加插件目录 就像是 eclipse plugin dir D myproduct V1 1 plugins clean 这只是为了节省每次插件的复制 虽然可以使用脚本完成复制 但用户可能没
  • 通过 TensorFlow 中 CSV 的分类特征数组列创建多热 SparseTensor

    这是推荐系统中处理稀疏特征 例如一些ID特征 的典型方式 我正在寻找一种方便的方法来为 TensorFlow 管道准备数据 我做了很多搜索 但尚未找到好的解决方案 下面是似乎接近我需要的 但尚未工作 See 下面的部分 数据文件如下 csv