【机器学习-分类】决策树预测

2023-11-15

我用一些机器学习的算法对数据进行一个分类,下面是一些需要用到的基础代码,以决策树为例,并不包括针对项目的模型处理和修改,留作记忆学习。

对于数据划分训练集直接省略

def Tree_score(depth = 3,criterion = 'entropy',samples_split=2):
	#构建树
	tree = DecisionTreeClassifier(criterion = criterion,max_depth = depth,min_samples_split=samples_split)
	#训练树
	tree.fit(Xtrain, Ytrain)
	#训练集和测试集精确度得分
	train_score = tree.score(Xtrain, Ytrain)
	test_score = tree.score(Xtest, Ytest)
	#return train_score,test_score

下面是对于树的得分曲线绘制,是可以作图观察最优的参数,参考页面忘记了,基本没啥改动

p,k=0
def tree_best_plot(picture_path):
    global p,k
    depths = range(2,25)
    #先是考虑用gini,考虑不同的深度depth
    scores = [Tree_score(d,'gini') for d in depths]
    train_scores = [s[0] for s in scores]
    test_scores = [s[1] for s in scores]

    plt.figure(figsize = (6,6),dpi = 144)
    plt.grid()
    plt.xlabel("max_depth of decision Tree")
    plt.ylabel("score")
    plt.title("'gini'")
    plt.plot(depths,train_scores,'.g-',label = 'training score')
    plt.plot(depths,test_scores,'.r--',label = 'testing score')
    plt.legend()
    path=picture_path+'gini_'+str(k)+'.jpg'
    k+=1
    plt.savefig(path, bbox_inches='tight', dpi=450)
    
    #信息熵(entropy),深度对模型精度的影响
    scores = [Tree_score(d) for d in depths]
    train_scores = [s[0] for s in scores]
    test_scores = [s[1] for s in scores]
    plt.figure(figsize = (6,6),dpi = 144)
    plt.grid()
    plt.xlabel("max_depth of decision Tree")
    plt.ylabel("score")
    plt.title("'entropy'")
    plt.plot(depths,train_scores,'.g-',label = 'training score')
    plt.plot(depths,test_scores,'.r--',label = 'testing score')
    plt.legend()
    path=picture_path+'entropy'+str(p)+'.jpg'
    plt.savefig(path, bbox_inches='tight', dpi=450)
    p+=1

但是也有利用函数自动调整参数的方法,上面那个就不太需要的,但是在应用的时候感觉画图得到的精确度要普遍高于函数搜寻最优参数的最优参。

函数是使用包GridSearchCV,param里是选择的元素集合

from sklearn.model_selection import GridSearchCV
#为了节省时间  去掉gini的检查
param = {'criterion':['entropy','gini'],'max_depth':[2,3,4,5,6,7],'min_samples_leaf':[2,3,4,5,6,7],'min_impurity_decrease':[0.1,0.2,0.3,0.5],'min_samples_split':[2,3,4,5,6,7,8]}
grid = GridSearchCV(DecisionTreeClassifier(),param_grid=param,cv=5)
#用数据进行训练
grid.fit(Xtrain,Ytrain)
print('最优分类器:',grid.best_params_,'最优分数:', grid.best_score_)  # 得到最优的参数和分值

训练好的数据需要保存为模型的形式:

joblib.dump(clf,'predictor.pkl')

模型的加载:model=joblib.load('predictor.pkl')

如果直接封装的就是sklearn中的模型,可以直接调用model.predict和model.score

##返回精确度
score1=model.score(Xtrain, Ytrain) 
score2 = model.score(Xtest, Ytest)

这个得到的是预测的概率值:

y1=model.predict_proba(Xtest)

原本的数据是列表,就只有一个长度9,现在改成array形式,reshape变成[1,9],然后进行预测概率(y1类型是ndarray),然后直接用np.max得到一个具体的值,说实话我记得我关于array的相关代码都是在panda库相关的书里学习的,也都忘的差不多了,啊代码对于金鱼脑子真的好难啊老是用错。

