机器学习的训练数据(Training Dataset)、测试数据(Testing Dataset)和验证数据(Validation Dataset)

2023-10-27

三者的意义
- 训练数据:用来训练模型的数据
- 验证数据:用来检验模型准确率
- 测试数据:再一次确认验证数据集中的模型是好的模型。

一般步骤:

测试数据集和验证数据的数据一定不能用来训练,否则会出现过拟合的现象

代码:

import math
import os

from IPython import display
from matplotlib import cm
from matplotlib import gridspec
from matplotlib import pyplot as plt
from sklearn import metrics
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.python.data import Dataset

tf.logging.set_verbosity(tf.logging.INFO)
pd.options.display.max_rows = 10

# 读取数据
if os.path.exists('data.csv'):
    california_housing_dataframe = pd.read_csv('data.csv', sep=',')
else:
    california_housing_dataframe = pd.read_csv(
        "https://storage.googleapis.com/mledu-datasets/california_housing_train.csv", sep=",")
    california_housing_dataframe.to_csv('data.csv')


def preprocess_features(california_housing_dataframe):
    '''
    准备输入测试数据特征列
    :param california_housing_dataframe:
    :return:
    '''
    selected_features = california_housing_dataframe[
        [
            'latitude',
            'longitude',
            'housing_median_age',
            'total_rooms',
            'total_bedrooms',
            'population',
            'households',
            'median_income'
        ]
    ]
    processed_features = california_housing_dataframe.copy()
    # 添加人均住房面积
    processed_features['rooms_per_person'] = (
            california_housing_dataframe['total_rooms'] /
            california_housing_dataframe['population']
    )
    return processed_features


def preprocess_targets(california_housing_dataframe):
    '''
    目标值输入函数
    :param california_housing_dataframe:
    :return: 房子的价值
    '''
    output_targets = pd.DataFrame()
    output_targets['median_house_value'] = (
            california_housing_dataframe['median_house_value'] / 1000.0
    )
    return output_targets


# 传入12000组数据
training_examples = preprocess_features(california_housing_dataframe.head(12000))
training_examples.describe()

# 传入12000组目标值
training_targets = preprocess_targets(california_housing_dataframe.head(12000))
training_targets.describe()

# 验证数据集
validation_examples = preprocess_features(california_housing_dataframe.tail(5000))
validation_examples.describe()

validation_targets = preprocess_targets(california_housing_dataframe.tail(5000))
validation_targets.describe()

plt.figure(figsize=(13, 8))
ax = plt.subplot(1, 2, 1)
ax.set_title("Validation Data")
ax.set_ylim([32, 43])
ax.set_autoscalex_on(False)
ax.set_xlim([-126, -112])
plt.scatter(validation_examples['longitude'],
            validation_examples['latitude'],
            cmap="coolwarm",
            c=validation_targets['median_house_value'] / training_targets['median_house_value'].max()
            )
plt.plot()


def my_input_fn(features, targets, batch_size=1, shuffle=True, num_epochs=None):
    '''
    输入函数
    :param features: 特征列
    :param targets: 目标值
    :param batch_size: batch size
    :param shuffle: 是否乱序
    :param num_epochs: epoch的数量
    :return: 一个迭代批次数据,包含特征列和标签
    '''
    ds = Dataset.from_tensor_slices((dict(features), targets))
    ds = ds.batch(batch_size=batch_size)
    if shuffle:
        ds.shuffle(buffer_size=10000)
    features, labels = ds.make_one_shot_iterator().get_next()
    return features, labels


def construct_feature_columns(input_features):
    '''
    :param input_features:特征
    :return: 构造的特征列
    '''
    return set([tf.feature_column.numeric_column(key=my_feature)
                for my_feature in input_features])


