使用 keras 可以在训练过程中实时获取输出层吗?

2023-12-12

我尝试在训练期间获得输出层。我正在尝试对模型进行实时 3D 可视化并使其具有交互性。我正在使用谷歌colab与tensorflow 2.0和python 3。

这是我的代码:

Imports

  from __future__ import absolute_import, division, print_function, unicode_literals
 try:
   # Use the %tensorflow_version magic if in colab.
     %tensorflow_version 2.x
 except Exception:
       pass

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

import tensorflow_hub as hub
import tensorflow_datasets as tfds

from tensorflow.keras import datasets, layers, models

from tensorflow.keras import backend as K
from tensorflow.keras.backend import clear_session

from tensorflow.keras.callbacks import Callback as Callback

import logging
logger = tf.get_logger()
logger.setLevel(logging.ERROR)

Get data

splits = tfds.Split.TRAIN.subsplit([70, 30])

(training_set, validation_set), dataset_info = tfds.load('tf_flowers',with_info=True, as_supervised=True, split=splits)
 for i, example in enumerate(training_set.take(5)):
        print('Image {} shape: {} label: {}'.format(i+1, example[0].shape, example[1]))

检查类和图像的数量

 num_classes = dataset_info.features['label'].num_classes

 num_training_examples = 0
 num_validation_examples = 0

 for example in training_set:
   num_training_examples += 1

 for example in validation_set:
   num_validation_examples += 1

 print('Total Number of Classes: {}'.format(num_classes))
 print('Total Number of Training Images: {}'.format(num_training_examples))
 print('Total Number of Validation Images: {} \n'.format(num_validation_examples))

开始

   IMAGE_RES = 299
   BATCH_SIZE = 32
def format_image(image, label):
   image = tf.image.resize(image, (IMAGE_RES, IMAGE_RES))/255.0
  return image, label

 (training_set, validation_set), dataset_info = tfds.load('tf_flowers', with_info=True, as_supervised=True, split=splits)
  train_batches = training_set.shuffle(num_training_examples//4).map(format_image).batch(BATCH_SIZE).prefetch(1)
    validation_batches = validation_set.map(format_image).batch(BATCH_SIZE).prefetch(1)

URL = "https://tfhub.dev/google/tf2-preview/inception_v3/feature_vector/4"
feature_extractor = hub.KerasLayer(URL,
  input_shape=(IMAGE_RES, IMAGE_RES, 3),
trainable=False)

model_inception = tf.keras.Sequential([
feature_extractor,
layers.Dense(num_classes, activation='softmax')
])

 model_inception.summary()

这是自定义回调,我尝试在训练期间获取输出层

    import datetime
 from keras.callbacks import Callback

class MyCustomCallback(tf.keras.callbacks.Callback):

  def on_train_batch_begin(self, batch, logs=None):
     print('Training: batch {} begins at {}'.format(batch, datetime.datetime.now().time()))

  def on_train_batch_end(self, batch, logs=None):
     for i in range(len(model_inception.layers)):
      inp = self.model.input                                    # input placeholder
      outputs = [layer.output for layer in self.model.layers]     # all layer outputs
      functors = [K.function([inp, K.learning_phase()], [out]) for out in outputs]    # evaluation functions
      input_shape = [1] + list(self.model.input_shape[1:])
      test = np.random.random(input_shape)
      layer_outs = [func([test, 1.]) for func in functors] 
      print('\n Training: batch {} ends at {}'.format( layer_outs , datetime.datetime.now().time()))

  def on_test_batch_begin(self, batch, logs=None):
    print('Evaluating: batch {} begins at {}'.format(batch, datetime.datetime.now().time()))

  def on_test_batch_end(self, batch, logs=None):
   # layer_output = get_3rd_layer_output(self.validation_data)[0]  
    print('Training: batch {} ends at {} with the output layer {}'.format(batch, datetime.datetime.now().time()))

 The problem is in callback of how i can get the output/input of each layer at the end of each batch

这是使用我的自定义回调进行的模型编译和训练

 model_inception.compile(
  optimizer='adam', 
  loss='sparse_categorical_crossentropy',
  metrics=['accuracy'])

 EPOCHS = 2

 history = model_inception.fit(train_batches,
                epochs=EPOCHS,
                steps_per_epoch=20,
                validation_data=validation_batches,callbacks=[MyCustomCallback()])

当我尝试运行它时出现当前错误

AttributeError                            Traceback (most recent call last)
<ipython-input-10-5909c67ba93f> in <module>()
      9                     epochs=EPOCHS,
     10                     steps_per_epoch=20,
---> 11                     validation_data=validation_batches,callbacks=[MyCustomCallback()])
     12 
     13 # #Testing