x1=[190,1,2,1,1, 0,1, 0,0]
x1_array=np.array(x1).reshape(1,len(x1))
print('x1_array')
print(x1_array)
y1=model.predict_proba(x1_array)
print(y1)
y1=np.max(y1)
y2=model.predict(x1_array)#这个得出来的就是概率最大的那一项。但是是用list的形式,比如这里y2的值会是[2]
print((y2[0],y1))

然后还要加入解析器,使用argparse.ArgumentParser。
argparse的相关使用找了一个链接Python- argparse.ArgumentParser()用法

但是解析器还有点半懂不懂。

#使用的库包括...等
import argparse
import joblib

对了,中间有遇到一个问题,下载模型使用vscode无法下载,但是使用pycharm可以正常使用。

学习记忆所写,希望走过路过的大佬多多指导,谢谢。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

【机器学习-分类】决策树预测 的相关文章

  • 基于网易云音乐的歌词js逆向

    歌曲的歌词 一 py源码 import json import execjs import requests 实例话一个node对象 node execjs get js源文件编译 ctx node compile open 网易云2号 j
  • 微博模型训练——僵尸用户识别(二)

    上文通过使用决策树算法简单实现了僵尸用户的识别 https blog csdn net weixin 43906500 article details 116992642 本文综合利用多种机器学习方法实现对僵尸用户的识别 使用的机器学习方法
  • shell 实现目录下文件修改记录监控

    文件监控可以配合rsync实现文件自动同步 例如监听某个目录 当文件变化时 使用rsync命令将变化的文件同步 可用于代码自动发布 inotify 是linux内核的一个特性 在内核 2 6 13 以上都可以使用 如果在shell环境下 可

