我正在尝试从某个检查点 (Tensorflow) 恢复训练,因为我正在使用 Colab 并且 12 小时还不够

2024-04-02

这是我正在使用的代码的一部分

checkpoint_dir = 'training_checkpoints1'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(optimizer=optimizer,
                             encoder=encoder,
                             decoder=decoder)

现在这是训练部分

EPOCHS = 900

for epoch in range(EPOCHS):
  start = time.time()

  hidden = encoder.initialize_hidden_state()
  total_loss = 0

  for (batch, (inp, targ)) in enumerate(dataset):
      loss = 0
    
      with tf.GradientTape() as tape:
          enc_output, enc_hidden = encoder(inp, hidden)
        
          dec_hidden = enc_hidden
        
          dec_input = tf.expand_dims([targ_lang.word2idx['<start>']] * batch_size, 1)       
        
          # Teacher forcing - feeding the target as the next input
          for t in range(1, targ.shape[1]):
              # passing enc_output to the decoder
              predictions, dec_hidden, _ = decoder(dec_input, dec_hidden, enc_output)
            
              loss += loss_function(targ[:, t], predictions)
            
              # using teacher forcing
              dec_input = tf.expand_dims(targ[:, t], 1)
    
      batch_loss = (loss / int(targ.shape[1]))
    
      total_loss += batch_loss
    
      variables = encoder.variables + decoder.variables
    
      gradients = tape.gradient(loss, variables)
    
      optimizer.apply_gradients(zip(gradients, variables))
    
      if batch % 100 == 0:
          print('Epoch {} Batch {} Loss {:.4f}'.format(epoch + 1,
                                                     batch,
                                                     batch_loss.numpy()))
  # saving (checkpoint) the model every 2 epochs
  if (epoch + 1) % 2 == 0:
    checkpoint.save(file_prefix = checkpoint_prefix)

  print('Epoch {} Loss {:.4f}'.format(epoch + 1,
                                    total_loss / num_batches))
  print('Time taken for 1 epoch {} sec\n'.format(time.time() - start))

现在我想恢复 exp 这个检查点并从那里开始训练,但我不知道如何做。

path="/content/drive/My Drive/training_checkpoints1/ckpt-9"
checkpoint.restore(path)

Result

<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7f6653263048>

你应该创建一个检查点管理器一开始为:

checkpoint_path = os.path.abspath('.') + "/checkpoints"   # Put your path here
ckpt = tf.train.Checkpoint(encoder=encoder,
                           decoder=decoder,
                           optimizer = optimizer)
ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)

现在,运行几个纪元后,要恢复最新的检查点,您应该从CheckpointManager:

start_epoch = 0
if ckpt_manager.latest_checkpoint:
    start_epoch = int(ckpt_manager.latest_checkpoint.split('-')[-1])
    # restoring the latest checkpoint in checkpoint_path
    ckpt.restore(ckpt_manager.latest_checkpoint)

这将从最新纪元恢复您的会话。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

我正在尝试从某个检查点 (Tensorflow) 恢复训练,因为我正在使用 Colab 并且 12 小时还不够 的相关文章

