Tensorflow之Estimator(二)实践

2023-11-19

1. 前言

这篇文章介绍Tensorflow的高级API,模型的建立和简化过程。

2. Estimator优势

本文档介绍了Estimator一种可极大地简化机器学习编程的高阶TensorFlow API。用了Estimator你会得到数不清的好处。

  • 您可以在本地主机上或分布式多服务器环境中运行基于 Estimator 的模型,而无需更改模型。此外,您可以在CPU、GPU或TPU上运行基于Estimator 的模型,而无需重新编码模型
  • 使用dataset高效处理数据,搭配上Estimator再GPU或者TPU上高效的运行模型,提高整体的模型运行的时间。
  • 使用Estimator编写应用时,您必须将数据输入管道从模型中分离出来。这种分离简化了不同数据集的实验流程
  • Estimator提供安全的分布式训练循环,可以控制如何以及何时:
    • 构建图
    • 初始化变量
    • 开始排队
    • 处理异常
    • 创建检查点文件并从故障中恢复
    • 保存 TensorBoard 的摘要
  • Estimator简化了在模型开发者之间共享实现的过程。
  • 您可以使用高级直观代码开发先进的模型。简言之,采用Estimator创建模型通常比采用低阶TensorFlow API更简单。
  • Estimator本身在tf.layers之上构建而成,可以简化自定义过程。

3. 预创建的Estimator

  • 编写一个或多个数据集导入函数
    • 一个字典,其中键是特征名称,值是包含相应特征数据的张量(或 SparseTensor)
    • 一个包含一个或多个标签的张量
def input_fn(dataset):
   # manipulate dataset, extracting the feature dict and the label
   return feature_dict, label
  • 定义特征列。每个tf.feature_column都标识了特征名称、特征类型和任何输入预处理操作。
# Define three numeric feature columns.
population = tf.feature_column.numeric_column('population')
crime_rate = tf.feature_column.numeric_column('crime_rate')
median_education = tf.feature_column.numeric_column('median_education',
                    normalizer_fn=lambda x: x - global_education_mean)
  • 实例化相关的预创建的Estimator。 例如,下面是对名为LinearClassifier预创建Estimator进行实例化的示例代码:
# Instantiate an estimator, passing the feature columns.
estimator = tf.estimator.LinearClassifier(
    feature_columns=[population, crime_rate, median_education],
    )
  • 调用训练、评估或推理方法。例如,所有 Estimator 都提供训练模型的 train 方法。
# my_training_set is the function created in Step 1
estimator.train(input_fn=my_training_set, steps=2000)

4. 自定义Estimator

4.1 input_fn输入函数

输入函数可以直接返回feature_dict, label,也可以返回的是dataset.make_one_shot_iterator(),这样就和我们高效的数据预处理接上了

def input_fn(features, labels, batch_size):
    """An input function for training"""
    # Convert the inputs to a Dataset.
    dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))

    # Shuffle, repeat, and batch the examples.
    dataset = dataset.shuffle(1000).repeat().batch(batch_size)

    # Return the read end of the pipeline.
    return dataset.make_one_shot_iterator().get_next()

4.2 feature_columns创建特征列

您必须定义模型的特征列来指定模型应该如何使用每个特征。无论是使用预创建的Estimator还是自定义Estimator,您都要使用相同的方式定义特征列。

以下代码为每个输入特征创建一个简单的 numeric_column,表示应该将输入特征的值直接用作模型的输入:

# Feature columns describe how to use the input.
my_feature_columns = []
for key in train_x.keys():
    my_feature_columns.append(tf.feature_column.numeric_column(key=key))

4.3 model_fn模型函数

def model_fn(
   features, # This is batch_features from input_fn
   labels,   # This is batch_labels from input_fn
   mode,     # An instance of tf.estimator.ModeKeys
   params):  # Additional configuration

前两个参数是从输入函数中返回的features和labels,mode参数表示调用程序是请求训练、预测还是评估。所以在model_fn里面需要实现训练、预测、评估3种请求方式。