def train_model(learning_rate,
                steps,
                batch_size,
                training_examples,
                training_targets,
                validation_examples,
                validation_targets):
    periods = 10
    steps_per_period = steps // periods
    my_optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate)
    my_optimizer = tf.contrib.estimator.clip_gradients_by_norm(my_optimizer, 5.0)
    linear_regressor = tf.estimator.LinearRegressor(
        feature_columns=construct_feature_columns(training_examples),
        optimizer=my_optimizer
    )
    print("Training model.....")
    training_rmse = []
    validation_rmse = []
    for period in range(0, periods):
        # 开始训练
        linear_regressor.train(
            input_fn=lambda: my_input_fn(
                training_examples,
                training_targets['median_house_value'],
                batch_size=batch_size
            ),
            steps=steps_per_period
        )

        # 预测数据集的处理
        training_predictions = linear_regressor.predict(
            input_fn=lambda: my_input_fn(
                training_examples,
                training_targets['median_house_value'],
                num_epochs=1,
                shuffle=False
            )
        )
        training_predictions = np.array([item['predictions'][0] for item in training_predictions])

        # 验证数据集的处理
        validation_predictions = linear_regressor.predict(
            input_fn=lambda: my_input_fn(
                validation_examples,
                validation_targets['median_house_value'],
                num_epochs=1,
                shuffle=False
            )
        )
        validation_predictions = np.array([item['predictions'][0] for item in validation_predictions])

        tmp_training_rmse = math.sqrt(
            metrics.mean_squared_error(training_predictions, training_targets)
        )
        tmp_validation_rmse = math.sqrt(
            metrics.mean_squared_error(validation_predictions, validation_targets)
        )
        print("period %02d: %0.2f" % (period, tmp_training_rmse))
        training_rmse.append(tmp_training_rmse)
        validation_rmse.append(tmp_validation_rmse)
    print("Model training finished")

    # 输出结果图
    plt.ylabel("RMSE")
    plt.xlabel("Periods")
    plt.title("Root Mean Squred Error vs. Periods")
    plt.tight_layout()
    plt.plot(training_rmse, labels="training")
    plt.plot(validation_rmse, labels="validation")
    plt.legend()

    return linear_regressor


# 训练模型
linear_regressor = train_model(
    learning_rate=0.00003,
    steps=500,
    batch_size=1,
    training_examples=training_examples,
    training_targets=training_targets,
    validation_examples=validation_examples,
    validation_targets=validation_targets
)

# 在测试数据集上评估模型
california_housing_test_data = pd.read_csv('data.csv', sep=',')
test_examples = preprocess_features(california_housing_test_data)
test_targets = preprocess_targets(california_housing_test_data)

predict_test_input_fn = lambda: my_input_fn(
    test_examples,
    test_targets['median_house_value'],
    num_epochs=1,
    shuffle=False
)

test_predictions = linear_regressor.predict(input_fn=predict_test_input_fn)
test_predictions = np.array([item['predictions'][0] for item in test_predictions])

RMSE = math.sqrt(
    metrics.mean_squared_error(test_predictions, test_targets)
)

print("Final RMSE (on test data): %0.2f" % RMSE)

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

机器学习的训练数据(Training Dataset)、测试数据(Testing Dataset)和验证数据(Validation Dataset) 的相关文章

  • Rabbit MQ详解

    一 什么是RabbitMQ 答 RabbitMQ简称MQ是一套实现了高级消息队列协议的开源消息代理软件 简单来说就是一个消息中间件 是一种程序对程序的通信方法 其服务器也是以高性能 健壮以及可伸缩性出名的Erlang语言编写而成 二 Rab
  • nc文件经度从0-360更改为-180到180,并保存

    从0 360改为 180到180 import xarray as xr rawnc path InPath ds xr open dataset rawnc path lon name lon 你的nc文件中经度的命名 ds longit
  • Python数据分析与机器学习----收入的预测分析

    一 题目 利用age workclass native country等13个特征预测收入是否超过50k 是一个二分类问题 二 训练集 32561个样本 每个样本14个特征 其中6个连续性特征 9个离散型特征 三 测试集 16281个样本
  • Open3D(C++) 四元数奇异值分解

    目录 一 算法原理 1 原理概述 2 实现过程 3 参考文献 二 代码实现 三 结果展示 本文由CSDN点云侠原创 原文链接 如果你不是在点云侠的博客中看到该文章 那么此处便是不要脸的爬虫 一 算法原理 1 原理概述 四元数矩阵的奇异值分解