随机推荐

  • Android O 预览版 findViewById 编译错误

    我尝试测试Android O Developer Preview第二阶段 项目创建后 我只是点击构建并运行 但没有任何成功 Android默认生成的代码如下 Toolbar toolbar Toolbar findViewById R id
  • 与 libbluetooth.so 链接

    在 Ubuntu 14 04 上 我尝试做一个蓝牙设备列表的小示例 但在编译这个简约演示时 我遇到了一个关于与蓝牙共享库链接的简单问题http people csail mit edu albert bluez intro c404 htm
  • 具有边框半径和线性渐变的 CSS 过渡

    鉴于我的 CodePenhttps codepen io scottmgerstl pen MpMeBy https codepen io scottmgerstl pen MpMeBy这是我有问题的图像布局 span class prof
  • 根据事件日志触发powershell

    我有一个用 PowerShell 编写的命令行参数脚本 它接受来自任务计划程序的服务器名称 然而 我的要求是在 SQL 服务器重新启动时执行脚本 因此我已将 PowerShell 脚本附加到事件 17069 但我无法动态传递事件源 在本例中
  • 如何设置 hibernate-mapping 以允许长度超过 255 个字符的字符串?

    所以我试图通过创建一个博客引擎来学习 我正在使用 Hibernate 和 MySQL 这是我的 Post 类的休眠映射
  • JoptionPane 显示确认对话框

    我有一个Java程序 当我运行该程序时 它会给我一个 GUI 如我所附 当我想关闭它时 它会弹出一个确认对话框 如果我按 是 按钮 它将使用以下命令退出程序System exit public static void main String
  • 如何对列表进行排序,其中正值位于负值之前,并且值分别排序?

    我有一个包含正数和负数混合的列表 如下所示 lst 1 2 10 12 4 5 9 2 我想要完成的任务是对列表进行排序 其中正数位于负数之前 也分别排序 期望的输出 1 2 9 10 12 5 4 2 我能够计算出第一部分的排序 其中正数
  • 将具有相同键的节点添加到属性树中

    我正在使用 Boost 的属性树来读取和写入 XML 使用我制作的电子表格应用程序 我想将电子表格的内容保存到 xml 这是一项学校作业 因此我需要使用以下 XML 格式
  • Swift 版本构建配置

    在 Swift v4 2 中 他们引入了扩展Bool toggle 我从早些时候就有了这个扩展 现在当我用 Xcode10 编译时它说Ambiguous use of toggle 如果 Swift 版本是 4 2 或更高版本 我试图让它忽
  • iOS 8 UIView 在键盘出现时不向上移动

    我正在开发一个聊天应用程序 其中有UITableView and a UIView含有一个UITextField and a UIButton在里面 我正在使用以下代码来移动UIView当键盘出现时向上 void keyboardWillS
  • Spring Security Saml 和 SP 应用程序的无状态会话

    我尝试运行启动示例 spring security saml boot https github com vdeotaris spring boot security saml sample https github com vdenota
  • 以编程方式完成 TFS Pull 请求

    使用Microsoft TeamFoundationServer Client 15 112 1 连接到TFS 2017 更新 2服务器我们可以获取有关现有 PR 的详细信息 如下所示 var connection new VssConne
  • 仅针对一个框架的 MSBuild 目标

    我有一个具有多框架目标的项目
  • 不安全的 JavaScript 尝试使用 URL 启动框架导航

    这有点复杂 请耐心等待 网站 A 有一个包含网站 B 的 iframe 网站 B 有一个包含网站 C 的 iframe 网站 C 上有一个按钮 单击后 我想刷新网站 B 的 url 下面是调用的 javascript 用于从网站 C 刷新网
  • 多线程 Objective-C 访问器:GCD 与锁

    我正在争论是否要转向基于 GCD 的多线程访问器模式 多年来我一直在访问器中使用基于自定义锁的同步 但我发现了一些信息 GCD简介 http www mikeash com pyblog friday qa 2009 08 28 intro
  • 如何删除 jQuery Mobile 样式?

    我之所以选择 jQuery Mobile 是因为它的动画功能和动态页面支持 而不是其他框架 然而 我在造型方面遇到了麻烦 我想保留基本页面样式以便执行页面转换 但我还需要完全自定义标题 列表视图 按钮 搜索框的外观 仅处理颜色是不够的 我需
  • Jetty + intellij idea :: 添加库

    I get java lang NoClassDefFoundError当我将 3d party 库添加到我的项目中时 我尝试将库添加到 web inf 模块依赖项 服务器库 但它不起作用 使用jetty和idea将库添加到项目的正确方法是
  • Gradle:应用程序和测试应用程序的已解决版本不同

    当我添加依赖项时 compile net bytebuddy byte buddy android 0 7 8 在我的应用程序中 我收到此错误 Conflict with dependency net bytebuddy byte budd
  • SQL 条件排序依据

    我正在两个表上进行连接 一个是用户表 另一个是高级用户列表 我需要让高级会员首先出现在我的查询中 然而 仅仅因为他们位于高级用户表中并不意味着他们仍然是高级会员 还有一个 IsActive 字段也需要检查 所以基本上我需要按以下顺序返回结果
  • 我正在尝试从某个检查点 (Tensorflow) 恢复训练,因为我正在使用 Colab 并且 12 小时还不够

    这是我正在使用的代码的一部分 checkpoint dir training checkpoints1 checkpoint prefix os path join checkpoint dir ckpt checkpoint tf tra