Tensorflow 数据集 API 中的过采样功能

2023-11-22

我想问一下目前的数据集API是否允许实现过采样算法?我处理高度不平衡的阶级问题。我认为在数据集解析(即在线生成)过程中对特定类进行过采样会很好。我已经看到了rejection_resample函数的实现,但是这会删除样本而不是复制它们,并且会减慢批次生成速度(当目标分布与初始分布有很大不同时)。我想要实现的目标是:举个例子,看看它的类概率来决定是否重复它。然后打电话dataset.shuffle(...) dataset.batch(...)并获取迭代器。最好的(在我看来)方法是对低概率类别进行过采样并对最可能的类别进行子采样。我想在网上做,因为它更灵活。


这个问题已经在issue中解决了#14451。 只需在此处发布 anwser 即可使其对其他开发人员更加可见。

示例代码对低频类进行过采样,对高频类进行欠采样,其中class_target_prob在我的例子中只是均匀分布。我想检查最近手稿的一些结论卷积神经网络中类别不平衡问题的系统研究

特定类的过采样是通过调用完成的:

dataset = dataset.flat_map(
    lambda x: tf.data.Dataset.from_tensors(x).repeat(oversample_classes(x))
)

这是完成所有操作的完整代码片段:

# sampling parameters
oversampling_coef = 0.9  # if equal to 0 then oversample_classes() always returns 1
undersampling_coef = 0.5  # if equal to 0 then undersampling_filter() always returns True

def oversample_classes(example):
    """
    Returns the number of copies of given example
    """
    class_prob = example['class_prob']
    class_target_prob = example['class_target_prob']
    prob_ratio = tf.cast(class_target_prob/class_prob, dtype=tf.float32)
    # soften ratio is oversampling_coef==0 we recover original distribution
    prob_ratio = prob_ratio ** oversampling_coef 
    # for classes with probability higher than class_target_prob we
    # want to return 1
    prob_ratio = tf.maximum(prob_ratio, 1) 
    # for low probability classes this number will be very large
    repeat_count = tf.floor(prob_ratio)
    # prob_ratio can be e.g 1.9 which means that there is still 90%
    # of change that we should return 2 instead of 1
    repeat_residual = prob_ratio - repeat_count # a number between 0-1
    residual_acceptance = tf.less_equal(
                        tf.random_uniform([], dtype=tf.float32), repeat_residual
    )

    residual_acceptance = tf.cast(residual_acceptance, tf.int64)
    repeat_count = tf.cast(repeat_count, dtype=tf.int64)

    return repeat_count + residual_acceptance


def undersampling_filter(example):
    """
    Computes if given example is rejected or not.
    """
    class_prob = example['class_prob']
    class_target_prob = example['class_target_prob']
    prob_ratio = tf.cast(class_target_prob/class_prob, dtype=tf.float32)
    prob_ratio = prob_ratio ** undersampling_coef
    prob_ratio = tf.minimum(prob_ratio, 1.0)

    acceptance = tf.less_equal(tf.random_uniform([], dtype=tf.float32), prob_ratio)

    return acceptance


dataset = dataset.flat_map(
    lambda x: tf.data.Dataset.from_tensors(x).repeat(oversample_classes(x))
)

dataset = dataset.filter(undersampling_filter)

dataset = dataset.repeat(-1)
dataset = dataset.shuffle(2048)
dataset = dataset.batch(32)

sess.run(tf.global_variables_initializer())

iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()

更新#1

这是一个简单的Jupyter笔记本它在玩具模型上实现了上述过采样/欠采样。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Tensorflow 数据集 API 中的过采样功能 的相关文章

