在 keras 中加载模型后的不同预测

2023-11-27

我在 Keras 中构建了一个序列模型,经过训练后它给了我很好的预测,但是当我保存然后加载模型时,我没有在同一数据集上获得相同的预测。为什么? 请注意,我检查了模型的权重,它们以及模型的架构都是相同的,并使用 model.summary() 和 model.getWeights() 检查。在我看来这很奇怪,我不知道如何处理这个问题。 我没有任何错误,但预测不同

  1. 我尝试使用 model.save() 和 load_model()

  2. 我尝试使用 model.save_weights() ,然后重新构建模型,然后加载模型

我对这两个选项都有同样的问题。

def Classifier(input_shape, word_to_vec_map, word_to_index, emb_dim, num_activation):

    sentence_indices = Input(shape=input_shape, dtype=np.int32)
    emb_dim = 300  # embedding di 300 parole in italiano
    embedding_layer = pretrained_embedding_layer(word_to_vec_map, word_to_index, emb_dim)

    embeddings = embedding_layer(sentence_indices)   

    X = LSTM(256, return_sequences=True)(embeddings)
    X = Dropout(0.15)(X)
    X = LSTM(128)(X)
    X = Dropout(0.15)(X)
    X = Dense(num_activation, activation='softmax')(X)

    model = Model(sentence_indices, X)

    sequentialModel = Sequential(model.layers)    
    return sequentialModel

    model = Classifier((maxLen,), word_to_vec_map, word_to_index, maxLen, num_activation)
    ...
    model.fit(Y_train_indices, Z_train_oh, epochs=30, batch_size=32, shuffle=True)

    # attempt 1
    model.save('classificationTest.h5', True, True)
    modelRNN = load_model(r'C:\Users\Alessio\classificationTest.h5')

    # attempt 2
    model.save_weights("myWeight.h5")

    model = Classifier((maxLen,), word_to_vec_map, word_to_index, maxLen, num_activation)
    model.load_weights(r'C:\Users\Alessio\myWeight.h5') 

    # PREDICTION TEST
    code_train, category_train, category_code_train, text_train = read_csv_for_email(r'C:\Users\Alessio\Desktop\6Febbraio\2test.csv')

    categories, code_categories = get_categories(r'C:\Users\Alessio\Desktop\6Febbraio\2test.csv')

    X_my_sentences = text_train
    Y_my_labels = category_code_train
    X_test_indices = sentences_to_indices(X_my_sentences, word_to_index, maxLen)
    pred = model.predict(X_test_indices)

    def codeToCategory(categories, code_categories, current_code):

        i = 0;
        for code in code_categories:
            if code == current_code:
                return categories[i]
            i = i + 1 
        return "no_one_find"   

    # result
    for i in range(len(Y_my_labels)):
        num = np.argmax(pred[i])

    # Pretrained embedding layer
    def pretrained_embedding_layer(word_to_vec_map, word_to_index, emb_dim):
    """
    Creates a Keras Embedding() layer and loads in pre-trained GloVe 50-dimensional vectors.

    Arguments:
    word_to_vec_map -- dictionary mapping words to their GloVe vector representation.
    word_to_index -- dictionary mapping from words to their indices in the vocabulary (400,001 words)

    Returns:
    embedding_layer -- pretrained layer Keras instance
    """

    vocab_len = len(word_to_index) + 1                  # adding 1 to fit Keras embedding (requirement)

    ### START CODE HERE ###
    # Initialize the embedding matrix as a numpy array of zeros of shape (vocab_len, dimensions of word vectors = emb_dim)
    emb_matrix = np.zeros((vocab_len, emb_dim))

    # Set each row "index" of the embedding matrix to be the word vector representation of the "index"th word of the vocabulary
    for word, index in word_to_index.items():
        emb_matrix[index, :] = word_to_vec_map[word]

    # Define Keras embedding layer with the correct output/input sizes, make it trainable. Use Embedding(...). Make sure to set trainable=False. 
    embedding_layer = Embedding(vocab_len, emb_dim)
    ### END CODE HERE ###

    # Build the embedding layer, it is required before setting the weights of the embedding layer. Do not modify the "None".
    embedding_layer.build((None,))

    # Set the weights of the embedding layer to the embedding matrix. Your layer is now pretrained.
    embedding_layer.set_weights([emb_matrix])

    return embedding_layer

您有什么建议吗?

提前致谢。

Edit1:如果使用在同一“页面”中保存和加载的代码(我使用笔记本jupyter),它工作正常。如果我更改“页面”,它就不起作用。难道是和tensorflow session有关系?

Edit2:我的最终目标是使用 Java 中的 Deeplearning4J 加载在 Keras 中训练的模型。因此,如果您知道将 keras 模型“转换”为 DL4J 中其他可读内容的解决方案,无论如何它都会有所帮助。

Edit3:添加函数 pretrained_embedding_layer()

Edit4:使用 gensim 读取 word2Vec 模型中的字典

from gensim.models import Word2Vec
model = Word2Vec.load('C:/Users/Alessio/Desktop/emoji_ita/embedding/glove_WIKI')

def getMyModels (model):
word_to_index = dict({})
index_to_word = dict({})
word_to_vec_map = dict({})
for idx, key in enumerate(model.wv.vocab):
    word_to_index[key] = idx
    index_to_word[idx] = key
    word_to_vec_map[key] = model.wv[key]
return word_to_index, index_to_word, word_to_vec_map

加载模型时是否以相同的方式预处理数据?

如果是,您是否设置了预处理函数的种子? 如果你用 keras 构建字典,句子的顺序是否相同?

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

在 keras 中加载模型后的不同预测 的相关文章

