如何将 TensorFlow (v. 2) Hub 中预训练的 KerasLayer 与 tfrecords 结合起来?

2023-12-25

我有一个包含 23 个类的 tfrecord,每个类有 35 张图像(总共 805 张)。我当前的 tfrecord 读取函数是:

def read_tfrecord(serialized_example):
 feature_description = {
    'image': tf.io.FixedLenFeature((), tf.string),
    'label': tf.io.FixedLenFeature((), tf.int64),
    'height': tf.io.FixedLenFeature((), tf.int64),
    'width': tf.io.FixedLenFeature((), tf.int64),
    'depth': tf.io.FixedLenFeature((), tf.int64)
 }

 example = tf.io.parse_single_example(serialized_example, feature_description)
 image = tf.io.parse_tensor(example['image'], out_type=float)
 image_shape = [example['height'], example['width'], example['depth']]
 image = tf.reshape(image, image_shape)
 label = tf.cast(example["label"], tf.int32)
 image = image/255

 return image, label

然后我有一个 make_dataset 函数,如下所示:

def make_dataset(tfrecord, BATCH_SIZE, EPOCHS, cache=True):
 files = tf.data.Dataset.list_files(os.path.join(os.getcwd(), tfrecord))
 dataset = tf.data.TFRecordDataset(files)

 if cache:
    if isinstance(cache, str):
      dataset = dataset.cache(cache)
    else:
      dataset = dataset.cache()

 dataset = dataset.shuffle(buffer_size=FLAGS.shuffle_buffer_size)
 dataset = dataset.map(map_func=read_tfrecord, num_parallel_calls=AUTOTUNE)
 dataset = dataset.repeat(EPOCHS)
 dataset = dataset.batch(batch_size=BATCH_SIZE)
 dataset = dataset.prefetch(buffer_size=AUTOTUNE)

 return dataset

这个 make_dataset 函数被传递到

train_ds = make_dataset(tfrecord=FLAGS.tf_record, BATCH_SIZE=BATCH_SIZE, EPOCHS=EPOCH)
image_batch, label_batch = next(iter(train_ds))
feature_extractor_layer = hub.KerasLayer(url, input_shape=IMAGE_SHAPE + (3,)) 
feature_batch = feature_extractor_layer(image_batch)
feature_extractor_layer.trainable = False
model = tf.keras.Sequential([feature_extractor_layer, layers.Dense(2048, input_shape=(2048,)), layers.Dense(len(CLASS_NAMES), activation='softmax')])

model.summary()
predictions = model(image_batch)
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=rate),
                              loss='categorical_crossentropy',
                              metrics=['acc'])

batch_stats_callback = CollectBatchStats()
STEPS_PER_EPOCH = np.ceil(image_count / BATCH_SIZE)
history = model.fit(image_batch, label_batch, epochs=EPOCH, batch_size=BATCH_SIZE, steps_per_epoch=STEPS_PER_EPOCH, callbacks=[batch_stats_callback])

这段代码的运行意义在于,它输出有关我有多少个 epoch 的常用信息以及一些训练精度数据(为 0,损失约为 100k)。我得到的错误对我来说没有任何意义,因为它说:函数实例化在外部推理上下文中的索引:100 处具有未定义的输入形状。您可以将该数字替换为 1000 以下的任何数字(不确定它是否超过了我 tfrecord 中的图像数量)。

我对这个完全不知所措。

EDIT:

看来我收到的这个“错误”只不过是一条警告消息。我怀疑这与 TensorFlow Hub 的使用和潜在的 eagerexecution 有关。我添加了

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

在文件的开头,警告已经消失。


正如 kelkka 所指定的,这不是一个错误,而只是一个警告。

在程序开始时添加以下代码行可以限制打印警告消息。

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

该环境变量的其他值及其行为如下所述:

  • 0 = 记录所有消息(默认行为)
  • 1 = 不打印 INFO 消息
  • 2 = 不打印信息和警告消息
  • 3 = 不打印信息、警告和错误消息

有关控制警告消息的详细程度的更多信息,请参阅此堆栈溢出答案 https://stackoverflow.com/questions/35869137/avoid-tensorflow-print-on-standard-error.

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

如何将 TensorFlow (v. 2) Hub 中预训练的 KerasLayer 与 tfrecords 结合起来? 的相关文章

