Tensorflow Keras 保留每批的损失

2024-01-05

我正在寻找构建 keras 模型架构的最佳实践的建议/示例。

我一直在摸索 Model() 子类和功能模型的各种迭代,但无法连接所有点。

该模型应具有自定义指标和相关损失,其中:在训练期间,按批次计算指标,并在每个时期结束时根据批次计算的平均值计算最终指标/损失。

根据我的想法,我需要一个Custom_Batch_Metric()除了一个之外,还将维护每个批次的指标列表Custom_Final_Metric()这将对每批结果进行平均。我不知道如何实现这个。

例如......每批我想生成一个metric and loss用于 y_true、y_pred 的相关性。在我的纪元结束时,我想对批次相关性进行平均(或找到最大值)。

如果有人能向我推荐有关此类架构的任何书籍,我将非常感激。


一个简单的解决方案是子类化tf.keras.callbacks.Callback并定义on_train_batch_end(或测试)。然后还有on_epoch_end.

class SaveBatchLoss(tf.keras.callbacks.Callback):
    def on_train_batch_end(self, batch, logs=None): 
        batch_end_loss.append(logs['loss'])
import tensorflow as tf
from tensorflow.keras.layers import Dense
from sklearn.datasets import load_iris
import numpy as np

X, y = load_iris(return_X_y=True)
X = X.astype(np.float32)

ds = tf.data.Dataset.from_tensor_slices((X, y)).shuffle(25).batch(8)

model = tf.keras.Sequential([
    Dense(16, activation='relu'),
    Dense(32, activation='relu'),
    Dense(3, activation='softmax')])

model.compile(loss='sparse_categorical_crossentropy', optimizer='adam', 
              metrics=['accuracy'])

batch_end_loss = list()

class SaveBatchLoss(tf.keras.callbacks.Callback):
    def on_train_batch_end(self, batch, logs=None):
        batch_end_loss.append(logs['loss'])

history = model.fit(ds, epochs=10, callbacks=SaveBatchLoss())

batch_end_loss[::20]
[1.2742226123809814,
 0.9069833755493164,
 0.9728888869285583,
 0.9536505937576294,
 0.8957988023757935,
 0.8624499440193176,
 0.7952101826667786,
 0.7765023112297058,
 0.7615134716033936,
 0.7278715968132019]

作为一种更复杂的方法,将损失值附加到每个批次末尾的列表中,以及每个 opoch 末尾的另一个列表中。像这样:

train_loss_per_train_batch = list()
train_loss_per_train_epoch = list()

for epoch in range(1, 25 + 1): 
    train_loss = tf.metrics.Mean()
    train_acc = tf.metrics.SparseCategoricalAccuracy()
    test_loss = tf.metrics.Mean()
    test_acc = tf.metrics.SparseCategoricalAccuracy()

    for x, y in train:
        loss_value, grads = get_grad(model, x, y)
        optimizer.apply_gradients(zip(grads, model.trainable_variables))
        train_loss.update_state(loss_value)
        train_acc.update_state(y, model(x, training=True))
        train_loss_per_train_batch.append(loss_value.numpy())
    
    train_loss_per_train_epoch.append(train_loss.result())

实现此功能的完整训练脚本将是:

import tensorflow as tf
import tensorflow_datasets as tfds

ds = tfds.load('iris', split='train', as_supervised=True)

train = ds.take(125).shuffle(16).batch(4)
test = ds.skip(125).take(25).shuffle(16).batch(4)

class MyModel(tf.keras.Model):
    def __init__(self):
        super(MyModel, self).__init__()
        self.d1 = tf.keras.layers.Dense(16, activation='relu')
        self.d2 = tf.keras.layers.Dense(32, activation='relu')
        self.d3 = tf.keras.layers.Dense(3, activation='softmax')

    def call(self, x, training=None, **kwargs):
        x = self.d1(x)
        x = self.d2(x)
        x = self.d3(x)
        return x

model = MyModel()

loss_object = tf.losses.SparseCategoricalCrossentropy(from_logits=False)


def compute_loss(model, x, y, training):
  out = model(x, training=training)
  loss = loss_object(y_true=y, y_pred=out)
  return loss


def get_grad(model, x, y):
    with tf.GradientTape() as tape:
        loss = compute_loss(model, x, y, training=True)
    return loss, tape.gradient(loss, model.trainable_variables)


optimizer = tf.optimizers.Adam()

verbose = "Epoch {:2d} Loss: {:.3f} TLoss: {:.3f} Acc: {:.2%} TAcc: {:.2%}"

train_loss_per_train_batch = list()
train_loss_per_train_epoch = list()

for epoch in range(1, 25 + 1):
    train_loss = tf.metrics.Mean()
    train_acc = tf.metrics.SparseCategoricalAccuracy()
    test_loss = tf.metrics.Mean()
    test_acc = tf.metrics.SparseCategoricalAccuracy()

    for x, y in train:
        loss_value, grads = get_grad(model, x, y)
        optimizer.apply_gradients(zip(grads, model.trainable_variables))
        train_loss.update_state(loss_value)
        train_acc.update_state(y, model(x, training=True))
        train_loss_per_train_batch.append(loss_value.numpy())

    train_loss_per_train_epoch.append(train_loss.result())

    for x, y in test:
        loss_value, _ = get_grad(model, x, y)
        test_loss.update_state(loss_value)
        test_acc.update_state(y, model(x, training=False))

    print(verbose.format(epoch,
                         train_loss.result(),
                         test_loss.result(),
                         train_acc.result(),
                         test_acc.result()))
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Tensorflow Keras 保留每批的损失 的相关文章

