3.【Python】分类算法—Softmax Regression

2023-11-02

3.【Python】分类算法—Softmax Regression


前言

Softmax回归算法主要用于多分类问题,是逻辑回归算法的推广,值得注意的是,Softmax回归算法中任意两个类是线性可分的。


一、Softmax Regression模型

1.Softmax Regression模型

对于Softmax Regression模型,输入特征为 X ( i ) ϵ R n + 1 X^{(i)}\epsilon R^ {n+1} X(i)ϵRn+1,类标记为 y ( i ) ϵ 0 , 1 , . . . , k y^{(i)}\epsilon{0,1,...,k} y(i)ϵ0,1...,k。假设函数为每一个样本估计其所属的类别的概率 P ( y = j ∣ X ) P(y=j |X) P(y=jX)。具体假设函数如下,其中 Θ \Theta Θ表示向量。
在这里插入图片描述

则对于每一个样本估计其所属的类别的概率为:

在这里插入图片描述

2.Softmax Regression的损失函数

在Softmax Regression算法的损失函数中引入指令函数 I ( ⋅ ) I(\cdot ) I(),表示为:
在这里插入图片描述
与Logistic Regression算法中对于损失函数的处理方式类似,都采用极大似然法,并以负的log似然函数作为损失函数,表示为:
在这里插入图片描述
I { y ( i ) = j } I \left \{ y^{(i)}=j \right \} I{y(i)=j}表示属于第j类时, I { y ( i ) = j } = 1 I \left \{ y^{(i)}=j \right \}=1 I{y(i)=j}=1或者 I { y ( i ) = j } = 0 I \left \{ y^{(i)}=j \right \}=0 I{y(i)=j}=0

3.Softmax Regression的求解

对于损失函数可以采用梯度下降法进行求解,求解其梯度表示为:
在这里插入图片描述
梯度下降的公式可以通过下式更新:

在这里插入图片描述
以下将用python代码实现softmax regression的更新过程,构建梯度更新函数gradientAscent,以此实现模型中权重的更新。

import  numpy as np

def gradientAscent (feature_data,label_data,k,maxCycle,alpha):
    '''利用梯度下降法训练softmax模型
    input:feature_data(mat):特征
          label_data(mat):标签
          k(int):类别的个数
          maxCycle(int):最大的迭代次数
          alpha(float):学习率

    output:weights(mat):权重
    '''
    m,n = np.shape(feature_data)
    #np.ones返回一个全1的n维数组
    weights = np.mat(np.ones((n,k)))
    i = 0
    while i <= maxCycle:
        err = np.exp(feature_data * weights)
        if i %100 == 0:
            print("\t----iter:", i,\
                  " , cost: ", cost(err,label_data))
        rowsum = -err.sum(axis = 1)
        rowsum = rowsum.repeat(k,axis = 1)
        err = err/rowsum
        for x in range(m):
            err[x,label_data[x,0]] +=1
        weights = weights + (alpha/m) * feature_data.T * err
        i += 1
    return weights

其中,函数cost用于计算当前损失函数的值,输入为当前预测值err和样本标签label_data。

def cost(err,label_data):
    '''计算损失函数的值
    input:err(mat):exp的值
           label_data(mat):标签的值
    output:sum_cost / m(float):损失函数的值
    '''
    m = np.shape(err)[0]
    sum_cost = 0.0
    for i in range(m):
        if err[i,label_data[i,0]]/np.sum(err[i,:]) > 0:
            sum_cost -= np.log(err[i,label_data[i,0]]/np.sum(err[i,:]))
        else:
            sum_cost -= 0
    return sum_cost/m

二、Softmax Regression和Logistic Regression

1.Softmax Regression中的参数特点

在Softmax Regression中有些参数是没用的,称为参数冗余。假设从参数向量 θ j {\theta _{j}} θj中减去向量 ψ \psi ψ,对预测结果没用任何影响,说明模型中存在多组最优解。
在这里插入图片描述

2.由Softmax Regression到Logistic Regression

Logistic Regression是Softmax Regression特征数为2 时的特殊情况,此时Softmax Regression的假设函数为:
在这里插入图片描述
由于Softmax Regression具有冗余性,减去 ψ \psi ψ依然等价,令 ψ = θ 1 \psi={\theta _{1}} ψ=θ1 θ 1 {\theta _{1}} θ1 θ 2 {\theta _{2}} θ2同时减去 ψ \psi ψ可得:
在这里插入图片描述
特征数为2 时两者假设函数是等价的。

