Tensorflow 2.0 中 KerasLayer 的 TimeDistributed

2024-02-01

我正在尝试使用来自tensorflow-hub的预训练模型构建CNN + RNN:

base_model = hub.KerasLayer('https://tfhub.dev/google/imagenet/resnet_v2_50/feature_vector/4', input_shape=(244, 244, 3)
base_model.trainable = False

model = Sequential()
model.add(TimeDistributed(base_model, input_shape=(15, 244, 244, 3)))
model.add(LSTM(512))
model.add(Dense(256, activation='relu'))
model.add(Dense(3, activation='softmax'))

adam = Adam(learning_rate=learning_rate)
model.compile(loss='categorical_crossentropy' , optimizer=adam , metrics=['accuracy'])
model.summary()

这是我得到的:

2020-01-29 16:1

6:37.585888: I tensorflow/core/platform/profile_utils/cpu_utils.cc:94] CPU Frequency: 2494000000 Hz
2020-01-29 16:16:37.586205: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0x3b553f0 executing computations on platform Host. Devices:
2020-01-29 16:16:37.586231: I tensorflow/compiler/xla/service/service.cc:175]   StreamExecutor device (0): Host, Default Version
Traceback (most recent call last):
  File "./RNN.py", line 45, in <module>
    model.add(TimeDistributed(base_model, input_shape=(None, 244, 244, 3)))
  File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/training/tracking/base.py", line 457, in _method_wrapper
    result = method(self, *args, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/engine/sequential.py", line 178, in add
    layer(x)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/engine/base_layer.py", line 842, in __call__
    outputs = call_fn(cast_inputs, *args, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/layers/wrappers.py", line 256, in call
    output_shape = self.compute_output_shape(input_shape).as_list()
  File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/layers/wrappers.py", line 210, in compute_output_shape
    child_output_shape = self.layer.compute_output_shape(child_input_shape)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/engine/base_layer.py", line 639, in compute_output_shape
    raise NotImplementedError
NotImplementedError

有什么建议么? 是否可以将 KerasLayer 转换为 Conv2D...层?


看来你不能使用TimeDistributed层来解决这个问题。但是,由于您不希望 Resnet 进行训练而只需要输出,因此您可以执行以下操作来避免TimeDistributed layer.

代替model.add(TimeDistributed(base_model, input_shape=(15, 244, 244, 3))), do

Option 1

# 2048 is the output size
model.add(
    Lambda(
        lambda x: tf.reshape(base_model(tf.reshape(x, [-1, 244, 244,3])),[-1, 15, 2048])
    , input_shape=(15, 244, 244, 3))
)

Option 2

如果您不想过多依赖输出形状(但这会牺牲性能)。

model.add(
    Lambda(
        lambda x: tf.stack([base_model(xx) for xx in tf.unstack(x, axis=1) ], axis=1)
    , input_shape=(15, 244, 244, 3))
)
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Tensorflow 2.0 中 KerasLayer 的 TimeDistributed 的相关文章

  • 如何使盒子阴影显示在容器中的下一个元素上?

    请看这段代码 http codepen io Varin pen kkGgVd http codepen io Varin pen kkGgVd div class container div class outside2 div clas
  • 如何在odoo中重写js函数

    我想加载 shop checkout url 函数是 odoo define change info order website sale change info order function require use strict oe w
  • 如何在 Angular 2 中订阅 DOMContentLoaded 事件?

    我正在将 UI 主题从 Angular 1 移植到 Angular 2 在第 1 个版本中 我有 viewContentLoaded事件 我想将其重新制作为 Angular 2 我正在尝试使用 HostListener DOMContent
  • 使用 Swift 解析框架

    有人尝试过将 Parse Framework 与 swift 一起使用吗 只要添加桥接文件 您就可以使用 swift 和 Objective C 代码 这是我的查询 从 Parse 返回的 对象 数组正确地包含了我的所有数据 但该方法在将
  • 为什么 .each 在我的 Rails 视图中完成后会重复数组? [复制]

    这个问题在这里已经有答案了 在我的 Rails 视图页面中 我有以下循环 它应该循环遍历我的 tag list 数组并打印每个标签 由于某种原因 它在打印每个单独的标签后会重复该数组 例如 这个数组有两个元素 ruby python 每个方
  • NV_path_rendering替代方案[关闭]

    Closed 这个问题不符合堆栈溢出指南 help closed questions 目前不接受答案 我刚刚观看了 Siggraph 2012 的一个非常令人印象深刻的演示 http nvidia fullviewmedia com sig
  • addEventListener keydown 不起作用

    我在互联网上找到了一些基本的 Pong 代码 并尝试添加按键 代码在这里 http cssdeck com labs ping pong game tutorial with html5 canvas and sounds http css
  • Selenium Python 使用代理运行浏览器[重复]

    这个问题在这里已经有答案了 我正在尝试编写一个非常简单的脚本 该脚本从 txt 文件获取代理 不需要身份验证 并用它打开浏览器 然后沿着代理列表循环此操作一定次数 我确实知道如何打开 txt 文件并使用它 我的主要问题是让代理正常工作 我见
  • 使用 OpenLayers 动态添加自定义标记到地图

    我想让用户在地图上添加自定义标记以及每个标记的描述 任何提示 任何教程的链接都会非常有用 您可以注册一个函数来在地图上 点击 事件 当用户单击它时 会自动添加该标记 尝试这样的事情 map is your map created using
  • 使用 PowerShell 检查 AD 中是否存在组

    我想为该组创建代码来检查该组是否存在 但是 我无法开始工作 因为它成功地将用户和组的部分成员仅添加到一个组中 而不是其他组 因为我设法在活动目录中创建一个组并从 csv 中读取 这是我的代码和结果 似乎在成功添加用户并添加组成员后我总是收到
  • 新的 .NET 6 控制台模板中的 C# 函数重载不起作用

    我在尝试重载该函数时遇到错误Print object in the 新的 NET 6 C 控制台应用程序模板 https learn microsoft com en us dotnet core tutorials top level t
  • git stash 和编辑帅哥

    我完全喜欢git add p and git stash但我偶尔会遇到以下问题 该问题是通过以下命令序列重现的 git add p my file 然后我手动编辑大块 using e 因为 git 建议的分割不适合我 git stash k
  • 如何在 yii 中设置 cron 作业

    我是 yii 的新手 我正在做一个项目 我写了一个向客户发送自动提醒的功能 假设这个函数位于 url http somedomain com index php somecontroller someaction 我想为此网址设置 cron
  • python中匹配3个或更多相同的字符

    我正在尝试使用正则表达式在字符串中查找三个或更多相同的字符 例如 你好 不匹配 噢 会的 我尝试过做类似的事情 re compile 1 3 a zA Z re compile w 1 5 但似乎都不起作用 w 1 2 是您正在寻找的正则表
  • Android 使用非公历

    我正在创建一个DatePickerDialogFragment用户将在其中选择出生日期 我想确保我可以处理非公历日期 我无法更改在我的设备上使用的日历类型 Android 是否允许用户切换日历类型 如果是的话 步骤是什么 到目前为止我还没有
  • 将元素添加到 D3 圆包节点

    我正在尝试制作一个可缩放的圆形包装图 我希望每个子圆圈包含一个较小的图表 该图表始终具有相同的结构 即 4 列 只有条形的高度会改变 我尝试添加一个简单的rect到目前为止我的图表 但矩形没有添加到圆圈中并且是静态的 JS var marg
  • FindAsync 很慢,但是延迟加载很快

    在我的代码中 我曾经使用加载相关实体await FindAsync 希望我能更好地遵守 C 异步指南 var activeTemplate await exec DbContext FormTemplates FindAsync exec
  • 截断段落前 100 个字符并隐藏段落的其余内容,以通过更多/更少链接显示/隐藏其余内容

    我有一个超过 500 个字符的段落 我只想获取最初的 100 个字符并隐藏其余部分 我还想在 100 个字符旁边插入 更多 链接 单击更多链接时 整个段落应显示并编辑文本 更多 到 更少 单击 更少 时 它应切换行为 段落是动态生成的 我无
  • 同时有两个操作栏(底部和向上)?

    我需要制作两个操作栏 顺便说一下我正在使用actionBarSherlock 所以我真正需要的是在正常操作栏上放置一个 欢迎屏幕 开关 并添加两个正常的 ActionBar 操作选项 与我需要的类似的是 Gmail 和地图 如下所示 htt
  • 如何检测文本是否可读?

    我想知道是否有一种方法可以告诉给定的文本是人类可读的 我所说的人类可读的意思是 它有一些含义 格式就像某人写的文章 或者至少是由软件翻译器生成的供人类阅读的文章 这是背景故事 最近我正在制作一个应用程序 允许用户将短文本上传到数据库 在部署

随机推荐