11 frames
/tensorflow-2.0.0/python3.6/tensorflow_core/python/eager/lift_to_graph.py in <listcomp>(.0)
 247   # Check that the initializer does not depend on any placeholders.
 248   sources = object_identity.ObjectIdentitySet(sources or [])
-->249   visited_ops = set([x.op for x in sources])
 250   op_outputs = collections.defaultdict(set)
 251 

AttributeError: 'int' object has no attribute 'op'

如果您阅读了自定义回调的源代码,here

有一个财产model对于我们定义的每个自定义回调。

您可以在 Cutomcallback 中定义的函数内使用模型对象。

例如,

def on_train_batch_end(self, batch, logs=None):
    #here you can get the model reference. 
    self.model.predict(dummy_data)

self.model 是 keras.models.Model 的实例,可以使用它调用相应的函数。

更多参考可以找到here and here

请关注以下评论以获得答案。

[EDIT 1]

OP注释中的代码段

def on_train_batch_end(self, batch, logs=None): 
    for i in range(len(model_inception.layers)): 
        get_layer_output = K.function(inputs = self.model.layers[i].input, outputs = self.model.layers[i].output) 
        print('\n Training: output of the layer {} is {} ends at {}'.format(i, get_layer_output.outputs , datetime.datetime.now().time()))
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

