机器学习(2)——鸢尾花数据集

2023-11-06

在上次房价数据集中做出一些改进,对鸢尾花数据集进行预测。

需要导入的库

from sklearn.datasets import load_iris #导入鸢尾花数据集
from sklearn.linear_model import LogisticRegression #导入sklearn中的逻辑回归模型
from sklearn.model_selection import train_test_split,cross_val_score #导入数据集的划分和交叉验证函数
import matplotlib.pyplot as plt
import paddle
import numpy as np
import paddle.nn as nn

读取鸢尾花数据集,将其分为训练集和测试集

#数据读取
iris = load_iris()
iris_x = iris.data
iris_y = iris.target
train_x,test_x,train_y,test_y = train_test_split(iris_x,iris_y,test_size=0.3)   #划分数据集和测试集
train_data = np.insert(train_x, 4, train_y, 1)
test_data = np.insert(test_x, 4, test_y, 1)

通过 train_test_split()函数对数据集进行划分,设置测试数据占总样本的0.3

通过np.insert函数插入矩阵

a=np.insert(arr, obj, values, axis)
#arr原始数组,可一可多,obj插入元素位置,values是插入内容,axis是按行按列插入(0:行、1:列)。

即实现了target和data的合并
 

构建datasets类,和房价预测一模一样

# create datasets类
# 三个必须的函数:1.构造函数(初始化工作)2.__getitem__函数(根据index确保数据能被找到并返回这一行数据)3.返回数据集长度
# 不可或缺,定义错误会导致dataloader使用时出错;
class MyDataset(paddle.io.Dataset):
    """
    继承paddle.io.Dataset类
    """
    def __init__(self, data):
        """
        实现构造函数(初始化这个class)
        """
        super(MyDataset, self).__init__()
        self.data = data

    def __getitem__(self, index):
        """
        步骤三:实现__getitem__方法,指定index-->返回数据
        """
        data = self.data[index]
        x_data = data[:-1]
        label = data[-1]
        return x_data, label

    def __len__(self):
        """
        步骤四:实现__len__方法,返回数据集长度
        """
        return self.data.shape[0]

构建神经网络结构 

注意:输入单元为4,输出单元为3

## create model structure
# 两个必须的函数:1.构造函数(初始化网络结构)2.forwad函数(定义前向传播过程)
class Mymodel(paddle.nn.Layer):
    def __init__(self):
        super(Mymodel, self).__init__()
        self.linear1 = nn.Linear(4, 3)   #维度(4,3)

    def forward(self, inputs):
        y = self.linear1(inputs)
        return y

custom_dataset_train = MyDataset(train_data)
train_loader = paddle.io.DataLoader(custom_dataset_train, batch_size=50, shuffle=True,drop_last=True)
custom_dataset_test = MyDataset(test_data)
test_loader = paddle.io.DataLoader(custom_dataset_test, batch_size=len(test_data), shuffle=False)

lr_model = Mymodel()
optim = paddle.optimizer.Adam(parameters=lr_model.parameters(), learning_rate=0.1)
# 设置损失函数
loss_fn = paddle.nn.CrossEntropyLoss()

训练过程

max_epoch = 200
for epoch in range(max_epoch):
    lr_model.train()
    train_loss = []
    for batch_id, (x_data,y_data) in enumerate(train_loader()):
        # pay attention to this dtype
        x_data = paddle.to_tensor(x_data,dtype="float32")
        y_data = paddle.to_tensor(y_data,dtype="int64")
        optim.clear_grad()
        y_hat = lr_model(x_data)
        loss = loss_fn(y_hat,y_data)
        loss.backward()
        optim.step()
        train_loss.append(loss.item())
        import pdb
        # pdb.set_trace()
    lr_model.eval()
    for batch_id, data in enumerate(test_loader()):
        x_data = paddle.to_tensor(x_data,dtype="float32")
        y_data = paddle.to_tensor(y_data,dtype="int64")
        y_hat = lr_model(x_data)
        loss_test = loss_fn(y_hat,y_data)
    train_loss = np.mean(train_loss)
    print("epoch:"+str(epoch)+"\t train loss:" + str(round(train_loss,4)) + "\t test loss:" + str(round(loss_test.item(),4)))

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