随机推荐

  • Angular2 传递函数作为组件输入不起作用

    我有一个以函数作为输入的组件 我已经从父级传递了这个函数 尽管调用了该函数 但该函数无法访问声明该函数的实例的依赖项 这是组件 Component selector custom element template val export cl
  • WCF ChannelFactory 与生成代理

    只是想知道当您可以使用 ChannelFactory 调用时 在什么情况下您更愿意从 WCF 服务生成代理 这样你就不必生成代理并担心服务器更新时重新生成代理了 Thanks 创建 WCF 客户端有 3 种基本方法 让 Visual Stu
  • 限制ManyToManyField的最大选择

    我试图限制模型记录在 ManyToManyField 中可以拥有的最大选择数量 在此示例中 有一个可以与区域相关的博客站点 在此示例中 我想将博客站点限制为只能有 3 个区域 这似乎是以前被问过 回答过的问题 但经过几个小时的探索后 我还没
  • Helm 图表之间的依赖关系是否应该反映微服务之间的依赖关系?

    给定以下服务方案及其依赖项 我想设计一组 Helm 图表 API Gateway calls Service A and Service C Service A calls Service B Service B calls Databas
  • ASP.NET Identity 2.0:如何重新哈希密码

    我正在 ASP NET 5 0 Web 应用程序中将用户从旧用户存储迁移到 ASP NET Identity 2 0 我有一种验证旧哈希值的方法 但我想在登录时将它们升级到 ASP NET Identity 2 0 哈希值 我创建了一个自定
  • 使用 MemoryStream 写入 XML

    我注意到有两种不同的方法将数据写入 XML 文件 为简洁起见 省略了错误处理 第一种方法是构建 XML 文档 然后将 XML 保存到文件中 using XmlWriter writer XmlWriter Create fileName w
  • 如何在makefile配方中设置环境变量?

    这是一个简化的 Makefile all for i 0 i lt 5 i do var var i echo var done echo var 我认为 var 的值是 0 1 2 3 4 但输出是 0 0 1 0 1 2 0 1 2 3
  • 如何实现CoreData记录的重新排序?

    我在 iPhone 应用程序中使用 CoreData 但 CoreData 不提供允许您对记录重新排序的自动方法 我想过使用另一列来存储订单信息 但是使用连续数字作为排序索引有问题 如果我正在处理大量数据 重新排序记录可能涉及更新排序信息上
  • Play 框架如何运作?

    我喜欢玩 与其他企业 Java 框架相比 它对于开发人员来说使用起来非常简单 但是 它是如何做到的呢 是什么让像 Java 这样的编译语言能够实现编辑 刷新循环 是什么让 Play 按其工作方式工作 Play 使用 Eclipse 编译器在
  • 在 PL/SQL 中打印记录字段

    如何在 PL SQL 中打印记录变量的所有字段 记录变量有很多字段 那么有没有比打印每个字段更好的方法呢 也尝试过动态sql但没有帮助 基于 Ollies 使用 dbms output 构建 但用于动态遍历光标 设置用于测试 create
  • ASP.NET MVC 3 - 在 jquery 对话框中编辑动态添加到模型集合的项目

    我是 MVC 新手 所以我不确定这里最好的方法是什么 我有一个视图模型 其中包含几个像这样的集合 public class MainViewModel public List
  • iPhone Web 应用程序可以使用相机吗?

    我有一个网络应用程序 我想拍照然后将它们上传到服务器 这可以通过网络应用程序完成吗 编辑 现在可以了 请参阅下面的答案 不可以 webapp 无法访问内部设备 尝试使用 PhoneGap 来缩小您的应用程序和内部设备之间的差距 但这将编译一
  • 通过列表和数组中的索引获取结构体项目

    当我使用数组时structs 例如 System Drawing Point 我可以通过索引获取项目并更改它 例如 此代码工作正常 Point points new Point new Point 0 0 new Point 1 1 new
  • 寻找曲线上的最佳权衡点

    假设我有一些数据 我想为其拟合参数化模型 我的目标是找到该模型参数的最佳值 我正在使用AIC BIC MDL奖励低误差模型的标准类型 但也会惩罚高复杂性的模型 可以说 我们正在为这些数据寻找最简单但最令人信服的解释 a la奥卡姆剃刀 根据
  • 如何在不删除 R 中存在 NA 的行的情况下执行聚类

    我有一个数据 其元素中包含一些 NA 值 我想做的是执行聚类而不删除行NA 存在的地方 我明白那个gower距离测量单位daisy允许这种情况 但为什么我下面的代码不起作用 我欢迎 雏菊 以外的其他选择 plot heat map with
  • Flutter Workmanager 插件在运行任务时无法与任何其他插件一起使用

    初始化工作管理器并创建任一任务后 如果我们在任务执行中使用任何插件 它将无法被识别并抛出如下错误 MissingPluginException 在通道 lyokone location 上找不到方法 getLocation 的实现 实际代码
  • 为什么 stdafx.h 会这样工作?

    像往常一样 当我的大脑搞乱了我自己无法弄清楚的事情时 我会向你们寻求帮助 这次我一直想知道为什么 stdafx h 会这样工作 据我了解 它做了两件事 包括我们的标准标头might 使用并且很少改变 作为编译器书签 代码不再预编译 现在 这
  • BOOST_CHECK_EQUAL 带有pair 和自定义运算符<<

    当尝试执行 BOOST CHECK EQUAL pair pair 时 尽管声明了它 但 gcc 找不到pair的流运算符 有趣的是 std out 找到了运算符 ostream operator lt lt ostream s const
  • 检测类型是否是主模板的专业化或用户提供的专业化

    假设我有这个 template
  • Tensorflow 数据集 API 中的过采样功能

    我想问一下目前的数据集API是否允许实现过采样算法 我处理高度不平衡的阶级问题 我认为在数据集解析 即在线生成 过程中对特定类进行过采样会很好 我已经看到了rejection resample函数的实现 但是这会删除样本而不是复制它们 并且