python 实现 softmax分类器(MNIST数据集)

2023-11-20

最近一直在外面,李航那本书没带在身上,所以那本书的算法实现估计要拖后了。
这几天在看Andrew Ng 机器学习的课程视频,正好看到了Softmax分类器那块,发现自己之前理解perceptron与logistic regression是有问题的。这两个算法真正核心的不同在于其分类函数的不同,perceptron采用一个分段函数作为分类器,logistic regression采用sigmod函数作为分类器,这才是这两个函数真正的不同。

废话不多说了,今天打算实现softmax分类器。

算法

算法参考的是Andrew 的课件与这篇文章
具体实现的时候发现加入权重衰减效果会更好。

这里为了防止大家看不懂我的程序,我在这里做一些定义

ΘjJ(Θ)=x(i)(1{y(i)=j}p(y(i)=j|x(i);Θ))+λΘj(1)

p(y(i)=j|x(i);Θ)=eΘTjx(i)kl=1eΘTlx(i)(2)

eΘTlx(i)(3)

数据集

数据集和KNN那个博文用的是同样的数据集。
数据地址:https://github.com/WenDesi/lihang_book_algorithm/blob/master/data/train.csv

特征

将整个图作为特征

代码

代码已上传GitHub

这次的代码是python3的,有可能需要稍微改一改,不好意思了,我要背叛python2了。

# encoding=utf8

import math
import pandas as pd
import numpy as np
import random
import time

from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score


class Softmax(object):

    def __init__(self):
        self.learning_step = 0.000001           # 学习速率
        self.max_iteration = 100000             # 最大迭代次数
        self.weight_lambda = 0.01               # 衰退权重

    def cal_e(self,x,l):
        '''
        计算博客中的公式3
        '''

        theta_l = self.w[l]
        product = np.dot(theta_l,x)

        return math.exp(product)

    def cal_probability(self,x,j):
        '''
        计算博客中的公式2
        '''

        molecule = self.cal_e(x,j)
        denominator = sum([self.cal_e(x,i) for i in range(self.k)])

        return molecule/denominator


    def cal_partial_derivative(self,x,y,j):
        '''
        计算博客中的公式1
        '''

        first = int(y==j)                           # 计算示性函数
        second = self.cal_probability(x,j)          # 计算后面那个概率

        return -x*(first-second) + self.weight_lambda*self.w[j]

    def predict_(self, x):
        result = np.dot(self.w,x)
        row, column = result.shape

        # 找最大值所在的列
        _positon = np.argmax(result)
        m, n = divmod(_positon, column)

        return m

    def train(self, features, labels):
        self.k = len(set(labels))

        self.w = np.zeros((self.k,len(features[0])+1))
        time = 0

        while time < self.max_iteration:
            print('loop %d' % time)
            time += 1
            index = random.randint(0, len(labels) - 1)

            x = features[index]
            y = labels[index]

            x = list(x)
            x.append(1.0)
            x = np.array(x)

            derivatives = [self.cal_partial_derivative(x,y,j) for j in range(self.k)]

            for j in range(self.k):
                self.w[j] -= self.learning_step * derivatives[j]

    def predict(self,features):
        labels = []
        for feature in features:
            x = list(feature)
            x.append(1)

            x = np.matrix(x)
            x = np.transpose(x)

            labels.append(self.predict_(x))
        return labels


if __name__ == '__main__':

    print('Start read data')

    time_1 = time.time()

    raw_data = pd.read_csv('../data/train.csv', header=0)
    data = raw_data.values

    imgs = data[0::, 1::]
    labels = data[::, 0]

    # 选取 2/3 数据作为训练集, 1/3 数据作为测试集
    train_features, test_features, train_labels, test_labels = train_test_split(
        imgs, labels, test_size=0.33, random_state=23323)
    # print train_features.shape
    # print train_features.shape

    time_2 = time.time()
    print('read data cost '+ str(time_2 - time_1)+' second')

    print('Start training')
    p = Softmax()
    p.train(train_features, train_labels)

    time_3 = time.time()
    print('training cost '+ str(time_3 - time_2)+' second')

    print('Start predicting')
    test_predict = p.predict(test_features)
    time_4 = time.time()
    print('predicting cost ' + str(time_4 - time_3) +' second')

    score = accuracy_score(test_labels, test_predict)
    print("The accruacy socre is " + str(score))

运行结果

这里写图片描述

速度挺快,正确率一般吧,比决策树之类的要高。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

python 实现 softmax分类器(MNIST数据集) 的相关文章

