鸢尾花分类

2023-11-04

鸢尾花数据集

鸢尾花数据集包含四个特征和一个标签。这四个特征确定了单株鸢尾花的下列植物学特征:花萼长度、花萼宽度、花瓣长度、花瓣宽度。我们的模型会将这些特征表示为float32数值数据。 该标签确定了鸢尾花品种,品种必须是下列任意一种:山鸢尾 (0)、变色鸢尾 (1)、维吉尼亚鸢尾 (2)。我们的模型会将该标签表示为int32分类数据。

算法

该程序会训练一个具有以下拓扑结构的深度神经网络分类器模型:2 个隐藏层,每个隐藏层包含 10 个节点。

下图展示了特征、隐藏层和预测(并未显示隐藏层中的所有节点):

 

推理

在无标签样本上运行经过训练的模型会产生三个预测,即相应鸢尾花属于指定品种的可能性。这些输出预测的总和是 1.0。例如,对无标签样本的预测可能如下所示:0.03(山鸢尾)、0.95(变色鸢尾)、0.02(维吉尼亚鸢尾)。该预测表示指定无标签样本是变色鸢尾的概率为 95%。

数据分析

所有功能的出发点都是数据,下面对数据进行全面的分析。首先,这个数据集是网上公开的一个CSV格式的数据集,可以将数据下载,使用pandas对数据进行分析。在此处已经将数据下载好并放在data文件下。

下面使用pandas对数据进行读取,并部分展示,由于下载的数据自带格式所以做了一下处理。

import pandas as pd
import tensorflow as tf
CSV_COLUMN_NAMES = ['SepalLength','SepalWidth','PetalLength', 'PetalWidth', 'Species']
data_train=pd.read_csv('./data/iris_test.csv',names=CSV_COLUMN_NAMES,header=0)
data_test=pd.read_csv('./data/iris_training.csv',names=CSV_COLUMN_NAMES,header=0)
data_train.head()

Out[1]:

  SepalLength SepalWidth PetalLength PetalWidth Species
0 5.9 3.0 4.2 1.5 1
1 6.9 3.1 5.4 2.1 2
2 5.1 3.3 1.7 0.5 0
3 6.0 3.4 4.5 1.6 1
4 5.5 2.5 4.0 1.3 1

 

构建模型

在构建模型结构时暂时不需要数据的输入,但需要制定有几个特征需要输入,从上面的结构中可以看出。输入的特征有四个:花萼长度(seqpallength),花萼宽度(sepalwidth),花瓣长度 (petallength),花瓣宽度 (petalwidth)。而最后一个是预测值,species 代表他是哪种鸢尾花:山鸢尾(0),变色鸢尾(1),维吉尼亚鸢尾 (2)。所以我们先将属性值与标记值区分开

train_x, train_y = data_train, data_train.pop('Species')
test_x, test_y = data_test, data_test.pop('Species')
train_x.head()

 

Out[2]:

  SepalLength SepalWidth PetalLength PetalWidth
0 5.9 3.0 4.2 1.5
1 6.9 3.1 5.4 2.1
2 5.1 3.3 1.7 0.5
3 6.0 3.4 4.5 1.6
4 5.5 2.5 4.0 1.3
my_feature_columns = []  #从train_x提取特征值
for key in train_x.keys():
    my_feature_columns.append(tf.feature_column.numeric_column(key=key))
print(my_feature_columns)

 

[_NumericColumn(key='SepalLength', shape=(1,), default_value=None, dtype=tf.float32, normalizer_fn=None), _NumericColumn(key='SepalWidth', shape=(1,), default_value=None, dtype=tf.float32, normalizer_fn=None), _NumericColumn(key='PetalLength', shape=(1,), default_value=None, dtype=tf.float32, normalizer_fn=None), _NumericColumn(key='PetalWidth', shape=(1,), default_value=None, dtype=tf.float32, normalizer_fn=None)]

下面开始构建训练模型

classifier = tf.estimator.DNNClassifier(
   # 这个模型接受哪些输入的特征
    feature_columns=my_feature_columns,
    # 包含两个隐藏层,每个隐藏层包含10个神经元.
    hidden_units=[10, 10],
    # 最终结果要分成的几类
    n_classes=3)
 

 

INFO:tensorflow:Using default config.
WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmpg62ivkem
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpg62ivkem', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x7efc2dc54160>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}

训练模型

下面就需要为模型提供数据并进行训练,由于tensorflow采用一个批量梯度下降算法更新参数,这里可以构造一个函数来生成数据,并且可以在这个函数当中对数据进行打乱。