随机推荐

  • u3d修改服务器ip,Unity ping一个服务器 ip 的工具类

    using UnityEngine using System Collections public class UnityPing MonoBehaviour private static string s ip private stati
  • mysql使用sql语句根据时间段查询数据

    1 sql语句 SELECT 字段 from 表名 where 时间字段 BETWEEN 2019 05 22 AND 2019 06 21 注 此种方法查到的是5 22到6 20之间数据 不包括6 21当天的数据 2 在mybatis中m
  • MySQL如何查看,删除用户

    1 查看所有用户 需要在root用户下进行 select host user password from mysql user 2 删除用户 mysql gt Delete FROM user Where User 用户名 and Host
  • LU矩阵分解

    LU分解 Pseudocode LU matrix decompose matrix for j 0 1 n L 为单位下三角矩阵 L j j 1 0 上三角矩阵的行列索引关系 j rows gt i columns for i 0 1 j
  • php socket 错误处理,PHP Socket or TCP 连接错误信息显示乱码问题处理

    错误说明 在项目中编码都是使用UTF 8编码 当用到Socket或者TCP连接的时候出现错误 错误信息不是UTF 8的编码 所以输出看到的是乱码且在输出json格式输出的时候是空白 比如在本地位win7系统 错误信息提示 Can not c
  • 多智能体强化学习与博弈论-博弈论基础3

    多智能体强化学习与博弈论 博弈论基础3 之前主要介绍了如何判断博弈中是否到达了纳什均衡 在这篇文章中将主要介绍如何计算纳什均衡 本文主要介绍下列几种情况下的纳什均衡 两个智能体 每个智能体有两个动作 两个智能体 每个智能体有多个动作 零和博
  • MySql Windows安装教程

    找到下载 gt 拉到最下面找到社区版下载 gt 下载 下面是我下载好的 度盘链接 提取码 sws3 解压到指定目录 Mysql国内镜像 Index of mysql MySQL 8 0 此时解压后的文件中没有data目录和ini文件 然后做
  • 帆软大屏开发手册

    1 需求调研 模块 输出 业务需求调研 业务需求调研报告 硬件调研 大屏采购硬件清单 数据调研 数据质量调研报告 关键性技术预研 技术预研报告 1 1 业务需求调研 1 1 1 根据业务场景抽取关键指标 关键指标是一些概括性词语 是对一组或
  • 使用Python,Keras和TensorFlow训练第一个CNN

    使用Python Keras和TensorFlow训练第一个CNN 这篇博客将介绍如何使用Python和Keras训练第一个卷积神经网络架构 ShallowNet 并在动物和CIFAR 10数据集上对其进行了训练 ShallowNet对动物
  • python+PyCharm+OpenCV配置

    文章目录 一 python的配置 二 PyCharm的安装 三 OpenCV的配置 一 python的配置 第一步 下载python安装包 从python的官网 python下载地址 中找到最新版本的python安装包 点击进行下载即可 需
  • python实现——处理Excel表格(超详细)

    目录 xls和xlsx 基本操作 1 用openpyxl模块打开Excel文档 查看所有sheet表 2 1 通过sheet名称获取表格 2 2 获取活动表 3 1 获取表格的尺寸 4 1 获取单元格中的数据 4 2 获取单元格的行 列 坐
  • 睿智的目标检测56——Pytorch搭建YoloV5目标检测平台

    睿智的目标检测56 Pytorch搭建YoloV5目标检测平台 学习前言 源码下载 YoloV5改进的部分 不完全 YoloV5实现思路 一 整体结构解析 二 网络结构解析 1 主干网络Backbone介绍 2 构建FPN特征金字塔进行加强
  • Arria 10上进行DDR3管脚分配

    本文介绍下DDR3的管脚分配 其它系列的DDR管脚分配也基本一样的 FPGA型号 10AX027H4F34I3SG DDR3型号 MT41J128M16JT 125 QuartusI Prime18 0 首先介绍下A10器件能支持的DDR系
  • 图像处理和图像识别中常用的matlab函数

    下面仅给出函数的大概意思 详细用法见 help 函数名 或 matlab help 1 imread read image from graphics file 2 imshow display image in Handle Graphi
  • MySQL高性能及性能优化技巧---更适合开发人员

    更新次数 更新时间 首发 2021 10 25 第一次更新 2021 10 26 1 删除了书中大量不必要的存储引擎类型 2 摘要完毕Mysql架构与历史部分 第二次更新 2021 10 29 1 摘要基准测试内容 2 删除了大量对概念的举
  • hdu 6208 The Dominator of Strings

    Problem acm hdu edu cn showproblem php pid 6208 Meaning 有 n 个字符串 问是否能找到其中一串 使得其它串都是它的子串 Analysis 如果存在这个串 那它一定是 n 个中的最长串
  • LeetCode刷题记录 字节跳动题库

    1 两数之和 哈希 一遍遍历 3 无重复字符的最长子串 哈希 流动窗口 双指针 因为右端点的位置一定不会朝左边走 建议再看看同类型的题目 2 两数相加 题 42 接雨水 单调递减栈 核心思想 对于每个点找其左边和右边第一个大于或等于它的点
  • 程序员最美的情人节玫瑰花,JAVA代码实现的3D玫瑰噢

    用纯javascript脚本编写的神奇3D圣诞树 令人印象深刻 2月14日情人节就要来临了 还是Roman Cortes 这次他又带来了用javascript脚本编写的红色玫瑰花 用代码做出的玫瑰花 这才是牛逼程序员送给女友的最好情人节礼物
  • idea自动去除导入但未使用的包

    使用idea开发过程中通常我们可能会引入某个包使用但是在后续更改中这个包就不需要了 一个个去除很麻烦 他每个java文件去除的快捷键是ctrl shift o 如果想要更智能的方法我们可以做如下配置 1 使用ctrl alt s进入sett
  • 【机器学习-分类】决策树预测

    我用一些机器学习的算法对数据进行一个分类 下面是一些需要用到的基础代码 以决策树为例 并不包括针对项目的模型处理和修改 留作记忆学习 对于数据划分训练集直接省略 def Tree score depth 3 criterion entrop