基于Python的西瓜数据集 3.0α的SVM实现

2023-05-16

在西瓜数据集 3.0α 上分别用线性核和高斯核训练一个 SVM,并比较其支持向量的差别。
数据集下载地址:
https://amazecourses.obs.cn-north-4.myhuaweicloud.com/datasets/watermelon_3a.csv
任选数据集中的一种分布类型的数据,分别用软、硬间隔SVM和各类核函数训练,并分析他们分类的效果。
数据集下载地址:https://amazecourses.obs.cn-north-4.myhuaweicloud.com/datasets/SVM.zip

由于课业繁忙,实在是没有时间从底层数学逻辑来实现SVM从而更好地理解把握支持向量机的原理。因此本次作业利用sklearn实现。
此博客为第一问的SVM的简单实现。
1. SVM实现
下方的链接博客很详细的给出了sklearn中SVM的参数,值得参考。

https://www.cnblogs.com/guodavid/p/10174763.html

  • 数据导入与处理
def load_dataset(fname):
    # fname = 'ensemble_study/dataset/weatherHistory.csv'
    data = pd.read_csv(fname, index_col=0)
    return data


def process_data(data: pd.core.frame.DataFrame):
    data.drop('编号', axis=1, inplace=True)
    feature_list = data['色泽'].unique().tolist()
    # print(feature_list)
    data['色泽'] = data['色泽'].apply(lambda n: feature_list.index(n))
    feature_list = data['根蒂'].unique().tolist()
    data['根蒂'] = data['根蒂'].apply(lambda n: feature_list.index(n))
    feature_list = data['敲声'].unique().tolist()
    data['敲声'] = data['敲声'].apply(lambda n: feature_list.index(n))
    feature_list = data['纹理'].unique().tolist()
    data['纹理'] = data['纹理'].apply(lambda n: feature_list.index(n))
    feature_list = data['脐部'].unique().tolist()
    data['脐部'] = data['脐部'].apply(lambda n: feature_list.index(n))
    feature_list = data['触感'].unique().tolist()
    data['触感'] = data['触感'].apply(lambda n: feature_list.index(n))
    feature_list = ['否', '是']
    data['好瓜'] = data['好瓜'].apply(lambda n: feature_list.index(n))

    return data


def split_train_test_set(data: pd.core.frame.DataFrame):
    y = data['好瓜'].values
    data.drop('好瓜', axis=1, inplace=True)
    xtrain, xtest, ytrain, ytest = train_test_split(data, y, test_size=0.2)
    return xtrain, xtest, ytrain, ytest

数据处理部分主要做的是将数据集中的中文特征标签全部利用序列号编号处理,即利用数值来代表,再去除掉编号这一无用数据特征,将目标特征单独提出,之后利用sklearn的train_test_split方法随机划分数据集。

- 线性核,高斯核的支持向量机实现

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time    : 2020/12/19 19:54
# @Author  : Ryu
# @Site    : 
# @File    : SVM.py
# @Software: PyCharm

from data_process import *
from sklearn import svm
from sklearn.metrics import accuracy_score
from visual import visual
import copy

if __name__ == '__main__':
    file_name = 'D:\Pythonwork\FisherLDA\SVM\watermelon_3a.csv'

    data = load_dataset(file_name)
    raw_data = copy.deepcopy(data)
    train = process_data(raw_data)
    xtrain, xtest, ytrain, ytest = split_train_test_set(train)

    # 线性核处理
    linear_svm = svm.LinearSVC(C=0.5, class_weight='balanced')
    linear_svm.fit(xtrain, ytrain)
    y_pred = linear_svm.predict(xtest)
    print('线性核的准确率为:{}'.format(accuracy_score(y_pred=y_pred, y_true=ytest)))

    # 高斯核处理
    gauss_svm = svm.SVC(C=0.5, kernel='rbf', class_weight='balanced')
    gauss_svm.fit(xtrain, ytrain)
    y_pred2 = gauss_svm.predict(xtest)
    print('高斯核的准确率: %s' % (accuracy_score(y_pred=y_pred2, y_true=ytest)))

    #多项式核
    poly_svm = svm.SVC(C=0.5, kernel='poly', degree=3, gamma='auto', coef0=0, class_weight='balanced')
    poly_svm.fit(xtrain, ytrain)
    y_pred3 = poly_svm.predict(xtest)
    print('多项式核的准确率: %s' % (accuracy_score(y_pred=y_pred3, y_true=ytest)))

    #sigmoid核
    sigmoid_svm = svm.SVC(C=0.5, kernel='sigmoid', degree=3, gamma='auto', coef0=0, class_weight='balanced')
    sigmoid_svm.fit(xtrain, ytrain)
    y_pred4 = sigmoid_svm.predict(xtest)
    print('sigmoid核的准确率: %s' % (accuracy_score(y_pred=y_pred4, y_true=ytest)))


    visual(data, 'gauss_svm', gauss_svm)
    visual(data, 'sigmoid svm', sigmoid_svm)

需要说明的是,上述代码中惩罚系数均使用0.5完成。线性核和高斯核的两个SVM中笔者均将class_weighted这个参数设置为了‘balanced’,利用自动计算的样本权值来调整数据集分布——主要愿意是样本数据只有17个,实在是太小了,导致任意一个样本的分类不当都会对整个SVM的准确率产生极大的影响。

  • 实验分析
    在这里插入图片描述
    在class_weighted参数未加入时,实验效果极差。由于样本数很少,划分训练测试数据集时比率取到0.2已是极限。在这种情况下,如果不采用加权样本分类的方法,两个核函数的SVM最终结果基本只有很小概率能够达到50%以上。在加入了之后正确率基本能够稳定在67%以上。并且线性划分的效果基本都好于高斯核的效果。这可能也与训练集过于简单有关。在这里插入图片描述
    由于线性svm在sklearn中没有特征向量支持,故选用sigmoid的核替代展示。可以明显的发现,sigmoid的分类效果不尽人如意。

在这里插入图片描述
高斯核的分类效果相较sigmoid更好,经多次试验发现也更加稳定。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

基于Python的西瓜数据集 3.0α的SVM实现 的相关文章

随机推荐