def train_func(train_x,train_y):
    dataset=tf.data.Dataset.from_tensor_slices((dict(train_x), train_y))
    dataset = dataset.shuffle(1000).repeat().batch(100)
    return dataset

下面可以进行模型训练,进行1000个回合的训练,每次调100的数据。

classifier.train(
    input_fn=lambda:train_func(train_x,train_y),
    steps=1000)

 

INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpg62ivkem/model.ckpt.
INFO:tensorflow:loss = 131.49458, step = 0
INFO:tensorflow:global_step/sec: 230.342
INFO:tensorflow:loss = 11.232605, step = 100 (0.435 sec)
INFO:tensorflow:global_step/sec: 508.696
INFO:tensorflow:loss = 8.845632, step = 200 (0.196 sec)
INFO:tensorflow:global_step/sec: 509.375
INFO:tensorflow:loss = 7.626791, step = 300 (0.196 sec)
INFO:tensorflow:global_step/sec: 515.75
INFO:tensorflow:loss = 6.2779765, step = 400 (0.194 sec)
INFO:tensorflow:global_step/sec: 459.292
INFO:tensorflow:loss = 6.688465, step = 500 (0.218 sec)
INFO:tensorflow:global_step/sec: 509.881
INFO:tensorflow:loss = 6.19946, step = 600 (0.196 sec)
INFO:tensorflow:global_step/sec: 523.267
INFO:tensorflow:loss = 6.8552265, step = 700 (0.191 sec)
INFO:tensorflow:global_step/sec: 526.077
INFO:tensorflow:loss = 6.5732956, step = 800 (0.190 sec)
INFO:tensorflow:global_step/sec: 519.665
INFO:tensorflow:loss = 6.094924, step = 900 (0.192 sec)
INFO:tensorflow:Saving checkpoints for 1000 into /tmp/tmpg62ivkem/model.ckpt.
INFO:tensorflow:Loss for final step: 5.365243.

Out[6]:

<tensorflow.python.estimator.canned.dnn.DNNClassifier at 0x7efc2dc54b00>

模型预测

可以使用下面方法对测试集的数据进行预测,并查看效果

def eval_input_fn(features, labels, batch_size):
    features=dict(features)
    if labels is None:
        # No labels, use only features.
        inputs = features
    else:
        inputs = (features, labels)
    dataset = tf.data.Dataset.from_tensor_slices(inputs)
    assert batch_size is not None, "batch_size must not be None"
    dataset = dataset.batch(batch_size)
    return dataset
predict_arr = []
predictions = classifier.predict(
        input_fn=lambda:eval_input_fn(test_x,labels=test_y,batch_size=100))
for predict in predictions:
    predict_arr.append(predict['probabilities'].argmax())
result = predict_arr == test_y
result1 = [w for w in result if w == True]
print("准确率为 %s"%str((len(result1)/len(result))))

INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Restoring parameters from /tmp/tmpg62ivkem/model.ckpt-1000 INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. 准确率为 0.9833333333333333

 

 

 

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

鸢尾花分类 的相关文章