随机推荐

  • DataTables 使用您自己的参数读取 ajax 响应

    在服务器端模式下使用 DataTables 1 10 15 我创建了一个 PHP 脚本来提供 JSON 响应 其中包括他们在文档中提到的参数 https datatables net manual server side Returned
  • 我们可以使用 spring Batch 顺序处理多个文件,同时使用多个线程处理单个文件数据..?

    我想按顺序处理多个文件 并且每个文件都需要在多个线程的帮助下处理 因此使用了 Spring Batch FlatFileItemReader 和 TaskExecutor 它似乎对我来说工作得很好 正如需求中提到的 我们必须处理多个文件 因
  • 如何在同一个图表上绘制多条带有偏移的曲线

    我从示波器读取波形 波形根据时间分为 10 段 我想绘制完整的波形 一个段在另一个段之上 或之下 可以这么说 具有垂直偏移 此外 还需要彩色图来显示信号强度 我只能得到以下情节 正如您所看到的 所有曲线都是叠加的 这是不可接受的 人们可以向
  • JSONArray 中的 POST 字符串和响应

    其实我关注了很多堆栈溢出 https stackoverflow com questions 28656865 send a jsonarray post request with android volley library答案与我的问题
  • 为什么这个 cppreference 摘录似乎错误地表明原子可以保护关键部分?

    int main std vector
  • 带有圆角图像的自定义 UIProgressView

    我在定制时遇到困难UIProgressView有两个图像 我在互联网上发现了很多有用的问题 但我的问题有点不同 也许我使用的图像在非角矩形上是圆角的 所以 stretchableImageWithLeftCapWidth方法似乎对我的情况没
  • sql - 两行之间的查询

    有疑问 我正在做一个选择 我需要抓取 2 行 我的值为 13000 00000 我需要抓住第 2 行和第 3 行 因为它 介于 10000 最小范围 和 15000 最小范围 之间 该语句仅引入第 2 行 select from TABLE
  • lambda 找不到我的 node_modules

    我正在尝试使用 lambda 上传 node modules 但我得到了 Cannot find module error 我已经设置了一个真正简单的 hello world js 文件 var async require async 我已
  • 如何在 Spark Scala 中将 null NAN 或 Infinite 值替换为默认值

    我正在将 csv 读入 Spark 并将架构设置为所有 DecimalType 10 0 列 当我查询数据时 出现以下错误 NumberFormatException Infinite or NaN 如果我的数据框中有 NaN null i
  • 使用 AJAX.NET 的 $get() 和 $find()

    我正在尝试遵循找到的 PageMethods 示例here http encosia com 2007 07 11 why aspnet ajax updatepanels are dangerous 但是 我在尝试调用时收到错误 get
  • Restlet HTTP 连接池

    我对 Restlet 相当陌生 编写了一小段代码来进行 HTTP 调用 它正在工作 但我想知道如何将 HTTP 连接池 apache 添加到其中 我找不到任何教程或参考代码 Client client new Client Protocol
  • 在 Java JFrame 中显示图像

    在 java JFrame 中的特定坐标处显示图像的最佳方法是什么 我知道有很多方法可以做到这一点 我只需要知道显示我计划在框架中移动的图像的最佳方式 将 ImageIcon 与 JLabel 结合使用是最简单的方法 实际上 您可以根据您的
  • 在 3d 中绘制 3 个向量

    我有 3 个向量 其中一个向量的角度为Phi 另一个角度为Teta 最后一个是点向量Y axe 计算完点后Teta Phi有一个功能 for teta 0 10 2 pi 2 for phi 0 10 2 pi 2 Y current v
  • android 蓝牙连接失败(isSocketAllowedBySecurityPolicy start : device null)

    我试图用蓝牙连接两部手机 galaxy note 1 galaxy note 2 但套接字连接失败 这是我的 LogCat I BluetoothService 24036 BEGIN mConnectThread D BluetoothU
  • 从左侧滑入CSS动画

    我想制作一个简单的动画 当页面加载时 我的徽标应该从框的左侧动画到右侧 我尝试了很多版本 但还没有成功 HTML div img src logo png alt logo style width 170px height 120px di
  • Kafka:使用java更改特定主题的分区数量

    我是 Kafka 新手 正在使用新的 KafkaProducer 和 KafkaConsumer 版本 0 9 0 1 java中是否有任何方法可以在创建特定主题后更改 更新其分区数量 我没有使用 Zookeeper 创建主题 当发布请求到
  • JavaScript 函数通过链式组合

    我检查了重复问题的可能性 并且无法找到准确的解决方案 我用 JavaScript 编写了一些函数链代码 如下所示 并且工作正常 var log function args console log args return function f
  • 浏览器同步无法用 gulp 重新加载

    我正在尝试按如下方式吞咽浏览器同步 var liveReload require browser sync create reload when something changes once scss is converted to css
  • 对话框的 Url 不适用于 angular.bootstrap (无限 $digest 循环)

    我有一个平均堆栈网站 我想用执行函数 https stackoverflow com a 45428344 702977绑定一个按钮以在对话框中启动该网站 function doSomethingAndShowDialog event cl
  • Tensorflow Keras 保留每批的损失

    我正在寻找构建 keras 模型架构的最佳实践的建议 示例 我一直在摸索 Model 子类和功能模型的各种迭代 但无法连接所有点 该模型应具有自定义指标和相关损失 其中 在训练期间 按批次计算指标 并在每个时期结束时根据批次计算的平均值计算