随机推荐

  • java继承层次结构,在状态模式中实现继承层次结构 - java

    我有一个与此非常相似的设计 这里的NewOrder Registered Granted都有通用方法AddOrderline 和Cancel 因此将这两种方法重构为父类很容易 当我要Cancel一条Shipped行 当前未在图中显示 时 会
  • SegNetr: 重新思考 U 形网络中的局部-全局交互和跳过连接

    SegNetr 会议分析 摘要 贡献 方法 整体框架 1 SegNetr Block 2 Information Retention Skip Connection 实验 1 对比实验 2 消融实验 2 1 Effect of local
  • tslib移植的问题:No raw modules loaded.ts_config:No such file or directory

    1 在开发板上运行校正程序时出现No raw modules loaded 解决方法是把 tslib etc目录下的ts conf 的 module raw input 的注释符号 去掉 但记住不要在前面留有 空格 否则会出现错误Segme
  • python 打开读取文件 出现异常 关闭文件的处理(世界上没有傻问题!但我是个傻子)

    事情梗概 try 尝试读取一个不存在的文件 except Exception as e 打印异常 finally 关闭文件 但是关闭文件时报异常 算了 看代码吧 try f open file name rb file data f rea
  • Vue.js的组件化开发

    组件化开发 什么是组件 web中的组件其实就是页面组成的一部分 好比是电脑中的每一个元件 如硬盘 键盘 鼠标 它是一个具有独立的逻辑和功能或界面 同时又能根据规定的接口规则进行相互融化 变成一个完整的应用 页面就是由一个个类似这样的组成部分
  • iOS开源系列——下拉刷新控件

    EGOTableViewPullRefresh FaceBook开源控件 下拉刷新的鼻祖 SVPullToRefresh 下拉刷新控件 MJRefresh 比较好用的下拉刷新 可以自定义上下拉刷新的文字说明 具体使用看 使用方法 国人写 X
  • 中间件的分类和作用

    要说清这个问题我们用一个生活中的实例来比喻 把分布式系统看作北京市区的交通系统 网络看作市区马路 通过交通工具 汽车 实现通信 每分钟将有几万辆车在马路上行驶 如果没有相应的交通设施和管理规划 北京市将会乱成一团 发生各种交通事故 1 通信
  • java各种报错汇总与分析

    1 没有找到pom文件 需要设置版本号 在这里插入图片描述 https img blog csdnimg cn 20210720112611634 png pic center 解决办法 https blog csdn net SSband
  • 从2023蓝帽杯0解题heapSpary入门堆喷

    从2023蓝帽杯0解题heapSpary入门堆喷 关于堆喷 堆喷射 Heap Spraying 是一种计算机安全攻击技术 它旨在在进程的堆中创建多个包含恶意负载的内存块 这种技术允许攻击者避免需要知道负载确切的内存地址 因为通过广泛地 喷射
  • adb shell 小米手机_【ADB命令实战】免ROOT停用小米手机系统应用

    对于未解锁的手机 总存在那么一些我们用不到 甚至看都不想看到的应用 但是没办法卸载 在这里提供一些禁用掉这些应用的方法供参考 本内容是以小米的MIUI系统为例 其他品牌机型不确保可以成功 毕竟系统应用的包名是不一样的 需要自己去发现 1 打
  • linux-hd.c

    linux kernel hd c C 1991 Linus Torvalds This is the low level hd interrupt support It traverses the request list using i
  • 数据结构与算法课程笔记(二)

    实验二 线性表的顺序存储结构实现 一 实验目的 二 实验内容 一 实验目的 熟悉VC 工程项目的文件组织方式 线性表中数据元素间的关系及其顺序存储结构方式表示方法 顺序表的操作方法与接口函数的设计方法 二 实验内容 1 利用本次实验提供的文
  • element input复合框 修改下拉框样式

    element input复合框 修改下拉框样式 1 项目中经常会遇到修改ui组件库样式的问题 elemetui官网自带样式是这样的 我想修改选中颜色 以及背景颜色 这样设置发现不生效 加上 popper append to body fa
  • python 翻译模块 翻译API使用(百度、有道、谷歌)

    1 翻译模块 api使用分析 1 translate库 使用简单 但是有次数限制 翻译的准确性中等 2 百度api 推荐使用 代码简单 有模块 但是需要注册 获取key值 翻译的准确性中下 3 chrome翻译api 代码复杂 次数限制 但
  • java8的常用的新特性

    Java 8引入了许多新的特性 下面列举了一些常用的新特性 Lambda表达式 Lambda表达式是Java 8中引入的一种函数式编程特性 提供了一种更简洁和灵活的方式来编写匿名函数 方法引用 方法引用允许直接引用已经存在的方法作为Lamb
  • 华为od机试面试题目

    1 华为机试102道题解 2 华为机考题目 2023年7月30日 19 30 22 00 机考提示 注意事项 考前必看 1 注意编译环境的变化及语言选择 选自己熟悉的语言机考 2 机考共3道题 150分钟完成 3 题目难度为 一星和两星 2
  • JS——Mediator(中介者)模式

    我们从日常的生活中打个简单的比方 我们去房屋中介租房 房屋中介人在租房者和房东出租者之间形成一条中介 租房者并不关心他租谁的房 房东出租者也不关心他租给谁 因为有中介的存在 这场交易才变得如此方便 在软件的开发过程中 势必会碰到这样一种情况
  • 自己定义控件

    http blog csdn net lmj623565791 article details 38173061
  • qt信号槽连接方式Qt::UniqueConnection的使用

    qt信号槽连接方式Qt UniqueConnection的使用 qt信号槽连接方式一共有以下五种 具体方式不在一一赘述 本文记录第五种Qt UniqueConnection的使用方法 Qt AutoConnection Qt DirectC
  • 机器学习的训练数据(Training Dataset)、测试数据(Testing Dataset)和验证数据(Validation Dataset)

    三者的意义 训练数据 用来训练模型的数据 验证数据 用来检验模型准确率 测试数据 再一次确认验证数据集中的模型是好的模型 一般步骤 测试数据集和验证数据的数据一定不能用来训练 否则会出现过拟合的现象 代码 import math impor