使用 keras 可以在训练过程中实时获取输出层吗? 的相关文章

  • 如何使用 Tensorflow-GPU 和 Keras 修复低易失性 GPU-Util?

    我有一台 4 GPU 机器 在上面运行带有 Keras 的 Tensorflow GPU 我的一些分类问题需要几个小时才能完成 nvidia smi returns Volatile GPU Util which never exceeds
  • 通过 python 中的另外两个修改数组[重复]

    这个问题在这里已经有答案了 假设我们有三个一维数组 A 长度为 5 B 长度相同 示例中为5 C 更长 比如长度为 100 C最初用零填充 A给出索引C应更改的元素 它们可能会重复 以及B给出应添加到初始零的值C 例如 如果A 1 3 3
  • 键入的完整命令行

    我想获得输入时的完整命令行 This join sys argv 在这里不起作用 删除双引号 另外 我不想重新加入已解析和拆分的内容 有任何想法吗 你太迟了 当键入的命令到达 Python 时 您的 shell 已经发挥了它的魔力 例如 引
  • 为什么在连接两个字符串时 Python 比 C 更快?

    目前我想比较 Python 和 C 用来处理字符串的速度 我认为 C 应该比 Python 提供更好的性能 然而 我得到了完全相反的结果 这是 C 程序 include
  • 将 Python Pandas DataFrame 写入 Word 文档

    我正在努力创建一个使用 Pandas DataFrames 的 Python 生成的报告 目前我正在使用DataFrame to string 方法 但是 这会作为字符串写入文件 有没有办法让我实现这一目标 同时将其保留为表格 以便我可以使
  • 使用 NumPy 编写一个函数来计算具有特定公差的积分

    我想编写一个自定义函数来以特定容差对表达式 python 或 lambda 函数 进行数字积分 我知道与scipy integrate quad人们可以简单地改变epsabs但我想使用 numpy 自己编写该函数 From 这篇博文 htt
  • django 模板 - 如何动态访问变量?

    假设我有一个具有以下上下文的 django 模板 data1 this is data1 data2 this is data2 data name data2 现在我知道了data name 假设它是 data2 是否可以用它来访问变量d
  • 查找正在导入哪些 python 模块

    从应用程序中使用的特定包中查找所有 python 模块的简单方法是什么 sys modules是将模块名称映射到模块的字典 您可以检查其键以查看导入的模块 See http docs python org library sys html
  • Django 多对多关系(类别)

    我的目标是向我的 Post 模型添加类别 我希望以后能够按不同类别 有时是多个类别 查询所有帖子 模型 py class Category models Model categories 1 red 2 blue 3 black title
  • Python - Unicode 到 ASCII 的转换

    我无法在不丢失数据的情况下将以下 Unicode 转换为 ASCII u ABRA xc3O JOS xc9 I tried encode and decode他们不会这么做 有人有建议吗 Unicode 字符u xce0 and u xc
  • ValueError:数据必须为正(boxcox scipy)

    我正在尝试将我的数据集转换为正态分布 0 8 298511e 03 1 3 055319e 01 2 6 938647e 02 3 2 904091e 02 4 7 422441e 02 5 6 074046e 02 6 9 265747e
  • 在请求中设置端口

    我正在尝试利用cgminer使用 Python 的 API 我对利用requests图书馆 我了解如何做基本的事情requests but cgminer想要更具体一点 我想缩小 import socket import json sock
  • Python:在字典中查找具有唯一值的键?

    我收到一个字典作为输入 并且想要返回一个键列表 其中字典值在该字典的范围内是唯一的 我将用一个例子来澄清 假设我的输入是字典 a 构造如下 a dict a cat 1 a fish 1 a dog 2 lt unique a bat 3
  • 是否可以在Python中将日+月(不是年)与当前日+月进行比较?

    我正在获取 5 月 10 日 格式的数据 我试图弄清楚它是今年还是明年 该日期仅一年 因此 5 月 10 日表示 2015 年 5 月 10 日 而 5 月 20 日表示 2014 年 5 月 20 日 为此 我想将字符串转换为日期格式并进
  • 如何在matplotlib中调整x轴

    I have a graph like this x轴上的数据表示小时 所以我希望x轴设置为0 24 48 72 而不是现在的值 很难看到 0 100 之间的数据 fig1 plt figure ax fig1 add subplot 11
  • 从迭代器外部将 StopIteration 发送到 for 循环

    有几种方法可以打破一些嵌套循环 他们是 1 使用中断 继续 for x in xrange 10 for y in xrange 10 print x y if x y gt 50 break else continue only exec
  • 无需访问 Internet 即可部署 Django 的简单方法?

    我拥有的是使用 Django 开发的 Intranet 站点的开发版本以及放置在 virtualenv 中的一些外部库 它运行良好 我可以在任何具有互联网连接的计算机上使用相同的参数 使用 pip 轻松设置 virtualenv 但是 不幸
  • pandas.read_fwf 忽略提供的数据类型

    我正在从文本文件导入数据框 我想指定列的数据类型 但 pandas 似乎忽略了dtype input 一个工作示例 from io import StringIO import pandas as pd string USAF WBAN S
  • 在Python中停止ThreadPool中的进程

    我一直在尝试为控制某些硬件的库编写一个交互式包装器 用于 ipython 有些调用对 IO 的影响很大 因此并行执行任务是有意义的 使用 ThreadPool 几乎 效果很好 from multiprocessing pool import
  • Biopython 可以执行 Seq.find() 来解释歧义代码吗

    我希望能够在 Seq 对象中搜索考虑歧义代码的子序列 Seq 对象 例如 以下内容应该是正确的 from Bio Seq import Seq from Bio Alphabet IUPAC import IUPACAmbiguousDNA

随机推荐