机器学习(2)——鸢尾花数据集 的相关文章

  • 如何在 Debian 上的 virtualenv 中安装 numpy?

    注 参见这另一篇文章 https stackoverflow com questions 6442754 how to install h5py numpylibhdf5 as non root on a debian linux syst
  • scipy.optimize on pandas dataframe

    我试图搜索它 但结果很差 有人可以向我解释一下如何在 Pandas DataFrame 上执行 optimize minimize 以便最小化 DataFrame 中的类别和结果列之间的错误 考虑这个例子 import pandas as
  • 键入的完整命令行

    我想获得输入时的完整命令行 This join sys argv 在这里不起作用 删除双引号 另外 我不想重新加入已解析和拆分的内容 有任何想法吗 你太迟了 当键入的命令到达 Python 时 您的 shell 已经发挥了它的魔力 例如 引
  • Flask 中“缺少 CSRF 令牌”,但它在模板中呈现

    问题 当我尝试登录 使用 Flask login 时 我得到Bad Request The CSRF session token is missing但令牌正在呈现 在模板中 secret key 已设置 并且我在本地运行localhost
  • 无法在 virtualenv 中安装 libxml2

    我有一个问题libxml2蟒蛇模块 我正在尝试将其安装在python3 虚拟环境使用以下命令 pip install libxml2 python3 但它显示以下错误 Collecting libxml2 python3 Using cac
  • 查找正在导入哪些 python 模块

    从应用程序中使用的特定包中查找所有 python 模块的简单方法是什么 sys modules是将模块名称映射到模块的字典 您可以检查其键以查看导入的模块 See http docs python org library sys html
  • 查找与另一列 Pandas 中的唯一值关联的列中的值的交集

    如果我有一个像这样的数据框 非常小的例子 col1 col2 0 a 1 1 a 2 2 b 1 3 b 2 4 b 4 5 c 1 6 c 2 7 c 3 我想要所有的交集col2当价值观与其独特性相关时col1值 因此在这种情况下 交集
  • 高级描述熊猫

    有没有像 pandas 那样更高级的功能 通常我会继续这样 r pd DataFrame np random randn 1000 columns A r describe 我会得到一份很好的总结 就像这样 A count 1000 000
  • 将 window.location 传递给 Flask url_for

    我正在使用 python 在我的页面上 当匿名用户转到登录页面时 我想将一个变量传递到后端 以便它指示用户来自哪里 发送 URL 因此 当用户单击此锚链接时 a href Sign in a 我想发送用户当前所在页面的当前 URL
  • 为什么 Collections.counter 这么慢?

    我正在尝试解决罗莎琳德的基本问题 即计算给定序列中的核苷酸 并在列表中返回结果 对于那些不熟悉生物信息学的人来说 它只是计算字符串中 4 个不同字符 A C G T 出现的次数 我期望collections Counter是最快的方法 首先
  • 如何获取分类数据的分组条形图

    I have a big dataset with information about students And I have to build a graph of dependencies between different value
  • 如何按 pandas 中的值对系列进行分组?

    我现在有一只熊猫Series与数据类型Timestamp 我想按日期对其进行分组 并且每组中有许多行具有不同的时间 看似显而易见的方法类似于 grouped s groupby lambda x x date 然而 熊猫的groupby按索
  • 如何通过 Python Requests 库使用基本 HTTP 身份验证?

    我正在尝试在 Python 中使用基本的 HTTP 身份验证 我正在使用Requests https docs python requests org 图书馆 auth requests post http hostname auth HT
  • Python 2.7 缩进错误[关闭]

    Closed 这个问题不符合堆栈溢出指南 help closed questions 目前不接受答案 这个问题是由拼写错误或无法再重现的问题引起的 虽然类似的问题可能是on topic help on topic在这里 这个问题的解决方式不
  • 如何展平解析树并存储在字符串中以进行进一步的字符串操作 python nltk

    我正在尝试从树结构中获取扁平树 如下所示 我想将整个树放在一个字符串中 就像没有检测到坏树错误一样 S NP SBJ NP DT The JJ high JJ seven day PP IN of NP DT the CD 400 NNS
  • Django 将 JSON 数据传递给静态 getJSON/Javascript

    我正在尝试从 models py 中获取数据并将其序列化为views py 中的 JSON 对象 模型 py class Platform models Model platformtype models CharField max len
  • Python 相当于 Scala 案例类

    Python 中是否有与 Scala 的 Case Class 等效的东西 就像自动生成分配给字段而无需编写样板的构造函数一样 当前执行此操作的现代方法 从 Python 3 7 开始 是使用数据类 https www python org
  • 没有名为“turtle”的模块

    我正在学习并尝试用Python3制作贪吃蛇游戏 我正在进口海龟 我正在使用 Linux mint 19 PyCharm python37 python3 tk Traceback most recent call last File hom
  • MoviePY 无法在 Windows 上检测 ImageMagick 二进制文件

    我刚买了一台新笔记本电脑 想要设置MoviePY在那新的Windows 64x Python3 7 0 机器 我对所有内容都进行了三次检查 但是当涉及到我的代码的文本部分时 它向我抛出了这个错误 OSError MoviePy Error
  • 如何同时接受int和float类型的输入?

    我正在制作一个货币转换器 如何让 python 同时接受整数和浮点数 我就是这样做的 def aud brl amount From to ER 0 42108 if amount int if From strip aud and to

随机推荐