调用程序可以将params传递给Estimator的构造函数。传递给构造函数的所有params 转而又传递给model_fn。

classifier = tf.estimator.Estimator(
    model_fn=my_model,
    params={
        'feature_columns': my_feature_columns,
        # Two hidden layers of 10 nodes each.
        'hidden_units': [10, 10],
        # The model must choose between 3 classes.
        'n_classes': 3,
    })

5. 定义模型

5.1 定义输入层

在 model_fn 的第一行调用 tf.feature_column.input_layer,以将特征字典和 feature_columns 转换为模型的输入,会应用特征列定义的转换,从而创建模型的输入层。如下所示:

# Use `input_layer` to apply the feature columns.
net = tf.feature_column.input_layer(features, params['feature_columns'])

5.2 隐藏层

如果您要创建深度神经网络,则必须定义一个或多个隐藏层。Layers API 提供一组丰富的函数来定义所有类型的隐藏层,包括卷积层、池化层和丢弃层。

隐藏层是用户自己发挥想象力,定义的可以很复杂的地方。

# Build the hidden layers, sized according to the 'hidden_units' param.
for units in params['hidden_units']:
    net = tf.layers.dense(net, units=units, activation=tf.nn.relu)

5.3 输出层

# Compute logits (1 per class).
logits = tf.layers.dense(net, params['n_classes'], activation=None)

tf.nn.softmax 函数会将这些对数转换为概率。

5.4 实现训练、评估和预测

创建模型函数的最后一步是编写实现预测、评估和训练的分支代码。

重点关注第三个参数 mode。如下表所示,当有人调用train、evaluate或predict时,Estimator框架会调用模型函数并将mode参数设置为ModeKeys.TRAIN,ModeKeys.EVAL,ModeKeys.PREDICT。

模型函数必须提供代码来处理全部三个mode值。对于每个mode值,您的代码都必须返回 tf.estimator.EstimatorSpec的一个实例,其中包含调用程序所需的信息。我们来详细了解各个mode。

  • 训练 ModeKeys.TRAIN

构建训练操作需要优化器。我们将使用 tf.train.AdagradOptimizer。
我们使用优化器的 minimize 方法根据我们之前计算的损失构建训练操作。
minimize 方法还具有 global_step 参数。

optimizer = tf.train.AdagradOptimizer(learning_rate=0.1)
train_op = optimizer.minimize(loss, global_step=tf.train.get_global_step())
if mode == tf.estimator.ModeKeys.TRAIN:
    return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)
  • 评估 ModeKeys.EVAL

虽然返回指标是可选的。TensorFlow 提供一个指标模块 tf.metrics 来计算常用指标。为简单起见,我们将只返回准确率。

# Compute evaluation metrics.
accuracy = tf.metrics.accuracy(labels=labels,
                               predictions=predicted_classes,
                               name='acc_op')
metrics = {'accuracy': accuracy}
tf.summary.scalar('accuracy', accuracy[1])

if mode == tf.estimator.ModeKeys.EVAL:
    return tf.estimator.EstimatorSpec(
        mode, loss=loss, eval_metric_ops=metrics)
  • 预测 ModeKeys.PREDICT

该模型必须经过训练才能进行预测。经过训练的模型存储在磁盘上,位于您实例化 Estimator 时建立的 model_dir 目录中。

此模型用于生成预测的代码如下所示:

# Compute predictions.
predicted_classes = tf.argmax(logits, 1)
if mode == tf.estimator.ModeKeys.PREDICT:
    predictions = {
        'class_ids': predicted_classes[:, tf.newaxis],
        'probabilities': tf.nn.softmax(logits),
        'logits': logits,
    }
    return tf.estimator.EstimatorSpec(mode, predictions=predictions)

predictions 存储的是下列三个键值对:

  • class_ids 存储的是类别 ID(0、1 或 2),表示模型对此样本最有可能归属的品种做出的预测。
  • probabilities 存储的是三个概率(在本例中,分别是 0.02、0.95 和 0.03)
  • logit 存储的是原始对数值(在本例中,分别是 -1.3、2.6 和 -0.9)