随机推荐

  • 为什么我在执行 sql 脚本时收到“不一致的结束行”警告窗口?

    当我尝试执行 sql 脚本时 它会出现一个警告窗口 以下文件中的行结尾不一致 是否要 使其正常化 我只是想知道为什么会出现此问题以及如何永久修复它 请帮忙 因为有些行以 CR LF 对结尾 有些行仅以 CR 或 LF 结尾 基本上以某种方式
  • 如何使 sqlalchemy 在反映表时返回 float 而不是 Decimal?

    我有一个在 Python 代码之外定义的 MySQL 数据库 我使用反射将其放入 SQLAlchemy 因此我没有任何可以修改的类定义 我不必担心失去精度 并且我对 Python 中的结果进行了一些算术运算 因此我宁愿不必手动将一堆值转换为
  • 为什么可变参数模板构造函数比复制构造函数更匹配?

    以下代码无法编译 include
  • Django 的 I18N 与第三方应用程序

    我有一个 Django 项目 它使用django tagging并且应该以德语运行 所以我查看了来源并发现django tagging确实使用gettext lazy因此是完全可翻译的 但是 包中没有可用的翻译 所以我认为必须有一种方法可以
  • 结合 R Markdown 和动画包

    有没有办法结合起来animation package和 r 降价 我想生成动画 我想在从 r markdown 生成的 html 文件中包含和描述该动画 当然我可以嵌入代码saveHTML or saveGIF文件已生成的 r markdo
  • Windows 中 TEMP 目录的限制?

    我有一个用 Python 编写的应用程序 它将大量数据写入 TEMP 文件夹 奇怪的是 每隔一段时间 它就会死去 然后回来IOError Errno 28 No space left on device 该驱动器有plenty的自由空间 T
  • 在 Linux 上通过 jenkins 运行 angular2 测试时出现 Karma 错误

    使用 karma 和 jenkins 运行我的 angular2 单元测试时 我看到以下错误 当我在本地计算机 Windows 上运行测试时 我的测试运行良好 但是当在 Linux 上通过 jenkins 运行测试时 我收到以下错误 Mis
  • 如何在WPF MVVM中调用窗口的Loaded事件?

    从我的 OnLoaded 事件创建命令很容易 处理程序代码 但如何从视图中调用它 从此不再切蛋糕 它调用 xaml cs 中的代码 我将如何创建一个 ICommand 相等的 您可以通过附加行为来完成此类事情 为了节省一些时间 看看 Mar
  • Promise.all() 被拒绝后的值,显示 [''PromiseStatus'']:如果存在 catch 块,则已解决

    我有两个承诺 一个被拒绝 另一个被解决 Promise all 被调用 当其中一个承诺被拒绝时 它执行了 Promise all 的 catch 块 const promise1 Promise resolve Promise 1 Reso
  • GTK 或 Qt 的图表小部件 [关闭]

    就目前情况而言 这个问题不太适合我们的问答形式 我们希望答案得到事实 参考资料或专业知识的支持 但这个问题可能会引发辩论 争论 民意调查或扩展讨论 如果您觉得这个问题可以改进并可能重新开放 访问帮助中心以获得指导 GTK 或 Qt 是否有一
  • NuGet 包依赖项

    对于一个包对其他库的每个依赖项 这些库是否也需要解析和安装 例如 我创建了一个使用的包NLog Postsharp and WindowsAzure Storage 我的软件包的客户端现在也必须安装这些软件包吗 为什么不能将这些依赖项 DL
  • 没有可见的接口错误

    我的模型的实现文件中有一个错误 我已将其注释掉 我可以做什么来解决这个问题 提前致谢 import CalculatorBrain h interface CalculatorBrain property nonatomic strong
  • 为什么在 `array.length && ...` 的短路计算中呈现“0”

    目前 我看到这样的行为 render const list return div list length div List rendered div div 我的预期是在该条件下不会呈现任何内容 但会呈现字符串 0 字符串 0 是list
  • spring-mvc中如何传递参数来重定向页面

    我写了以下控制器 RequestMapping value logOut method RequestMethod GET public String logOut Model model RedirectAttributes redire
  • 检查在自定义 Chrome 选项卡中打开哪个网址

    chrome自定义选项卡中是否有类似于Webview的onPageStarted的功能 在 onNavigation 捆绑包始终为空 根据设计 Chrome 自定义选项卡不可能做到这一点 您可以知道用户已经导航 但无法知道他们去了哪里 看
  • 检测浏览器关闭/导航到其他页面并注销的最佳方法

    我正在 GWT 中编写一个应用程序 我需要检测用户何时离开我的应用程序或何时关闭浏览器窗口 onUnload 事件 并执行注销 会话失效和其他一些清理任务 注销操作由 servlet 执行 我目前正在通过挂钩 onUnload 事件并打开一
  • 检查 LatLngBounds.Builder 是否为空

    这是我的代码 LatLngBounds Builder builder new LatLngBounds Builder for int x firstVisibleItem x lt lastVisibleItem x builder i
  • 如何提取直接 Facebook 视频 url

    我正在尝试从 facebook 视频链接中提取 facebook 视频文件页面的 url 但我无法继续操作 例如 我的 Facebook 视频网址是 https www facebook com nerdandco videos 16621
  • 如何在 Woocommerce 中检查产品是否具有特定产品属性

    我想确定产品是否具有属性 例如 if product has attribute pa color do something 我怎样才能做到这一点 您只需使用WC Product method get attribute 这边走 If ne
  • 在 keras 中加载模型后的不同预测

    我在 Keras 中构建了一个序列模型 经过训练后它给了我很好的预测 但是当我保存然后加载模型时 我没有在同一数据集上获得相同的预测 为什么 请注意 我检查了模型的权重 它们以及模型的架构都是相同的 并使用 model summary 和