随机推荐

  • 如何使用 django 、 Location.objects.all() 获取第一个元素和最后一个元素

    这是我的代码 obj list Location objects all first element obj list 0 last element obj list 1 then return render to response tem
  • 如何使用 Python 在 Seaborn 中保存绘图 [重复]

    这个问题在这里已经有答案了 我有一个 Pandas 数据框并尝试将绘图保存在 png 文件中 然而 似乎有些事情并没有按预期进行 这是我的代码 import pandas import matplotlib pyplot as plt im
  • 我们可以调用 va_start() 两次而不调用 va_end() 吗?

    这是我的最小示例 include
  • Jenkins Pipeline 特定阶段的触发器

    我从一开始就在使用 Jenkins 但我想做点什么 但我找不到如何做 事实上 我想用两种不同的方式触发我的项目 每 4 小时和每次提交 但对于每种情况 我不希望执行所有 Jenkinsfile 只执行某些特定阶段 是否可以使用声明式管道来做
  • 使用 Git 和 Heroku 进行正确的持续集成和持续部署

    我正在使用 heroku 和 git 开发一个 ruby on Rails 网站 我应该使用哪些工具和功能来建立以下简单的开发流程 代码 gt 签入 gt 自动测试 gt 自动部署 我将代码签入我的存储库 首选选项 托管 git 如 git
  • 确定 sprintf 缓冲区大小 - 标准是什么?

    当像这样转换 int 时 char a 256 sprintf a d 132 确定多大的最佳方法是什么a应该 我认为手动设置它是可以的 因为我已经看到它到处使用 但它应该有多大 32 位系统上可能的最大 int 值是多少 是否有一些棘手的
  • 如何在 Chart.js 中循环工具提示附加数据

    这里我有一个图表 其中包含来自数据库表的 x 轴数据和 y 轴数据 现在我面临的问题是 无论我尝试将第三个数据附加到afterbody工具提示中的回调函数 它将在每个工具提示中显示完整数据 但我想分别将这些数据附加到每个工具提示中 like
  • 主动存储种子Rails

    我想为我的数据库添加一些包含活动存储附件的实例 但我不知道如何做到这一点 我尝试了一些方法但没有成功 这是我的种子 User create email email protected cdn cgi l email protection p
  • DMCS 中的 D 代表什么?

    所以 我正在阅读有关单声道 C 编译器 http www mono project com docs about mono languages csharp 我知道这些应用程序的用途是什么 但我只是想知道缩写代表什么 另外 gmcs smc
  • Android hprov-dump 给我错误:期待 1.0.3

    我在 eclipses DDMS 中使用了转储 HPROF 文件选项 并将我的 hprof 文件命名为 in hprof 但是当我尝试执行以下操作时hprov conf in hprof out hprof从命令行它给我错误 错误 期待 1
  • 使文本块只读

    目前我在滚动查看器控件中放置了一个文本块 如何使文本块只读 文本块 http msdn microsoft com en us library system windows controls textblock aspx已经是只读的 它们旨
  • 为什么会出现错误 ORA-00937

    对于每名获得三架以上飞机认证的飞行员 找到 援助和他 或她 所乘坐的飞机的最大航程 认证为 我有四张桌子 FLIGHTS flno varchar 出发地 varchar 目的地 varchar 距离 整数 出发 日期 到达 日期 飞机 a
  • Android httpclient cookie拒绝非法路径属性

    我正在构建一个 Android 应用程序 它使用 httpclient 将数据发布到 WordPress 服务器并检索数据 由于 cookie 中的路径无效 我无法发送发布数据 这是我检索到的日志 Cookie rejected Basic
  • 高效解析大型 JSON 数组的前四个元素

    我在用Jackson从 json 解析 JSONinputStream如下所示 36 100 The 3n 1 problem 56717 0 1000000000 0 6316 0 0 88834 0 45930 0 46527 5209
  • 如何复制图像?

    我想复制image png form folder1 to folder2 怎么做 folder1 image png folder2 Thanks 尝试这样的事情 var fs require fs var inStr fs create
  • PHP包含html页面字符集问题

    使用下面的代码查询 mysql 数据库后 我生成了一个 html 文件 myFile page htm fh fopen myFile w 或 die 无法打开文件 fwrite fh row 文本 fclose fh 在 mysql 数据
  • 如何运行内存中下载的文件? [复制]

    这个问题在这里已经有答案了 可能的重复 使用 C 加载 EXE 文件并从内存中运行它 https stackoverflow com questions 3553875 load an exe file and run it from me
  • 从 C++ 中查找 python 函数参数

    我正在从 C 调用 python 函数 我想知道是否可以确定参数的数量和这些参数的名称 我已阅读链接如何从 C 语言中查找 Python 函数的参数数量 https stackoverflow com questions 1117164 h
  • 如何从最小最大算法中获取实际移动而不是移动值

    我目前正在为国际象棋编写一个带有 alpha beta 剪枝的极小极大算法 从我见过的所有示例中 极小极大算法将返回一个 int 值 该值表示最佳得分或最佳移动所产生的棋盘状态 我的问题是我们如何返回与分数返回值相关的最佳动作 例如 下面的
  • 如何将 TensorFlow (v. 2) Hub 中预训练的 KerasLayer 与 tfrecords 结合起来?

    我有一个包含 23 个类的 tfrecord 每个类有 35 张图像 总共 805 张 我当前的 tfrecord 读取函数是 def read tfrecord serialized example feature description