三、Softmax Regression实践

1.构建Softmax Regression算法的训练模型

使用类似如图的数据对Softmax Regression模型进行训练。
在这里插入图片描述

(图侵删)

训练模型的主函数如下,首先需要导入训练数据data.txt,而后利用梯度下降法gradientAscent对模型进行训练,已经在前面给出代码,最终将模型参数保存至weights中。

if __name__ == "__main__":
    inputfile = "data.txt"
    #1.导入训练数据
    print("------1.load data------")
    feature,label,k = load_data("data.txt")
    #2.训练softmax regression模型
    print("------2.training-------")
    weights = gradientAscent(feature,label,k,1000,0.4) #最大迭代次数1000,学习率0.4
    #3.保存最终的模型
    print("------3.save model------")
    save_model("weights",weights)

首先构建导入训练数据的load_data函数,得出训练数据的特征feature_data 、标签label_data和训练样本的类别个数k。

#data.txt为文件名,inputfile为文件
def load_data(inputfile):
    '''
    input: inputfile(string)训练数据的文件位置
    output: feature_data(mat)特征
        label_data(mat)标签
        k(int)类别的个数
    '''
    f = open(inputfile) #打开文件
    feature_data = []
    label_data = []
    #逐行读取
    for line in f.readlines():
        feature_tmp = []
        #strip spilt
        lines = line.strip().split("\t")
        feature_tmp.append(1) #偏置项
        for i in range(len(lines)-1): #读除最后一行的前几行
            feature_tmp.append(float(lines[i]))
        label_data.append(int(lines[-1])) #读最后一行

        feature_data.append(feature_tmp)
    f.close() #关闭文件
    #.T为转置
    return np.mat(feature_data),\
           np.mat(label_data).T,len(set(label_data))

gradientAscent函数在第一节第3.点总已经定义。
最后构建训练模型的save_model函数,将模型和weights保存在file_name中。

def save_model(file_name, weights):
    '''保存最终的模型
    input:
        file_name(string): 保存的文件名
        weights(mat):softmax模型
    '''
    f_w = open(file_name, "w")
    m, n = np.shape(weights)
    for i in range(m):
        w_tmp = []
        for j in range(n):
            w_tmp.append(str(weights[i,j]))
        f_w.write("\t".join(w_tmp)+"\n")
    f_w.close()

2.预测测试数据

将文件命名为softmax_regression_test.py,构建测试模型的主程序,首先导入模型的权重等参数,然后导入测试数据,其后利用训练好的softmax模型对测试数据进行预测,最后将预测结果保存到文件中。

测试模型的主程序如下:

if __name__ == "__main__":
    #1.保存softmax模型
    print("------1.load model------")
    w, m, n = load_weights("weights")
    #2.导入测试数据
    print("------2.load data------")
    test_data = load_data(4000,m)
    #3.利用训练好的softmax regression模型对测试数据进行预测
    print("------3.get prediction-------")
    result = predict(test_data, w)
    #4.保存最终的预测结果
    print("------4.save prediction------")
    save_model("result", result)

构建load_weights函数导入训练模型参数,主要是权重矩阵和矩阵的行列数。

def load_weights(weights_path):
    '''导入softmax训练模型
    input:weights_path(string)权重的存储位置
    output:weights(mat)将权重存到矩阵中
           m(int)权重的行数
           n(int)权重的列数
    '''
    f = open(weights_path)
    w = []
    for lines in f.readlines():
        lines = line.strip().split("\t")
        w_tmp = []
        for x in lines:
            w_tmp.append(float(x))
        w.append(w_tmp)
    f.close()
    weights = np.mat(w)
    m, n = np.shape(weights)
    return np.mat(w)

构建load_data函数导入测试数据,这里测试数据为随机数,需要导入random模块。

import random as rd

def load_data(num,m):
    '''导入测试数据
    input:
        num(int):生成的测试样本的个数
        m(int):样本的维数
    output:testdataset(mat)生成测试样本
    '''
    testdataset = np.mat(np.ones(num, m))
    for i in range(num):
        #随机生成[-3,3]之间的随机数
        testdataset[i, 1] = rd.random() * 6 - 3
        #随机生成[0,15]之间的随机数
        testdataset[i,2] = rd.random() * 15
    return testdataset

