Svm实现多分类

2023-10-26

机器学习---Svm实现多分类详解

Svm实现多类分类原理

1.支持向量机分类算法最初只用于解决二分类问题,缺乏处理多分类问题的能力。后来随着需求的变化,需要svm处理多分类分为。目前构造多分类支持向
量机分类器的方法主要有两类: 一类是“同时考虑所有分类”方法,另一类是组合二分类器解诀多分类问题。
第一类方法主要思想是在优化公式的同时考虑所有的类别数据,J.Weston 和C.Watkins 提出的“K-Class 多分类算法”就属于这一类方法。该算法在经典的SVM理论的基础上,重新构造多类分类型,同时考虑多个类别,然后将问题也转化为-个解决二次规划(Quadratic Programming,简称QP)问题,从而实现多分类。该算法由于涉及到的变量繁多,选取的目标函数复杂,实现起来比较困难,计算复杂度高。
第二类方法的基本思想是通过组合多个二分类器实现对多分类器的构造,常见的构造方法有“一对一”(one-against-one)和“一对其余”(one-against-the rest两种。 其中“一对一”方法需要对n类训练数据两两组合,构建n(n- 1)/2个支持向量机,每个支持向量机训练两种不同类别的数据,最后分类的时候采取“投票”的方式决定分类结果。“一对其余”方法对n分类问题构建n个支持向量机,每个支持向量机负责区分本类数据和非本类数据。该分类器为每个类构造一个支持向量机, 第k个支持向量机在第k类和其余n-1个类之间构造一个超平面,最后结果由输出离分界面距离wx+ b最大的那个支持向量机决定。

本文将上述“一对其余”的SVM多分类方法对莺尾花数据进行分类识别,并设法减少训练样本个数,提高训练速度。

代码实现

#-*- coding:utf-8 -*-
'''
@project: exuding-bert-all
@author: exuding
@time: 2019-04-23 09:59:52
'''
#svm 高斯核函数实现多分类
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from sklearn import datasets

sess = tf.Session()

#加载数据集,并为每类分离目标值
iris = datasets.load_iris()
#提取数据的方法
x_vals = np.array([[x[0],x[3]] for x in iris.data])

y_vals1 = np.array([1 if y==0 else -1 for y in iris.target])
y_vals2 = np.array([1 if y==1 else -1 for y in iris.target])
y_vals3 = np.array([1 if y==2 else -1 for y in iris.target])
#合并数据的方法
y_vals = np.array([y_vals1,y_vals2,y_vals3])
#数据集四个特征,只是用两个特征就可以
class1_x = [x[0] for i,x in enumerate(x_vals) if iris.target[i]==0]
class1_y = [x[1] for i,x in enumerate(x_vals) if iris.target[i]==0]
class2_x = [x[0] for i,x in enumerate(x_vals) if iris.target[i]==1]
class2_y = [x[1] for i,x in enumerate(x_vals) if iris.target[i]==1]
class3_x = [x[0] for i,x in enumerate(x_vals) if iris.target[i]==2]
class3_y = [x[1] for i,x in enumerate(x_vals) if iris.target[i]==2]
#从单类目标分类到三类目标分类,利用矩阵传播和reshape技术一次性计算所有的三类SVM,一次性计算,y_target的占位符维度是[3,None]
batch_size = 50
x_data = tf.placeholder(shape = [None,2],dtype=tf.float32)
y_target = tf.placeholder(shape=[3,None],dtype=tf.float32)
#TODO
prediction_grid = tf.placeholder(shape=[None,2],dtype=tf.float32)
b = tf.Variable(tf.random_normal(shape=[3,batch_size]))
#计算高斯核函数 TODO
gamma = tf.constant(-10.0)
dist = tf.reduce_sum(tf.square(x_data),1)
dist = tf.reshape(dist,[-1,1])
sq_dists = tf.add(tf.subtract(dist,tf.multiply(2.,tf.matmul(x_data,tf.transpose(x_data)))),tf.transpose(dist))
my_kernel = tf.exp(tf.multiply(gamma,tf.abs(sq_dists)))
#扩展矩阵维度
def reshape_matmul(mat):
    v1 = tf.expand_dims(mat,1)
    v2 = tf.reshape(v1,[3,batch_size,1])
    return (tf.matmul(v2,v1))
#计算对偶损失函数
model_output = tf.matmul(b,my_kernel)
first_term = tf.reduce_sum(b)
b_vec_cross = tf.matmul(tf.transpose(b),b)
y_target_cross = reshape_matmul(y_target)
second_term = tf.reduce_sum(tf.multiply(my_kernel,tf.multiply(b_vec_cross,y_target_cross)),[1,2])
loss = tf.reduce_sum(tf.negative(tf.subtract(first_term,second_term)))
#创建预测核函数
rA = tf.reshape(tf.reduce_sum(tf.square(x_data),1),[-1,1])
rB = tf.reshape(tf.reduce_sum(tf.square(prediction_grid),1),[-1,1])
pred_sq_dist = tf.add(tf.subtract(rA,tf.multiply(2.,tf.matmul(x_data,tf.transpose(prediction_grid)))),tf.transpose(rB))
pred_kernel = tf.exp(tf.multiply(gamma,tf.abs(pred_sq_dist)))
#创建预测函数,这里实现的是一对多的方法,所以预测值是分类器有最大返回值的类别
prediction_output = tf.matmul(tf.multiply(y_target,b),pred_kernel)
prediction = tf.arg_max(prediction_output-tf.expand_dims(tf.reduce_mean(prediction_output,1),1),0)
accuracy = tf.reduce_mean(tf.cast(tf.equal(prediction,tf.arg_max(y_target,0)),tf.float32))
#准备好核函数,损失函数,预测函数以后,声明优化器函数和初始化变量
my_opt = tf.train.GradientDescentOptimizer(0.01)
train_step = my_opt.minimize(loss)
init = tf.initialize_all_variables()
sess.run(init)
#该算法收敛的相当快,所以迭代训练次数不超过100次
loss_vec = []
batch_accuracy = []
for i in range(100):
    rand_index = np.random.choice(len(x_vals),size=batch_size)
    rand_x = x_vals[rand_index]
    rand_y = y_vals[:,rand_index]
    sess.run(train_step,feed_dict={x_data:rand_x,y_target:rand_y})
    temp_loss = sess.run(loss,feed_dict={x_data:rand_x,y_target:rand_y})
    loss_vec.append(temp_loss)
    acc_temp = sess.run(accuracy,feed_dict={x_data:rand_x,y_target:rand_y,prediction_grid:rand_x})
    batch_accuracy.append(acc_temp)
    if(i+1)%25 ==0:
        print('Step #' + str(i+1))
        print('Loss #' + str(temp_loss))

x_min,x_max = x_vals[:,0].min()-1,x_vals[:,0].max()+1
y_min,y_max = x_vals[:,1].min()-1,x_vals[:,1].max()+1
xx,yy = np.meshgrid(np.arange(x_min,x_max,0.02),np.arange(y_min,y_max,0.02))
grid_points = np.c_[xx.ravel(),yy.ravel()]
grid_predictions = sess.run(prediction,feed_dict={x_data:rand_x,y_target:rand_y,prediction_grid:grid_points})
grid_predictions = grid_predictions.reshape(xx.shape)
#绘制训练结果,批量准确度和损失函数
#等高线图
plt.contourf(xx,yy,grid_predictions,cmap=plt.cm.Paired,alpha=0.8)
plt.plot(class1_x,class1_y,'ro',label = 'I.setosa')
plt.plot(class2_x,class2_y,'kx',label = 'I.versicolor')
plt.plot(class3_x,class3_y,'gv',label = 'T.virginica')
plt.title('Gaussian Svm Results on Iris Data')
plt.xlabel('Pedal Length')
plt.ylabel('Sepal Width')
plt.legend(loc='lower right')
plt.ylim([-0.5,3.0])
plt.xlim([3.5,8.5])
plt.show()

plt.plot(batch_accuracy,'k-',label='Accuracy')
plt.title('Batch Accuracy')
plt.xlabel('Generation')
plt.ylabel('Sepal Width')
plt.legend(loc = 'lower right')
plt.show()

plt.plot(loss_vec,'k--')
plt.title('Loss per Generation')
plt.xlabel('Generation')
plt.ylabel('Loss')
plt.show()

训练的图片

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Svm实现多分类 的相关文章

  • 北大学生控诉字节跳动backup制度,怎么破解职场pua?

    5月13日下午 一名北大学生在校内论坛未名BBS上写下4000多字长文 陈述自己在字节跳动实习的经历 该同学表示 2021年1月份在字节跳动办理实习生入职 四月中旬实习期已满 且因毕业事宜繁忙向leader表达了近期需要离职的诉求 但竟遭遇
  • unity3d课后练习(四)

    文章目录 1 基本操作演练 建议做 2 编程实践 1 基本操作演练 建议做 下载 Fantasy Skybox FREE 构建自己的游戏场景 在 Asset Store 中搜索 Fantasy Skybox FREE 下载完成后 按照介绍导
  • better-scroll的学习和使用

    better scroll的学习和初始化 介绍 在日常的移动端开发中 列表滚动条的处理是非常常见的需求 横竖的滚动条使用better scroll都可以帮助我们在开发中实现 什么是better scroll better scroll是一个
  • Lattice Planner从入门到放弃

    Lattice Planner相关背景和更正式的公式推导可以直接参考其原始论文 Optimal Trajectory Generation for Dynamic Street Scenarios in a Fren t Frame ICR

随机推荐

  • protobuf快速上手

    protobuf快速上手 一 序列化与反序列化 序列化与反序列化的场景 常用的工具 二 protobuf工作原理 三 快速上手 protobuf中的数据类型 proto文件格式 编译选项 快速上手 四 通讯录demo 编写proto文件 编
  • 多DBWR进程与IO slave

    DBWn DBWn定期写脏数据到磁盘 频繁的磁盘I O会影响性能 所以每当数据库内存中产生脏数据时 是不一定也不应该产生写数据到磁盘的操作的 DBWn会尽量少的写入磁盘 虽然一个数据库DBW0进程适用于所有系统 为了提高数据库写的能力可以配
  • 使用gSOAP与WebService - 第一部分 为VC++从WSDL读取信息

    CurrencyConvertor How use gSOAP and WebServices Part 1 Get ready with VC 6 from WSDL file Download Demo 42 1 KB Download
  • Angular js 中angular is not defined 的问题

    觉得很搞笑 我现在还是不知道这是什么情况 反正这样能解决问题 1 我用下面这种方式引入js文件 在js文件中使用angular module方法会报angular is not defined 2 然后把引入的angular js文件放在上
  • Proteus(8.9版本) 51单片机-烟雾探测器的设计-仿真

    第一步是收集有关室外温度 湿度和气体浓度的信息 作为敏感元件烟雾传感器的输入信息 当信号输入值与放大模块的A D转换器输入电平相匹配时 无需放大放大器 当信号输入值与放大模块的A D转换器输入级别不匹配时 放大器将放大电气信号 A D电路的
  • 小熊派学习:手册查询和ADC深入使用

    弯曲传感器 折弯弯曲传感器 它的电阻值就会上升 那么flex value的值就会越来越小 连带地让led value的值越小 LED就会越暗 涉及到 上下拉电阻 电源至元器件引脚上的电阻称为上拉电阻 作用是平时使该引脚为高电平 地至元器件引
  • sqli-labs第二十三关(注释过滤绕过)

    从源码可知此关将注释符全部过滤掉 需要绕过 使用and 1 1即可 http localhost 90 sqli labs master Less 23 id 1 union select 111 select group concat s
  • 【linux-kali】网络模式host-only设置及注意事项

    网络模式 host only 设置 环境 kali vmware windows10 步骤 1 关闭kali系统下 虚拟机 编辑 虚拟机网络编辑器 Vmnet1 设置或确认子网IP 192 167 0 0 和DHCP范围 2 宿主机 上网网
  • CentOS8安装mysql8.0.24

    记录一下CentOS8安装mysql的过程 CentOS系统版本为CentOS Linux release 8 1 1911 安装的mysql版本为8 0 24 一 下载mysql安装包并解压 执行以下命令 创建mysql安装目录 mkdi
  • WebSocket的基本使用

    目录 为何使用websocket 1 后端搭建 2 搭建webSocket前后分离 1 配置跨域过滤器与初始化websocket 2 定义websocket服务 3 定义控制器进行测试webSocket向前端发送消息 2 前端准备 3 进行
  • 如何从gitee上拉项目?

    目录 第一步 下载git软件 第二步 一直下一步 傻瓜式安装 第三部 使用 新建一个文件夹 2 右击 打开命令窗口 3 复制项目下载url 4 命令窗口输入这样一串命令 第一步 下载git软件 CNPM Binaries Mirror np
  • spring-boot是否还和spring mvc一样存在父子容器

    文章目录 一 spring boot在自动集成了spring springmvc后是否在有父子容器之分 1 看下spring boot run方法 2 为什么spring mvc弄了一个父子容器 二 spring mvc中父子容器初始化过程
  • @Autowired 和 @Resource 的区别

    Autowired 和 Resource 的区别 区别 Autowired Resource 区别 区别1 Autowired 是spring提供的注解 Resource 是JDK提供的注解 区别2 Autowired 默认的注入方式是By
  • 第三十章、containers容器类部件QMdiArea多文档界面部件功能介绍及开发应用

    专栏 Python基础教程目录 专栏 使用PyQt开发图形界面Python应用 专栏 PyQt入门学习 老猿Python博文目录 一 引言 老猿在前期学习PyQt相关知识时 对每个组件的属性及方法都研究得很透彻 并将学习的感悟都写成了博文
  • linux中jdk安装/java环境安装

    第一步首先下载java jdk jdk 8u144 linux x64链接 https pan baidu com s 1uvSB 7JP037AdZJPDdGF6A 提取码 mdat 然后使用工具将文件传输到linux上 然后将tar g
  • 在树莓派中安装ROS系统(Kinetic)

    在树莓派中安装ROS系统 重新梳理了一下树莓派的安装流程 现在我们来开始吧 打开官网教程 http wiki ros org kinetic step1 安装源 中国 sudo sh c etc lsb release echo deb h
  • 物联网+区块链溯源方案

    物联网硬件 蓝牙 wifi 加区块链的方式可有效对现实世界中的实例进行链上映射 本文介绍一种基于硬件的轮胎区块链防伪溯源以及渠道管控的方案思路 更多区块链技术与应用分类 区块链应用 区块链开发 以太坊 Fabric BCOS 密码技术 共识
  • 服务器搭建系列之7:k8s安装postgresql数据库,2022最新版本

    Dockerfile FROM postgres EXPOSE 5432 deploy yaml 命名空间 apiVersion v1 kind Namespace metadata name fandai apiVersion apps
  • Maven安装与配置,Eclipse配置Maven【图文并茂的保姆级教程】

    Welcome Huihui s Code World 接下来看看由辉辉所写的关于Maven的相关操作吧 目录 Welcome Huihui s Code World 一 Maven是什么 二 Maven的下载 辉辉小贴士 maven中各个
  • Svm实现多分类

    机器学习 Svm实现多分类详解 Svm实现多类分类原理 代码实现 训练的图片 Svm实现多类分类原理 1 支持向量机分类算法最初只用于解决二分类问题 缺乏处理多分类问题的能力 后来随着需求的变化 需要svm处理多分类分为 目前构造多分类支持