我们通过 predictions 参数(属于 tf.estimator.EstimatorSpec)将该字典返回到调用程序。Estimator 的 predict 方法会生成这些字典。

回到顶部

6. 实例化Estimator

通过 Estimator 基类实例化自定义 Estimator,如下所示:

# Build 2 hidden layer DNN with 10, 10 units respectively.
classifier = tf.estimator.Estimator(
    model_fn=my_model,
    params={
        'feature_columns': my_feature_columns,
        # Two hidden layers of 10 nodes each.
        'hidden_units': [10, 10],
        # The model must choose between 3 classes.
        'n_classes': 3,
    })

在这里,params 字典与 DNNClassifier 的关键字参数用途相同;即借助 params 字典,您无需修改 model_fn 中的代码即可配置 Estimator。

使用 Estimator 训练、评估和生成预测要用的其余代码与预创建的 Estimator 一章中的相同。例如,以下行将训练模型:

# Train the Model.
classifier.train(
    input_fn=lambda:iris_data.train_input_fn(train_x, train_y, args.batch_size),
    steps=args.train_steps)

7. 工作流程

  1. 假设存在合适的预创建的Estimator,使用它构建第一个模型并使用其结果确定基准。
  2. 使用此预创建的Estimator构建和测试整体管道,包括数据的完整性和可靠性。
  3. 如果存在其他合适的预创建的Estimator,则运行实验来确定哪个预创建的Estimator效果最好。
  4. 可以通过构建自定义Estimator进一步改进模型。
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Tensorflow之Estimator(二)实践 的相关文章