构建predict函数对测试数据进行预测。

def predict(test_data, weights):
    '''利用训练好的softmax模型对测试数据进行预测
    input:
        test_data(mat):测试数据的特征
        weights(mat):模型的权重
    output:h.argmax(axis=1)所属的类别
    '''
    h = test_data * weights #每个样本属于每一个类别的概率
    #每一列最大值所在位置的索引
    return h.argmax(axis=1) #获得最终的类别标签

构建save_result函数保存预测结果。

def save_result (file_name, result):
    '''保存最终的预测结果
    input:
        file_name(string):保存最终结果的文件名
        result(mat):最终的预测结果
    '''
    f_result = open(file_name, "w")
    m = np.shape(result)[0]
    for i in range(m):
        f_result.write(str(result[i, 0])+ "\n")
    f_result.close()

总结

以上针对Softmax Regression算法的原理和python具体实现过程进行了介绍,python的具体实现过程主要分为两个部分—训练和预测。

参考文献:《Python机器学习算法》

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

3.【Python】分类算法—Softmax Regression 的相关文章

  • 翠儿。让流永远运行

    我对 tweepy python 库比较陌生 我想确保我的流 python 脚本始终在远程服务器上运行 因此 如果有人能够分享如何实现这一目标的最佳实践 那就太好了 现在我正在这样做 if name main while True try
  • 在 Pandas 中按日期获取有效合约

    我在检测 pandas DataFrame 中的活动合约方面遇到了一些困难 假设每一行都是一个协商 对于每一行 我有两列 initial date 和 end date 我想知道的是按日期划分的活跃合约数量 到目前为止我做了一个非常低效的方
  • 从正在运行的 python 脚本检测优化标志是否为 -O 或 -OO

    有时我想生成一个子进程 其优化标志与启动父进程时使用的优 化标志相同 我可以使用类似的东西 optimize not debug 但这样我就可以匹配两者 O and OO flags 是否有一些 python 内部状态包含该信息 经过一番深
  • 为什么 .setGeometry() 不改变 QWidget 实例的大小?

    我想使用 QWidget 更改 QPushButton 的大小 setGeometry https doc qt io qtforpython 5 PySide2 QtWidgets QWidget html PySide2 QtWidge
  • 在Python3.6中调用C#代码

    由于完全不了解 C 编码 我希望在我的 python 代码中调用 C 函数 我知道有很多关于同一问题的问答 但由于一些奇怪的原因 我无法从示例 python 模块导入简单的 c 类库 以下是我所做的事情 C 类库设置 我使用的是 VS 20
  • 如何在Python中循环并存储自变量中的值

    我对 python 很陌生 所以这听起来可能很愚蠢 我进行了搜索 但没有找到解决方案 我在 python 中有一个名为 ExcRng 的函数 我可以对该函数执行什么样的 for 循环 以便将值存储在独立变量中 我不想将它们存储在列表中 而是
  • 如何确定非阻塞套接字是否真正连接?

    这个问题不仅限于Python 这是一个一般的套接字问题 我有一个非阻塞套接字 想要连接到一台可访问的机器 在另一端 该端口不存在 为什么 select 仍然成功 我预计会超时 sock send 因管道损坏而失败 select 之后如何确定
  • 在python中将文本文件解析为列表

    我对 Python 完全陌生 我正在尝试读取包含单词和数字组合的 txt 文件 我可以很好地读取 txt 文件 但我正在努力将字符串转换为我可以使用的格式 import matplotlib pyplot as plt import num
  • 即使使用 .loc[row_indexer,col_indexer] = value 时也会设置 WithCopyWarning

    这是我的代码中得到的行之一SettingWithCopyWarning value1 Total Population value1 Total Population replace to replace value 4 然后我将其更改为
  • Pandas重置索引未生效[重复]

    这个问题在这里已经有答案了 我不确定我在哪里误入歧途 但我似乎无法重置数据帧上的索引 当我跑步时test head 我得到以下输出 正如您所看到的 数据帧是一个切片 因此索引超出范围 我想做的是重置该数据帧的索引 所以我跑test rese
  • 错误:permission_manager_qt.cpp(82) 不支持的权限类型:13

    我正在开发具有内置浏览器功能的 python 代码 PyQt 5 13 import sys from PyQt5 QtCore import from PyQt5 QtGui import from PyQt5 QtWidgets imp
  • matplotlib matshow 标签

    我一个月前开始使用 matplotlib 所以我仍在学习 我正在尝试用 matshow 制作热图 我的代码如下 data numpy array a reshape 4 4 cax ax matshow data interpolation
  • 获取列表中倒数第二个元素[重复]

    这个问题在这里已经有答案了 我可以通过以下方式获取列表的倒数第二个元素 gt gt gt lst a b c d e f gt gt gt print lst len lst 2 e 有没有比使用更好的方法print lst len lst
  • 如何列出 python PDB 中的当前行?

    在 perl 调试器中 如果重复列出离开当前行的代码段 可以通过输入命令返回到当前行 点 我无法使用 python PDB 模块找到任何类似的东西 如果我list如果我自己离开当前行并想再次查看它 似乎我必须记住当前正在执行的行号 对我来说
  • 将输入发送到 python 子进程而不等待结果

    我正在尝试为一段代码编写一些基本测试 该代码通常通过 stdin 无休止地接受输入 直到给出特定的退出命令 我想检查程序是否在给出一些输入字符串时崩溃 经过一段时间来考虑处理 但似乎无法弄清楚如何发送数据而不是陷入等待我不知道的输出关心 我
  • Django 在选择列表更改时创建毫无意义的迁移

    我正在尝试使用可调用创建一个带有选择字段的模型 以便 Django 在选择列表更改时不会创建迁移 如中所述this https stackoverflow com questions 31788450 stop django from cr
  • 将一个列表的元素除以另一个列表的元素

    我有两个清单 比如说 a 10 20 30 40 50 60 b 30 70 110 正如你所看到的 列表 b 由一个列表的元素总和组成 其中 window 2 b 0 a 0 a 1 10 20 30 etc 如何获得另一个列表 该列表由
  • 在 MacO 和 Linux 上安装 win32com [重复]

    这个问题在这里已经有答案了 我的问题很简单 我可以安装吗win32com蟒蛇API pywin32特别是 在非 Windows 操作系统上 我一直在Mac上尝试多个版本pip install pywin32 都失败了 下面是一个例子 如果你
  • 如何禁止 celery 中的 pickle 序列化

    Celery 默认使用 pickle 作为任务的序列化方法 如中所述FAQ http ask github com celery faq html isn t using pickle a security concern 这代表一个安全漏
  • 检查字符串是否只有字母和空格 - Python

    试图让 python 返回一个字符串仅包含字母和空格 string input Enter a string if all x isalpha and x isspace for x in string print Only alphabe

