-
estimator.train(input_fn=train_input_fn, …) 这是第一步,就是调用,此时其实 input_fn 就是 input_fn_builder
函数的 return input_fn,此时实际并没有进入到这个子函数内部,所以重点就是看 train
函数了。实际上又扔给了 TPUEstimator 的父类,也就是 Estimator 的 train
方法了。
-
return super(TPUEstimator, self).train(
input_fn=input_fn,
hooks=hooks,
steps=steps,
max_steps=max_steps,
saving_listeners=saving_listeners)
-
看一下这个 train
,_train_model
,进去
saving_listeners = _check_listeners_type(saving_listeners)
loss = self._train_model(input_fn, hooks, saving_listeners)
logging.info('Loss for final step: %s.', loss)
return self
def _train_model(self, input_fn, hooks, saving_listeners):
if self._train_distribution:
return self._train_model_distributed(input_fn, hooks, saving_listeners)
else:
return self._train_model_default(input_fn, hooks, saving_listeners)
# 这里走下面这个默认的就好了,都一样其实
features, labels, input_hooks = (
self._get_features_and_labels_from_input_fn(
input_fn, ModeKeys.TRAIN))
在这个地方开始获取数据特征与标签了,也就是要实际进入到 input_fn
内部了
-
进去看看
def _get_features_and_labels_from_input_fn(self, input_fn, mode):
"""Extracts the `features` and labels from return values of `input_fn`."""
return estimator_util.parse_input_fn_result(
self._call_input_fn(input_fn, mode))
# 开始 call 调用了,这里的 mode 是 train 或者 eval 等
# _call_input_fn 的第一句是这个
input_fn_args = function_utils.fn_args(input_fn)
def fn_args(fn):
"""Get argument names for function-like object.
Args:
fn: Function, or function-like object (e.g., result of `functools.partial`).
Returns:
`tuple` of string argument names.
Raises:
ValueError: if partial function has positionally bound arguments
"""
if isinstance(fn, functools.partial): # 不符合
args = fn_args(fn.func)
args = [a for a in args[len(fn.args):] if a not in (fn.keywords or [])]
else:
if _is_callable_object(fn): # 不符合
fn = fn.__call__
args = tf_inspect.getfullargspec(fn).args # ----> 吊炸天的一个函数
if _is_bounded_method(fn):
args.pop(0) # remove `self` or `cls`
return tuple(args)
fn_args
就是获取函数或者 function-like 对象的参数的,getfullargspec
方法好像很吊,可以直接获取函数在哪个文件的哪一行,有哪些参数,你的函数内部有哪些变量,很吊的样子,这个是 Python 内部提供的,不过 tf 自己也封装了一下,这个不必纠结,总之 args 就是 input_fn 的参数,即 params
-
回到 _call_fn_input
中,现在 params 参数是存在的,
继续该函数
with self._ctx.with_mode(mode) as ctx:
# Setting the batch size in params first. This helps user to have same
# input_fn for use_tpu=True/False.
batch_size_for_input_fn = ctx.batch_size_for_input_fn # 进去
if batch_size_for_input_fn is not None:
_add_item_to_params(kwargs['params'], _BATCH_SIZE_KEY,
batch_size_for_input_fn)
看一下 ctx.batch_size_for_input_fn
def batch_size_for_input_fn(self):
"""Returns the shard batch size for `input_fn`."""
global_batch_size = self.global_batch_size # 在这里
if (self.is_running_on_cpu() or self.is_input_broadcast_with_iterators()
and not self.is_input_slice_broadcast_to_all_cores()):
return global_batch_size
global_batch_size 是什么鬼?进去看看
@property
def global_batch_size(self):
mode = self._assert_mode()
if mode == model_fn_lib.ModeKeys.TRAIN:
return self._train_batch_size # 这个实际就是我们传的train_batch_size
elif mode == model_fn_lib.ModeKeys.EVAL:
return self._eval_batch_size
elif mode == model_fn_lib.ModeKeys.PREDICT:
return self._predict_batch_size
else:
return None
回到上面倒数第三张图中
with self._ctx.with_mode(mode) as ctx:
# Setting the batch size in params first. This helps user to have same
# input_fn for use_tpu=True/False.
batch_size_for_input_fn = ctx.batch_size_for_input_fn # 进去
if batch_size_for_input_fn is not None:
_add_item_to_params(kwargs['params'], _BATCH_SIZE_KEY,
batch_size_for_input_fn)
# For export_saved_model, input_fn is never passed to Estimator. So,
# `is_export_mode` must be False.
if ctx.is_running_on_cpu(is_export_mode=False):
with ops.device('/device:CPU:0'):
return input_fn(**kwargs) # ----> 真正带着 params = {"batch_size": 32}
_add_item_to_params
就是把 params 内加一个 batch_size 参数,_BATCH_SIZE_KEY
是定义的一个字符串_BATCH_SIZE_KEY = 'batch_size'
def _add_item_to_params(params, key, value):
"""Adds a new item into `params`."""
if hasattr(params, 'set_hparam'):
# For HParams, we need to use special API.
if key in params:
params.set_hparam(key, value)
else:
params.add_hparam(key, value)
else:
# Now params is Python dict.
params[key] = value # ----> 就是这句话