随机推荐

  • lua和测试(一)

    lua做为一门高级语言 在游戏产业运用到机会越来越多了 测试掌握几门脚本语言也有一定的重要性 以下对于lua组合输入做出一些引导 测试需要掌握的关于返回数值 主要用到布尔类 前言的指引 lua的语法比较简单和清晰 学过c语言的可以很好的掌握
  • 并发编程系列之自定义线程池

    前言 前面我们在讲并发工具类的时候 多次提到线程池 今天我们就来走进线程池的旅地 首先我们先不讲线程池框架Executors 我们今天先来介绍如何自己定义一个线程池 是不是已经迫不及待了 那么就让我们开启今天的旅途吧 什么是线程池 线程池可
  • selenium+python 对输入框的输入处理

    最近自己在做项目的自动化测试 公司无此要求 在用户管理模块做修改用户信息时 脚本已经跑成功 并且的确做了update操作 但是自己登陆页面检查 信息却没有被修改 再次确定系统该模块的编辑功能可用 脚本如下 if result num gt
  • 近千万EOS被盗事件回顾,大家请保护好自己的EOS私钥

    最近有伙伴被盗了价值近千万的EOS 于是查看了这次被盗活动账号记录 这次分享出来 一是有可能大家有线索 二是也让大家意识到数字货币私钥安全的重要性 事件回顾 受害人在7 9号被偷盗人通过update auth更换了账号授权公私钥 紧接着被转
  • 零基础到GPT高手:快速学习与利用ChatGPT的完全指南

    进入人工智能时代 令人惊叹的ChatGPT技术正在引爆全球 您是否想象过能够与智能语言模型对话 提升工作效率 解锁创意 甚至实现商业化变现 在本篇文章中 我将向你揭示ChatGPT的原理 学习技巧 并展示如何利用ChatGPT提升工作效率和
  • Windows11:QT5.14.2+PCL1.12.0+VS2019环境配置

    之前在win10系统下配置了PCL1 8 1 QT5 9 1 VS2015的开发环境 由于PCL库已经更新到了1 12 1而且1 8 1一直有bug 为了使用下新的算法库 今天配置一下新的开发环境 1 安装Qt5 14 2 Qt5 14 2
  • 【b站雅思笔记】Simon‘s IELTS Course - 听力部分

    前情提要 b站up主贼开心的小林上传的Simon的听力课 资料均来源于她 参考 雅思阅读 最好的雅思课程 阅读部分全集 https www bilibili com video BV1ea4y1x7qR spm id from 333 78
  • Spring为什么要用的三级缓存解决循环依赖

    一 代码准备 Component aService public class AService Autowired private BService bService public void test System out println
  • 哈工大2020软件构造Lab3实验报告

    本项目于4 21日实验课验收 更新完成 如果有所参考 请点点关注 点点赞GitHub Follow一下谢谢 2020春计算机学院 软件构造 课程Lab3实验报告 Software Construction 2020 Spring Lab 3
  • react_hooks系列05_useRef,useImperativeHandle,高阶组件forwordRef

    一 useRef 1 uesRef使用在官方标签上 useRef 返回一个可变的 ref 对象 其 ref 对象 current 属性被初始化为传入的参数 initialValue 返回的 ref 对象在组件的整个生命周期内保持不变 imp
  • 蓝桥杯字母阵列

    字母阵列 递归解法 仔细寻找 会发现 在下面的8x8的方阵中 隐藏着字母序列 LANQIAO SLANQIAO ZOEXCCGB MOAYWKHI BCCIPLJQ SLANQIAO RSFWFNYA XIFZVWAL COAIQNAL 我
  • 教你怎么导入导出数据

    最近在做一个项目 需要对数据进行导入导出 实现之后 自己也做了一个总结 总体来说还是比较容易的 第一次的话肯定有许多坑的 细节真的很重要 当你踏过一个又一个坑 一路路走来 你会发现自己的信心越来越强 对于数据的导入导出 我们首先写一个工具类
  • 代码检查、评审、单元测试工具 大搜集

    看书真是迅速进入一个陌生领域的最快办法 系统的 体系完整的知识比起在互联网上七拼八凑出的认识强太多了 先记下一些理论概念 软件生命周期模型 分析 设计与文档 编码与审查 测试与调试 发布与维护 软件测试对象的6种分类 单元测试 静态检查 动
  • 数据结构---线性表的静态/动态分配与顺序/链式存储

    线性表 基于严魏敏版数据结构c语言实现 谭浩强版c语言 数据元素在计算机中的存储分为顺序存储和链式存储 顺序存储 借助元素在存储器中的相对位置来表示数据元素之间的逻辑关系 链式存储 借助指示元素存储地址的指针表示数据元素之间的逻辑关系 ps
  • matlab定义机器人位置,机器人自定位问题(数学建模)

    形形色色 各式各样的机器人正在走进人们的生产与生活 发挥着越来越重要的作用 这些机器人 一般都拥有 感官 各种传感器 大脑 智能计算的软硬件 和 执行器 各种操控设备 等 它们在自己的工作场合内 能自主感知 自主决策并完成使命 为达到这样的
  • 笔记---Linux安装OpenCV及VSCode的配置编译

    学更好的别人 做更好的自己 微卡智享 本文长度为4250字 预计阅读10分钟 前言 最近在学点新东西 教程中主要也是在Linux中使用 对于我这个以前从未接触Linux系统的人来说 正好也是个机会掌握下LInux系统 这篇就是记录在Linu
  • 批量创建文件与文件夹

    1 批量创建文件 下面们来说一下如何在pyhton中去批量创建文件 假设我要新建10个txt文件 这里我用一个for循环 for i in range 10 这里的 指代的是当前文件夹 i表示文件的名称 a表示没有该文件就新建 f open
  • Java 泛型 T,E,K,V,?

    泛型带来的好处 在没有泛型的情况的下 通过对类型 Object 的引用来实现参数的 任意化 任意化 带来的缺点是要做显式的强制类型转换 而这种转换是要求开发者对实际参数类型可以预知的情况下进行的 对于强制类型转换错误的情况 编译器可能不提示
  • 入门力扣自学笔记279 C++ (题目编号:1123)

    1123 最深叶节点的最近公共祖先 题目 给你一个有根节点 root 的二叉树 返回它 最深的叶节点的最近公共祖先 回想一下 叶节点 是二叉树中没有子节点的节点 树的根节点的 深度 为 0 如果某一节点的深度为 d 那它的子节点的深度就是
  • python 实现 softmax分类器(MNIST数据集)

    最近一直在外面 李航那本书没带在身上 所以那本书的算法实现估计要拖后了 这几天在看Andrew Ng 机器学习的课程视频 正好看到了Softmax分类器那块 发现自己之前理解perceptron与logistic regression是有问