随机推荐

  • Unity3D---Vuforia is not enabled解决方案

    在Unity3D实现VR的过程中 需要选择Vuforia官网自己创建的Database中的Target 此时 有的Unity3D会出现如下错误 解决办法如下 选择Edit Project Settings Player 将XR Settin
  • 图像识别小车(jetson nano部分)——电赛学习笔记(3)

    目录 零 前言 1 jetson nano购买商家及技术支持 2 相关环境配置 3 做好系统备份 一 vscode远程ssh操作 局域网连接 二 板载摄像头教程 三 运行例程 四 GPIO使用 GPIO库的API用法 1 导入库 2 引脚编
  • Git 命令行提交代码详细操作

    Git 命令行提交代码操作 安装git后 鼠标右键打开Git Bash 1 查看本地git绑定的用户名和邮箱 git config user name git config user email 2 修改本地git绑定的用户名和邮箱 全局
  • 数据挖掘(知识图谱2019)

    领域 二级分类 三级分类 data mining 数据挖掘 time series analysis 时间序列分析 data streams 数据流 time series data 时间序列数据 real time 实时 time ser
  • Unity之Matrix4x4 矩阵

    Matrix4x4 矩阵 Struct A standard 4x4 transformation matrix 一个标准的4x4变换矩阵 A transformation matrix can perform arbitrary line
  • Qt中关于定时器timerEvent和QTimer

    1 Qt 定时器类 QTimer 在进行窗口程序的处理过程中 经常要周期性的执行某些操作 或者制作一些动画效果 使用定时器类 QTimer 就可以解决 使用 只需创建一个 QTimer 类对象 然后调用其 start 函数开启定时器 此后
  • C#实现多语言切换(通过Resource语言包文件实现)

    点我 下载多语言切换项目最全源码 1 先说说Resources语言包文件是怎么来的 通过Visual Studio 命令提示工具将txt文件转换成resources文件 具体操作 a 打开Visual Studio 命令提示工具 然后输入你
  • EXT2.2 grid行不能复制信息的解决方法

    在ext all js的后面加入如下js if Ext grid GridView prototype templates Ext grid GridView prototype templates Ext grid GridView pr
  • 当下用途最广的计算机语言,目前为止国际上最主流的计算机编程语言是什么?...

    看主流的观察角度 如果是这些语言编写的软件的用户数量最多 那么肯定是C和C 了 因为我们的操作系统 例如WINDOWS IOS LINUX 和核心应用程序 例如OFFICE IE CHROME 以及绝大多数的游戏 几乎全都是C和C 以及少量
  • 除了中国好声音,星空华文冲刺港股IPO还有其他王牌吗?

    回顾国内的综艺节目发展史 中国好声音 曾是里程碑式的存在 曾一度稳坐各大省级卫视综艺节目收视率的头把交椅 更是民间歌手们心中殿堂级的存在 但它背后的制作公司 星空华文似乎却江河日下 5月13日 星空华文再次发起IPO 这一回选择登陆的是港交
  • JSP中,AJAX使用POST方式提交中文乱码问题解决

    本人原创 欢迎转载 转载请保留本人信息 作者 wallimn 电邮 wallimn sohu com 博客 http blog csdn net wallimn 时间 2006 11 15 本人原创 欢迎转载 转载请保留本人信息 今天终于解
  • Python编程:从入门到实践(基础知识)

    第一章 起步 计算机执行源程序的两种方式 编译 一次性执行源代码 生成目标代码 解释 随时需要执行源代码 源代码 采用某种编程语言编写的计算机程序 目标代码 计算机可执行 101010 编程语言分为两类 静态语言 使用编译执行的编程语言 C
  • java.library.path属性在代码中设置不生效问题

    http www blogjava net gembin archive 2008 10 29 237377 html from http daimojingdeyu blogbus com logs 28617218 html 可是在使用
  • 如何用wps制作地图分布图_如何用Power BI制作自己的可视化地图

    作者 AgnesJ 在之前的文章中介绍过Power BI的形状地图 使用形状地图我们可以导入自己想要的任何地图 只要找到对应的TopoJson格式地图文件就可以 但是当我们需要分析某一个销售区域 或服务范围时 如何获取或者创建自己的Json
  • 让div撑满整个屏幕的方法(css)

    在body只有一个div的时候 可以通过这样的方式让div撑满整个屏幕 1 给div设置定位 复习一下 css中position有五种属性 static 默认值 没有定位 absolute 绝对定位 相对于父级元素进行定位 relative
  • aop统一日志输出controller出入参及部分参数

    输出使用的jackson 其中获取iputil放在另一篇文章 gt gt gt gt IpUtil获取ip author cy c date 2022 5 19 16 28 统一日志处理 Component Aspect public cl
  • JTS:04 读取数据库数据

    版本 org locationtech jts jts core 1 19 0 链接 github 数据库 创建数据库方式 postgresql 使用postgis插件 kartoza postgis 15 3 3 使用docker容器 创
  • C++中#pragma once与#ifndef的区别

    为了避免同一个文件被include 多次 可以使用两种方式 1 方式一 ifndef SOMEFILE H define SOMEFILE H 声明语句 endif 2 方式二 pragma once 声明语句 两者的区别 ifndef方式
  • Struts2识别与漏洞利用

    Struts2框架识别 1 通过网页后缀来进行判断 如 do或者 action Struts2漏洞验证 Struts2 045漏洞介绍 安恒信息安全研究院WEBIN实验室高级安全研究员nike zheng发现著名J2EE框架 Struts2
  • 3.【Python】分类算法—Softmax Regression

    3 Python 分类算法 Softmax Regression 文章目录 3 Python 分类算法 Softmax Regression 前言 一 Softmax Regression模型 1 Softmax Regression模型