Bert Estimator input_fn 函数调用逻辑

2023-11-16

Bert Estimator input_fn 函数调用逻辑

网上有很多讲 Bert 源码的,本身代码难度不大,主要两个重点,一个是数据集的处理,以满足 masked LM 和 next sentence predict 两个任务的需求,这一部分主要围绕 create_pretaining_data.py 看就行,另一部分就是预训练,主要围绕 run_pretraining.pymodeling.py 看就行,主要是 Transformer 模型 Encoder 部分的堆砌,至于 fine-tune 也类似这两个重点过程。

像 TensorFlow 这种库,“友好”到 Estimator 这种高级库可以大大简化我们的工作,找个 demo 做填空题基本就能搞定,“不友好”到封装的我们都不知道它是怎么运行的,例如,很多人一开始看代码都不清楚 input_fn 或者 model_fn 是怎么被调用执行的,params 参数下的 batch_size 又是在哪里被鼓捣进去的,我们下面以 input_fn 为例简单唠叨唠叨:

    • estimator.train(input_fn=train_input_fn, …) 这是第一步,就是调用,此时其实 input_fn 就是 input_fn_builder 函数的 return input_fn,此时实际并没有进入到这个子函数内部,所以重点就是看 train 函数了。实际上又扔给了 TPUEstimator 的父类,也就是 Estimator 的 train方法了。

    •   return super(TPUEstimator, self).train(
                  input_fn=input_fn,
                  hooks=hooks,
                  steps=steps,
                  max_steps=max_steps,
                  saving_listeners=saving_listeners)
      
    • 看一下这个 train_train_model,进去

      saving_listeners = _check_listeners_type(saving_listeners)
      loss = self._train_model(input_fn, hooks, saving_listeners)
      logging.info('Loss for final step: %s.', loss)
      return self
      
      def _train_model(self, input_fn, hooks, saving_listeners):
        if self._train_distribution:
          return self._train_model_distributed(input_fn, hooks, saving_listeners)
        else:
          return self._train_model_default(input_fn, hooks, saving_listeners)
      # 这里走下面这个默认的就好了,都一样其实
      
      features, labels, input_hooks = (
                self._get_features_and_labels_from_input_fn(
                    input_fn, ModeKeys.TRAIN))
      

      在这个地方开始获取数据特征与标签了,也就是要实际进入到 input_fn 内部了

    • 进去看看

      def _get_features_and_labels_from_input_fn(self, input_fn, mode):
        """Extracts the `features` and labels from return values of `input_fn`."""
        return estimator_util.parse_input_fn_result(
          self._call_input_fn(input_fn, mode))
      
      # 开始 call 调用了,这里的 mode 是 train 或者 eval 等
      
      # _call_input_fn 的第一句是这个
      input_fn_args = function_utils.fn_args(input_fn)
      
      def fn_args(fn):
        """Get argument names for function-like object.
      
        Args:
          fn: Function, or function-like object (e.g., result of `functools.partial`).
      
        Returns:
          `tuple` of string argument names.
      
        Raises:
          ValueError: if partial function has positionally bound arguments
        """
        if isinstance(fn, functools.partial):   # 不符合
          args = fn_args(fn.func)
          args = [a for a in args[len(fn.args):] if a not in (fn.keywords or [])]
        else:
          if _is_callable_object(fn):    # 不符合
            fn = fn.__call__
          args = tf_inspect.getfullargspec(fn).args   # ----> 吊炸天的一个函数
          if _is_bounded_method(fn):
            args.pop(0)  # remove `self` or `cls`
        return tuple(args)
      

      fn_args 就是获取函数或者 function-like 对象的参数的,getfullargspec 方法好像很吊,可以直接获取函数在哪个文件的哪一行,有哪些参数,你的函数内部有哪些变量,很吊的样子,这个是 Python 内部提供的,不过 tf 自己也封装了一下,这个不必纠结,总之 args 就是 input_fn 的参数,即 params

    • 回到 _call_fn_input 中,现在 params 参数是存在的,

      继续该函数

          with self._ctx.with_mode(mode) as ctx:
            # Setting the batch size in params first. This helps user to have same
            # input_fn for use_tpu=True/False.
            batch_size_for_input_fn = ctx.batch_size_for_input_fn    # 进去
            if batch_size_for_input_fn is not None:
              _add_item_to_params(kwargs['params'], _BATCH_SIZE_KEY,
                                  batch_size_for_input_fn)
      

      看一下 ctx.batch_size_for_input_fn

        def batch_size_for_input_fn(self):
          """Returns the shard batch size for `input_fn`."""
          global_batch_size = self.global_batch_size    # 在这里
          if (self.is_running_on_cpu() or self.is_input_broadcast_with_iterators()
              and not self.is_input_slice_broadcast_to_all_cores()):
            return global_batch_size
      

      global_batch_size 是什么鬼?进去看看

        @property
        def global_batch_size(self):
          mode = self._assert_mode()
          if mode == model_fn_lib.ModeKeys.TRAIN:
            return self._train_batch_size         # 这个实际就是我们传的train_batch_size
          elif mode == model_fn_lib.ModeKeys.EVAL:
            return self._eval_batch_size
          elif mode == model_fn_lib.ModeKeys.PREDICT:
            return self._predict_batch_size
          else:
            return None
      

      回到上面倒数第三张图中

          with self._ctx.with_mode(mode) as ctx:
            # Setting the batch size in params first. This helps user to have same
            # input_fn for use_tpu=True/False.
            batch_size_for_input_fn = ctx.batch_size_for_input_fn    # 进去
            if batch_size_for_input_fn is not None:
              _add_item_to_params(kwargs['params'], _BATCH_SIZE_KEY,
                                  batch_size_for_input_fn)
            # For export_saved_model, input_fn is never passed to Estimator. So,
            # `is_export_mode` must be False.
            if ctx.is_running_on_cpu(is_export_mode=False):
              with ops.device('/device:CPU:0'):
                return input_fn(**kwargs)         # ----> 真正带着 params = {"batch_size": 32} 
      

      _add_item_to_params 就是把 params 内加一个 batch_size 参数,_BATCH_SIZE_KEY 是定义的一个字符串_BATCH_SIZE_KEY = 'batch_size'

      def _add_item_to_params(params, key, value):
        """Adds a new item into `params`."""
        if hasattr(params, 'set_hparam'):
          # For HParams, we need to use special API.
          if key in params:
            params.set_hparam(key, value)
          else:
            params.add_hparam(key, value)
        else:
          # Now params is Python dict.
          params[key] = value            #  ----> 就是这句话
      
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Bert Estimator input_fn 函数调用逻辑 的相关文章