随机推荐

  • CryptoJS与JSEncrypt 加密算法

    crypto js进行AES加密 安装 npm i save crypto js jsencrypt进行RSA加密 安装 npm i save jsencrypt 官网 https github com travist jsencrypt
  • 微软Imagine Cup 2013大赛中国区CSDN高校俱乐部校区比赛成绩及获奖名单

    微软 Imagine Cup 2013 大赛已接近尾声 CSDN高校俱乐部首次参加此大赛 在中国赛区的比赛中 CSDN高校俱乐部校区取得了令人骄傲的成绩 在此向所有的参赛同学表示祝贺和感谢 同时 非常感谢各俱乐部的指导老师 主席 同学对CS
  • 路由期末复习(二)—配置命令

    这篇就专门说说关于配置的知识点 了解基础知识指路 目录 路由器 Telnet服务配置命令 路由器 SSH服务配置命令 SSH配置例子 重点 一图理解SSH配置 用FTP传输文件 使用TFTP传输文件 VLAN的基本配置 配置Hybrid端口
  • APP、软件版本号的命名规范与原则

    APP 软件版本号的命名规范与原则 为了在软件产品生命周期中更好的沟通和标记 我们应该对APP 软件的版本号命名的规范和原则有一定的了解 1 APP 软件的版本阶段 Alpha版 也叫 版 此版本主要是以实现软件功能为主 通常只在软件开发者
  • python中mysql的用法_Python中MySQL用法

    Python中MySQL用法 一 注意事项查看系统版本 arch命令 查看系统是64位还是32位 使用cat etc system release查看内核版本 注意安装MySQL的版本企业版 付费 社区版 免费 MariaDB 注意安装之后
  • java计算机毕业设计基于springboo+vue的汉服文化宣传活动交流网站(汉服社团)

    项目介绍 近年来 随着个人计算机的普及以及互联网的飞速发展 互联网逐渐成为人们获取信息的重要渠道 互联网的便捷性与实时性等特征 在方便人们获取自己感兴趣信息的同时 也在很大程度上为企事业单位节约了大量人力 物力 财力等运营成本 汉服交流网站
  • 【笔记】下单但未支付的订单倒计时自动取消逻辑实现

    平常我们都用过淘宝 京东这些电商平台 同时肯定也在这些平台上面下过单 这种情况不保证大家都有遇到过 但做开发的 肯定也知道有这个环节的存在 确认货品配置无误之后 我们都会点击购买 随之而来的就是一个结算页 让你确认商品信息 收货地址 价格等
  • ElementUi的el-tree组件样式修改

    ElementUi的el tree组件样式修改 需求如下 下拉图标的修改 element ui中的原本的基本样式是这样的 所以第一步呢 就是要把这个下拉按钮的样式修改成加号 在vue文件中 修改样式即可 vue的项目在写样式的时候 回家上s
  • join表连接的三种算法思想:Nested-Loop Join和Index Nested-Loop Join和Block Nested-Loop Join和BKA

    一 Nested Loop Join 在Mysql中 使用Nested Loop Join的算法思想去优化join Nested Loop Join翻译成中文则是 嵌套循环连接 举个例子 select from t1 inner join
  • ChatGPT能为留学生做什么?错误使用有何后果?

    随着AI人工智能行业的迅速发展 越来越多的学生开始利用ChatGPT等软件来获得更高效便利的论文和作业辅助 然而 我们需要认识到一个严肃的问题 学生是否过度依赖AI助手来完成毕业论文 近期出现的Turnitin AI Detector是一个
  • 在Windows下使用vs2019编译libjpeg库

    一 库的编译 1 下载 libjpeg 源码 这里我下载的是 jpegsr9e zip 2 解压源码 3 进入解压后的目录 找到 makefile vs 文件 用文本编辑器打开并编辑 找到 语句 include
  • 设备管理过程

    复杂度2 5 机密度2 5 最后更新2021 04 19 AIX中对设备会有如下五个操作 define aix下能看到设备的定义 但驱动程序并没有加载或初始化 该设备不可用 lsdev看到设备时defined 很多逻辑设备 vg lv等 只
  • CTF练题(5)word隐写基础题,jpg图片隐写,敲击码解密

    2022 11 2 两道misc题目 题目一 word隐写基础 题目信息如下 以及一个无法打开的word文档 解题步骤 1 将该word文档拖入010Editor中进行分析 发现文件头显示为PK 压缩文件 将该文档后缀改为 zip 保存到桌
  • go-zero使用Etcd进行服务注册代码分析

    代码分析 github com tal tech go zero v1 2 3 core discov publisher go package discov import github com tal tech go zero core
  • ld链接器的--start-group和--end-group参数说明

    start group archives end group The archives should be a list of archive files They may be either explicit file names or
  • OpenLayers与Bootstrap样式冲突的解决

    在引入Bootstrap响应式布局样式后 OpenLayers图层瓦片会显示异常 在页面中加入以下样式可以解决 参见 http openlayers org dev examples bootstrap html
  • linux网络95值工具,Linux下网络故障排查工具之ping

    服务器运维人员在日常运维服务器的过程中经常会遇到服务器网络故障 有服务器硬件造成的 也有服务商网络问题造成的 也有区域网络问题造成的 这个时候就需要用到ping traceroute mtr这三个命令 1 ping 最简单的网络请求反馈命令
  • 粒子群算法优化的最小二乘支持向量机分类代码

    粒子群算法优化的最小二乘支持向量机分类代码 在数据挖掘和机器学习领域中 分类是一个非常基础而重要的问题 其中最小二乘支持向量机 LSSVM 是一种有效的分类方法 经常被应用于实际问题中 而粒子群算法 PSO 是一种优化算法 也可以用来优化L
  • C++之函数模板

    1 什么是模板 模板有什么作用 模板分为函数模板和类模板 函数模板是对函数功能框架的描述 具体功能由实际传递的参数决定 有了函数模板 编译器就会根据模板自动生成多个函数名相同 参数列表不同的函数 不需要手动写 例 求一个矩形面积 当传入的长
  • Tensorflow之Estimator(二)实践

    1 前言 这篇文章介绍Tensorflow的高级API 模型的建立和简化过程 2 Estimator优势 本文档介绍了Estimator一种可极大地简化机器学习编程的高阶TensorFlow API 用了Estimator你会得到数不清的好