随机推荐

  • [CISCN 2022 初赛]login_normal

    叠甲 菜 很菜 就是懂一点但是不多 可能涉及的错误会很多 有错误欢迎指出 同时对于这个疑问有解答的也欢迎留言 总之就是很菜 QAQ 这一道题 首先要考代码审计 来猜它这个 login 的格式 然后在通过它的 login 之后 通过传入可见字
  • 【Android】ViewModel原理分析

    概述 本文主要通过分析ViewModel源码解决以下两个疑问 1 ViewModel如何保证的唯一性 2 ViewModel如何保证数据不丢失的 为了解决这些问题 从ViewModel的构造开始 一般创建ViewModel的方法如下 Vie
  • 《消息队列高手课》内存管理:如何避免内存溢出和频繁的垃圾回收?

    不知道你有没有发现 在高并发 高吞吐量的极限情况下 简单的事情就会变得没有那么简单了 一个业务逻辑非常简单的微服务 日常情况下都能稳定运行 为什么一到大促就卡死甚至进程挂掉 再比如 一个做数据汇总的应用 按照小时 天这样的粒度进行数据汇总都
  • SQL Server用户登录失败

    SQL Server数据库中 如果我们忘记了 sa密码 又删除了jhyf kj administrators帐号 我们可以用下面的方法来修复 1 首先停止所有与SQLServer相关的服务 net stop SQL Server Integ
  • Spring Boot全面总结(超详细,建议收藏)

    前言 本文非常长 建议先mark后看 也许是最后一次写这么长的文章 说明 前面有4个小节关于Spring的基础知识 分别是 IOC容器 JavaConfig 事件监听 SpringFactoriesLoader详解 它们占据了本文的大部分内
  • 2021极客大挑战web部分wp

    Dark 看到url http c6h35nlkeoew5vzcpsacsidbip2ezotsnj6sywn7znkdtrbsqkexa7yd onion 发现后缀为 onion 为洋葱 下载后使用洋葱游览器访问 Welcome2021
  • git学习:github上传自己的代码到别人的仓库

    转载 原博客链接 总结 向别人贡献自己的代码 和传到自己仓库的区别 要先fork转化 clone仓库文件到电脑本地 然后进入文件夹 若想提交到非默认分支 要先git checkout到分支 pull分支下的最新代码 若还想创建新分支 用gi
  • 入門篇-耦合Coupling AC/DC/GND差別在哪

    摘自 https www strongpilab com p 156 示波器操作 入門篇 耦合Coupling AC DC GND差別在哪 2016 06 26 儀器 Instrument 示波器 Scope 0 示波器的Vertical選
  • Crested Ibis vs Monster——AT动态规划思想

    题目描述 Ibis is fighting with a monster The health of the monster is H Ibis can cast N kinds of spells Casting the i th spe
  • 对caffe2的一些初步体会(草稿)

    Caffe2的一些关键设计思想 所有运算都抽象为Operator Blob和Tensor的概念 Blob和Net都存放在Workspace中 一个Workspace中可以有多个Net 这些Net中使用到的相同名称的Blob实际对应于这个Wo
  • 图数据库nebula

    目录 1 查询方式 按需不需要基于索引查询 可以分为两类 为什么有的需要索引 go 依据路径查询属性 fetch 获取指定边 点的属性值 lookup match 1 查询方式 nebula可以用来查询的语句关键字主要有 GO FETCH
  • virt与virsh常用命令

    前提 客户机虚拟机上配置qemu guest agent 并对guest的xml配置文件做一些修改 那么就可以使用很多特有的命令 对虚拟机进行配置 例如 修改虚拟机密码 root localhost virsh set user passw
  • Excel工具类

    目录 1导入导出 2测试 2 1导入测试 2 1 1JSON导入 2 1 2对象导入 2 2导出测试 2 2 1导出模版 2 2 2导出用户表 3依赖 4工具包 1导入导出 UserImport package com excel enti
  • 关于嵌入式系统的学习路线图

    来源 本文乃同济大学软件学院王院长 JacksonWan 在同济网论坛发表的帖子 谈谈软件学院高年级同学的学习方向 的第二部分 三部分依次为 一 关于企业计算方向 二 关于嵌入式系统方向 三 关于游戏软件方向 嵌入式系统方向 嵌入式系统无疑
  • Java加密技术(九)——初探SSL

    在 Java加密技术 八 中 我们模拟了一个基于RSA非对称加密网络的安全通信 现在我们深度了解一下现有的安全网络通信 SSL 我们需要构建一个由CA机构签发的有效证书 这里我们使用上文中生成的自签名证书 zlex cer 这里 我们将证书
  • 2021-06-24

    daily plan 2021 05 2021年05月 2021 05 06 Spring概要与入门 https www yuque com haohaoxuexicainengtiantianxiangshang ldmxww eb8tv
  • USB fastboot

    1 Samsung fastboot flashing unlock 2 bootloader增加解锁密码 diff git a app aboot aboot c b app aboot aboot c index e4d46e4 1b4
  • win10安装.NET Framework 4.5.2时会提示:这台计算机中已经安装了 .NET Framework 4.5.2 或版本更高的更新

    问题现象 win10安装 NET Framework 4 5 2时会提示 这台计算机中已经安装了 NET Framework 4 5 2 或版本更高的更新 问题原因 Win10系统自带的 net framework版本为4 7 问题解决 1
  • 1352--奖金(拓扑排序)

    输入样例 2 1 1 2 输出样例 201 解析 拓扑排序 判断是否存在结果 include
  • 鸢尾花分类

    鸢尾花数据集 鸢尾花数据集包含四个特征和一个标签 这四个特征确定了单株鸢尾花的下列植物学特征 花萼长度 花萼宽度 花瓣长度 花瓣宽度 我们的模型会将这些特征表示为float32数值数据 该标签确定了鸢尾花品种 品种必须是下列任意一种 山鸢尾