从 keras 模型中将特征提取到数据集中

2024-05-04

我使用以下代码(由here https://github.com/keras-team/keras/blob/master/examples/mnist_cnn.py)运行 CNN 来训练 MNIST 图像:

from __future__ import print_function
import keras
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Dense, Dropout, Flatten
from keras.layers import Conv2D, MaxPooling2D
from keras import backend as K

batch_size = 128
num_classes = 10
epochs = 1

# input image dimensions
img_rows, img_cols = 28, 28

# the data, split between train and test sets
(x_train, y_train), (x_test, y_test) = mnist.load_data()

if K.image_data_format() == 'channels_first':
    x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols)
    x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols)
    input_shape = (1, img_rows, img_cols)
else:
    x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
    x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
    input_shape = (img_rows, img_cols, 1)

x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
print('x_train shape:', x_train.shape)
print(x_train.shape[0], 'train samples')
print(x_test.shape[0], 'test samples')

# convert class vectors to binary class matrices
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)

model = Sequential()
model.add(Conv2D(32, kernel_size=(3, 3),
                 activation='relu',
                 input_shape=input_shape))
model.add(Conv2D(64, (3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))
model.add(Flatten())
model.add(Dense(128, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(num_classes, activation='softmax'))

model.compile(loss=keras.losses.categorical_crossentropy,
              optimizer=keras.optimizers.Adadelta(),
              metrics=['accuracy'])
model.fit(x_train, y_train,
          batch_size=batch_size,
          epochs=epochs,
          verbose=1,
          validation_data=(x_test, y_test))

print(model.save_weights('file.txt')) # <<<<<----

score = model.evaluate(x_test, y_test, verbose=0)
print('Test loss:', score[0])
print('Test accuracy:', score[1])

我的目标是使用 CNN 模型将 MNIST 特征提取到数据集中,我可以将其用作另一个分类器的输入。在这个例子中,我不关心分类操作,因为我需要的只是训练图像的特征。我发现的唯一方法是save_weights as:

print(model.save_weights('file.txt'))

如何从 keras 模型中将特征提取到数据集中?


训练或加载现有训练模型后,您可以创建另一个模型:

extract = Model(model.inputs, model.layers[-3].output) # Dense(128,...)
features = extract.predict(data)

并使用.predict方法从特定层返回向量,在这种情况下,每个图像将变为 (128,),即 Dense(128, ...) 层的输出。

您还可以使用 2 个输出联合训练这些网络函数式API https://keras.io/getting-started/functional-api-guide/。按照指南进行操作,您会发现可以将模型链接在一起,并具有多个输出,每个输出可能有一个单独的损失。这将使您的模型能够学习共享特征,这些特征对于同时对 MNIST 图像进行分类和您的任务很有用。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

从 keras 模型中将特征提取到数据集中 的相关文章

  • 在 django ORM 中查询时如何将 char 转换为整数?

    最近开始使用 Django ORM 我想执行这个查询 select student id from students where student id like 97318 order by CAST student id as UNSIG
  • 如何使用 opencv.omnidir 模块对鱼眼图像进行去扭曲

    我正在尝试使用全向模块 http docs opencv org trunk db dd2 namespacecv 1 1omnidir html用于对鱼眼图像进行扭曲处理Python 我正在尝试适应这一点C 教程 http docs op
  • 用枢轴点拟合曲线 Python

    我有下面的图 我想用 2 条线来拟合它 使用 python 我设法适应上半部分 def func x a b x np array x return a x b popt pcov curve fit func up x up y 我想用另
  • 如何使用 Pandas、Numpy 加速 Python 中的嵌套 for 循环逻辑?

    我想检查一下表的字段是否TestProject包含了Client端传入的参数 嵌套for循环很丑陋 有什么高效简单的方法来实现吗 非常感谢您的任何建议 def test parameter a list parameter b list g
  • Pandas Merge (pd.merge) 如何设置索引和连接

    我有两个 pandas 数据框 dfLeft 和 dfRight 以日期作为索引 dfLeft cusip factorL date 2012 01 03 XXXX 4 5 2012 01 03 YYYY 6 2 2012 01 04 XX
  • 使用 xlrd 打开 BytesIO (xlsx)

    我正在使用 Django 需要读取上传的 xlsx 文件的工作表和单元格 使用 xlrd 应该可以 但因为文件必须保留在内存中并且可能不会保存到我不知道如何继续的位置 本例中的起点是一个带有上传输入和提交按钮的网页 提交后 文件被捕获req
  • Python beautifulsoup 仅限 1 级文本

    我看过其他 beautifulsoup 得到相同级别类型的问题 看来我的有点不同 这是网站 我正试图拿到右边那张桌子 请注意表的第一行如何展开为该数据的详细细分 我不想要那个数据 我只想要最顶层的数据 您还可以看到其他行也可以展开 但在本例
  • “隐藏”内置类对象、函数、代码等的名称和性质[关闭]

    Closed 这个问题需要多问focused help closed questions 目前不接受答案 我很好奇模块中存在的类builtins无法直接访问的 例如 type lambda 0 name function of module
  • 在Python中检索PostgreSQL数据库的新记录

    在数据库表中 第二列和第三列有数字 将会不断添加新行 每次 每当数据库表中添加新行时 python 都需要不断检查它们 当 sql 表中收到的新行数低于 105 时 python 应打印一条通知消息 警告 数量已降至 105 以下 另一方面
  • 如何使用python在一个文件中写入多行

    如果我知道要写多少行 我就知道如何将多行写入一个文件 但是 当我想写多行时 问题就出现了 但是 我不知道它们会是多少 我正在开发一个应用程序 它从网站上抓取并将结果的链接存储在文本文件中 但是 我们不知道它会回复多少行 我的代码现在如下 r
  • Docker 中的 Python 日志记录

    我正在 Ubuntu Web 服务器上的 Docker 容器中测试运行 python 脚本 我正在尝试查找由 Python Logger 模块生成的日志文件 下面是我的Python脚本 import time import logging
  • 在 Sphinx 文档中*仅*显示文档字符串?

    Sphinx有一个功能叫做automethod从方法的文档字符串中提取文档并将其嵌入到文档中 但它不仅嵌入了文档字符串 还嵌入了方法签名 名称 参数 我如何嵌入only文档字符串 不包括方法签名 ref http www sphinx do
  • 如何通过索引列表从 dask 数据框中选择数据?

    我想根据索引列表从 dask 数据框中选择行 我怎样才能做到这一点 Example 假设我有以下 dask 数据框 dict A 1 2 3 4 5 6 7 B 2 3 4 5 6 7 8 index x1 a2 x3 c4 x5 y6 x
  • Numpy - 根据表示一维的坐标向量的条件替换数组中的值

    我有一个data多维数组 最后一个是距离 另一方面 我有距离向量r 例如 Data np ones 20 30 100 r np linspace 10 50 100 最后 我还有一个临界距离值列表 称为r0 使得 r0 shape Dat
  • Cython 和类的构造函数

    我对 Cython 使用默认构造函数有疑问 我的 C 类 Node 如下 Node h class Node public Node std cerr lt lt calling no arg constructor lt lt std e
  • 如何使用原始 SQL 查询实现搜索功能

    我正在创建一个由 CS50 的网络系列指导的应用程序 这要求我仅使用原始 SQL 查询而不是 ORM 我正在尝试创建一个搜索功能 用户可以在其中查找存储在数据库中的书籍列表 我希望他们能够查询 书籍 表中的 ISBN 标题 作者列 目前 它
  • 如何在 pygtk 中创建新信号

    我创建了一个 python 对象 但我想在它上面发送信号 我让它继承自 gobject GObject 但似乎没有任何方法可以在我的对象上创建新信号 您还可以在类定义中定义信号 class MyGObjectClass gobject GO
  • 将 Python 中的日期与日期时间进行比较

    所以我有一个日期列表 datetime date 2013 7 9 datetime date 2013 7 12 datetime date 2013 7 15 datetime date 2013 7 18 datetime date
  • 使用for循环时如何获取前一个元素? [复制]

    这个问题在这里已经有答案了 可能的重复 Python 循环内的上一个和下一个值 https stackoverflow com questions 1011938 python previous and next values inside
  • Keras:多类 NLP 任务中 model.evaluate 与 model.predict 的准确性差异

    我正在使用以下代码在 keras 中为 NLP 任务训练一个简单模型 训练集 测试集和验证集的变量名称是不言自明的 该数据集有 19 个类 因此网络的最后一层有 19 个输出 标签也是 one hot 编码的 nb classes 19 m

随机推荐

  • Firefox 中的相对位置[重复]

    这个问题在这里已经有答案了 可能的重复 Firefox 是否支持表格元素上的position relative https stackoverflow com questions 5148041 does firefox support p
  • C++11 类型推导与 const char *

    In GotW 94 http herbsutter com 2013 08 12 gotw 94 solution aaa style almost always auto Herb Sutter 对 经典 C 声明进行了区分 const
  • 自定义 ViewEngine ASP.NET MVC 3

    我正在为 ASP NET MVC 的自定义视图引擎寻找最简单的解决方案 这样我就可以超越路径来寻找视图 实际上 我正在尝试在我的解决方案中构建一个主题系统 我查看了网络 但发现了很难学习和实施的解决方案 Thanks 这就是我用的 它在主题
  • 使用 Windows Media Foundation 枚举时如何获取硬件 ID

    我在用MFEnumDeviceSources 枚举连接的设备 我正在寻找一个已连接两个的特定网络摄像头 枚举工作正常 我可以打印友好名称这是FLIR Video对于我的两台相机 我正在努力弄清楚如何从 Media Foundation 设备
  • 生成总和为 N 的所有数字排列

    我正在编写一个程序来创建所有数字 起初 我尝试使用分区函数对数字进行分区 然后对每个数字集进行排列 但是我认为这行不通 最好的方法是递归排列 同时对数字求和 这超出了我的能力范围 抱歉 如果这听起来真的很愚蠢 但我真的不知道 Example
  • 部分渲染冗余方法调用

    我知道 JSF 可能会调用托管 bean 方法几次 即使它在 xhtml 中只调用一次 我知道这是由于编码 方法造成的 我想请您向我解释一下以下案例 我有一个类似这样的 JSF 文件
  • RecyclerView 中的 OnLongItemClick

    我开始在 Android 中使用 RecyclerView 一切工作正常 直到我为我的适配器实现触摸侦听器 来自这个主题 https stackoverflow com a 26826692 2923403 https stackoverf
  • 如何设置 MySQL Workbench 自动断开与服务器的连接?

    有没有办法设置Workbench在空闲时自动与服务器断开连接 命令行 mysql 客户端在空闲时断开连接 然后在运行查询时重新连接 我也希望 Workbench 自动断开连接 我无法修改服务器的超时设置 但命令行客户端可以按照当前服务器设置
  • 在 R 中连接/匹配数据帧

    我有两个数据框 第一列有两列 x是水深 y是每个深度的温度 第二个也有两列 x也是水深 但与第一个表中的深度不同 第二栏z是盐度 我想通过以下方式连接两个表x 通过增加z到第一张桌子 我已经学会了如何使用 key 来连接表tidyr 但只有
  • 读取进程的进程内存不会返回所有内容

    我正在尝试扫描第三方应用程序的内存 我已经查到地址了 现在是在0x0643FB78 问题是 从那以后我就再也爬不上去LPMODULEENTRY32 gt modBaseAddr is 0x00400000 and LPMODULEENTRY
  • 在 Java 中引发竞争条件

    我必须编写一个引发竞争条件的单元测试 以便我可以测试稍后是否可以解决问题 问题是竞争条件很少发生 可能是因为我的计算机只有两个核心 代码如下 class MyDateTime String getColonTime datetime is
  • Select2 基本示例不起作用

    我想得到select2使用 symfony2 脚本的库 我正在尝试实现提供的基本示例https select2 github io examples html https select2 github io examples html pa
  • C# 中的通用 foreach 循环

    给出以下代码的编译器告诉我 使用未分配的局部变量 x 有什么想法吗 public delegate Y Function
  • 如何使用 JS 和 Chrome 控制台向频道发送 Discord 消息?

    如何使用 JS 和 Chrome 控制台在不使用 Discord API 的情况下将 Discord 消息发送到 Discord 频道 看来这是不可能的事了 打开不和谐控制台 ctrl shift i 不起作用 请参阅下面的编辑 然后进入网
  • 为什么 Java 原始数据类型不称为 java 数据类型?

    我有一个问题 为什么 Java 原始数据类型不直接称为 Java 数据类型 或类似的名称 因为Java有更多的数据类型原语 http java sun com docs books tutorial java nutsandbolts da
  • Boost.Intrusive 和 unordered_map

    我希望使用侵入性的 unordered map 由于某种原因 库中只有一个 unordered set 还有一个侵入式哈希表 但我不确定它是否具有相同的功能 而且它没有相同的接口 我错了吗 我错过了 unordered map 链接吗 如果
  • 无法处理 ajax 中的 302 重定向,为什么? [复制]

    这个问题在这里已经有答案了 我有一个使用表单身份验证用 asp net mvc 编写的后端服务器 当用户未通过身份验证时 服务器将自动发送 302 重定向到登录操作并返回登录页面 在客户端 我有一个项目列表 只有经过身份验证的用户才能访问此
  • C# 如何比较两个字符串并指出哪些部分不同

    例如 如果我有 string a personil string b personal 我想得到 string c person i l 然而 它不一定是单个字符 我也可以这样 string a disfuncshunal string b
  • 带有 OpenCV 的增强现实 SDK [关闭]

    就目前情况而言 这个问题不太适合我们的问答形式 我们希望答案得到事实 参考资料或专业知识的支持 但这个问题可能会引发辩论 争论 民意调查或扩展讨论 如果您觉得这个问题可以改进并可能重新开放 访问帮助中心 help reopen questi
  • 从 keras 模型中将特征提取到数据集中

    我使用以下代码 由here https github com keras team keras blob master examples mnist cnn py 运行 CNN 来训练 MNIST 图像 from future import