随机推荐

  • 华为云云耀云服务器L实例评测

    大家好 我是雄雄 欢迎关注微信公众号 雄雄的小课堂 目录 前言 效果图 购买云耀云服务器L实例 重置密码 放开端口 远程连接 安装云监控面板 进入监控面板 前言 有幸参与了华为云云耀云服务器L实例的评测名额 借着评测 顺便教给大家一项技能
  • 多视图聚类(multi-view clustering)简介

    多视图聚类 目前大概有以下几种 多视图k means聚类 多视图谱聚类 多视图图聚类 多视图子空间聚类 multi view subspace clustering 深度学习多视图聚类 deep multi view clustering
  • Vector迭代器实现

    实现数组的迭代器 实现内容 1 使用C 语言实现一个长度可扩充的数组结构 要求使用class实现 不能直接使用vector等现成的数据结构 2 要求实现为可以用于不同数据类型的数组结构 并不是说同一个对象需要存储多种类型的数据 建议使用te
  • 【满分】【华为OD机试真题2023 JAVA&JS】租车骑绿道

    华为OD机试真题 2023年度机试题库全覆盖 刷题指南点这里 租车骑绿道 时间限制 1s 空间限制 256MB 限定语言 不限 题目描述 部门组织绿道骑行团建活动 租用公共双人自行车骑行 每辆自行车最多坐两人 做大载重M 给出部门每个人的体
  • 毕业设计 单片机与OpenMV机器视觉目标跟踪系统

    文章目录 0 前言 课题简介 设计框架 3 openMV实现舵机定位色块STM32 3 硬件设计 4 软件设计 4 1 硬件连接 4 2 软件代码 OpenMV端 4 3 软件代码 STM32端 4 4 利用PC端测试数据数据是否发送接收正
  • 《银行法律法规》一、经济金融基础知识——3、金融市场

    第三章 金融市场 第一节 金融市场概述 考点1 金融市场功能 概念 金融市场是指货币资金融通和金融工具交易的场所 金融市场的融资行为既包括以银行等金融机构为信用媒介的间接融资行为 也包括各类交易主体之间的直接融资行为 主体 是各类融资活动的
  • 运维企业实战Shell脚本合集+万能工具箱

    文章目录 系统维护篇 服务器日常巡检脚本 下线登录用户 企业级Linux日常自动抓取服务器巡检 登录 执行命令记录 备份脚本 终端对话 广播消息 批量查询IP归属地 手机号归属地信息 Linux开机后自动执行命令或脚本 一键自动格式化输出S
  • Anaconda中安装指定版本的tensorflow1.14.0/tensorflow-gpu1.14.0

    在运行github中一个项目时 由于其使用的tensorflow的版本是1 14 0 而我的版本是2 6 0的版本 因为版本过高导致运行失败 所以需要安装tensorflow1 14 0 首先在anaconda的命令行中输入如下命令 pip
  • 【Qt】【CMake】【CMakeLists.txt】-PROJECT_NAME 和 CMAKE_PROJECT_NAME 的区别

    Qt CMake CMakeLists txt PROJECT NAME 和 CMAKE PROJECT NAME 的区别 原帖 https stackoverflow com questions 38938315 difference b
  • 2000+Docker镜像,Kolla是如何管理的

    根据 DockerHub 上的数据 整个 Kolla 项目管理的 镜像有 2000 多个 这么多的镜像 是怎么定义 又是如何构建的呢 简介 我们一直在说的 Kolla 通常情况下泛指 包括了 Kolla 和 Kolla Ansible 两个
  • 二进制部署K8s

    一 环境需求 节点IP 节点名称 所需组件 192 168 248 11 k8s master docker etcd apiserver controller manager scheduler kube proxy flannel 19
  • cobra库:基于cobra-cli命令行生成项目结构

    cobra库 基于cobra cli命令行生成项目结构 一 新建go项目 在F盘创建文件夹cobra started 1 使用mod对go项目进行管理 go mod init cobra started 二 使用cobra cli代码生成
  • 手写嵌入式操作系统(基于stm8单片机)

    include
  • maven学习总结

    众所周知 maven的两大作用是项目构建和依赖管理 除此之外 基于多模块项目 maven常用的功能还有模块化管理 项目构建 Maven是一个构建工具 可以根据项目中的配置文件 pom xml 来自动执行项目的构建过程 它可以将源代码编译 运
  • win10 win7局域网、AD域内共享文件夹方法

    第一 确保访问电脑和被访问电脑同在域中 可右击此电脑 属性 域 查看 第二 确保防火墙关闭 如图均已关闭 第三 选择要共享的文件夹 右击 属性 共享 高级共享 全新 Everyone或指定个人 第四 分享地址 即 本机IP地址 win r输
  • 惠普 g5 服务器 centos安装系统,hp 380G5 安装centos 7

    最近给服务器升级操作系统 发现hp的老机器安装centos 7时不能识别硬盘 原因 hp的服务器G5 使用的是CCISS driver 新的机器使用的是HPSA driver RHEL7 已经移除了 cciss 的支持 处理 安装时候 修改
  • 常数据成员、常成员函数

    定义常数据成员 类型 const 对象名 或者 const 类型 对象名 例如 const clock c1 9 9 9 或者 clock const c2 10 10 10 常对象的几条特殊规则 1 常对象 不能被赋值 2 常对象 不能访
  • 【Pytorch Lighting】第 7 章:半监督学习

    大家好 我是Sonhhxg 柒 希望你看完之后 能对你有所帮助 不足请指正 共同学习交流 个人主页 Sonhhxg 柒的博客 CSDN博客 欢迎各位 点赞 收藏 留言 系列专栏 机器学习 ML 自然语言处理 NLP 深度学习 DL fore
  • 2、halcon+利用光流场检测运动的物体

    这个事例是应用optical flow mg这个算子来在一个图像序列中计算其光溜 并且分割其运动物体 dev update off 把程序窗口 变量窗口 显示窗体变为off状态 dev close window 关闭显示窗口 read im
  • Bert Estimator input_fn 函数调用逻辑

    目录 Bert Estimator input fn 函数调用逻辑 Bert Estimator input fn 函数调用逻辑 网上有很多讲 Bert 源码的 本身代码难度不大 主要两个重点 一个是数据集的处理 以满足 masked LM