全面解析并实现逻辑回归(Python)

2023-11-19

本文以模型、学习目标、优化算法的角度解析逻辑回归(LR)模型,并以Python从头实现LR训练及预测。

一、逻辑回归模型结构

逻辑回归是一种广义线性的分类模型且其模型结构可以视为单层的神经网络,由一层输入层、一层仅带有一个sigmoid激活函数的神经元的输出层组成,而无隐藏层。其模型的功能可以简化成两步,“通过模型权重[w]对输入特征[x]线性求和+sigmoid激活输出概率”具体来说,我们输入数据特征x,乘以一一对应的模型权重w后求和,通过输出层神经元激活函数σ(sigmoid函数)将(wx + b)的计算后非线性转换为0~1区间的概率数值后输出。学习训练(优化模型权重)的过程是通过梯度下降学到合适的模型权重[W],使得模型输出值Y=sigmoid(wx + b)与实际值y的误差最小。

附注:sigmoid函数是一个s形的曲线,它的输出值在[0, 1]之间,在远离0的地方函数的值会很快接近0或1。对于sigmoid输出作为概率的合理性,可以参照如下证明:

逻辑回归是一种判别模型,为直接对条件概率P(y|x)建模,假设P(x|y)是高斯分布,P(y)是多项式分布,如果我们考虑二分类问题,通过公式变换可以得到:

可以看到,逻辑回归(或称为对数几率回归)的输出概率和sigmoid形式是一致的。

逻辑回归模型本质上属于广义线性分类器(决策边界为线性)。这点可以从逻辑回归模型的决策函数看出,决策函数Y=sigmoid(wx + b),当wx+b>0,Y>0.5;当wx+b<0,Y<0.5,以wx+b这条线可以区分开Y=0或1(如下图),可见决策边界是线性的。

二、学习目标

逻辑回归是一个经典的分类模型,对于模型预测我们的目标是:预测的概率与实际正负样本的标签是对应的,Sigmoid 函数的输出表示当前样本标签为 1 的概率,y^可以表示为

当前样本预测为0的概率可以表示为1-y^

对于正样本y=1,我们期望预测概率尽量趋近为1 。对于负样本y=0,期望预测概率尽量都趋近为0。也就是,我们希望预测的概率使得下式的概率最大(最大似然法)

我们对 P(y|x) 引入 log 函数,因为 log 运算并不会影响函数本身的单调性。则有:

我们希望 log P(y|x) 越大越好,反过来,只要 log P(y|x) 的负值 -log P(y|x) 越小就行了。那我们就可以引入损失函数,且令 Loss = -log P(y|x),得到损失函数为:

我们已经推导出了单个样本的损失函数,是如果是计算 m 个样本的平均的损失函数,只要将 m 个 Loss 叠累加取平均就可以了:


这就在最大似然法推导出的lr的学习目标——交叉熵损失(或对数损失函数),也就是让最大化使模型预测概率服从真实值的分布,预测概率的分布离真实分布越近,模型越好。可以关注到一个点,如上式逻辑回归在交叉熵为目标以sigmoid输出的预测概率,概率值只能尽量趋近0或1,同理loss也并不会为0。

三、优化算法

我们以极小交叉熵为学习目标,下面要做的就是,使用优化算法去优化参数以达到这个目标。由于最大似然估计下逻辑回归没有(最优)解析解,我们常用梯度下降算法,经过多次迭代,最终学习到的参数也就是较优的数值解。
梯度下降算法可以直观理解成一个下山的方法,将损失函数J(w)比喻成一座山,我们的目标是到达这座山的山脚(即求解出最优模型参数w使得损失函数为最小值)。

下山要做的无非就是“往下坡的方向走,走一步算一步”,而在损失函数这座山上,每一位置的下坡的方向也就是它的负梯度方向(直白点,也就是山的斜向下的方向)。在每往下走一步(步长由α控制)到一个位置的时候,求解当前位置的梯度,向这一步所在位置沿着最陡峭最易下山的位置再走一步。这样一步步地走下去,一直走到觉得我们已经到了山脚。
当然这样走下去,有可能我们不是走到山脚(全局最优,Global cost minimun),而是到了某一个的小山谷(局部最优,Local cost minimun),这也梯度下降算法的可进一步优化的地方。
对应的算法步骤:

另外的,以非极大似然估计角度,去求解逻辑回归(最优)解析解,可见kexue.fm/archives/8578

四、Python实现逻辑回归

本项目的数据集为癌细胞分类数据。基于Python的numpy库实现逻辑回归模型,定义目标函数为交叉熵,使用梯度下降迭代优化模型,并验证分类效果:

# coding: utf-8

import numpy as np 
import matplotlib.pyplot as plt
import h5py
import scipy
from sklearn import datasets

# 加载数据并简单划分为训练集/测试集
def load_dataset():
    dataset = datasets.load_breast_cancer()  
    train_x,train_y = dataset['data'][0:400], dataset['target'][0:400]
    test_x, test_y = dataset['data'][400:-1], dataset['target'][400:-1]
    return train_x, train_y, test_x, test_y

# logit激活函数
def sigmoid(z):
    s = 1 / (1 + np.exp(-z))    
    return s
    
# 权重初始化0
def initialize_with_zeros(dim):
    w = np.zeros((dim, 1))
    b = 0
    assert(w.shape == (dim, 1))
    assert(isinstance(b, float) or isinstance(b, int))
    return w, b

# 定义学习的目标函数,计算梯度
def propagate(w, b, X, Y):
    m = X.shape[1]      
    A = sigmoid(np.dot(w.T, X) + b)         # 逻辑回归输出预测值  
    cost = -1 / m *  np.sum(Y * np.log(A) + (1 - Y) * np.log(1 - A))   # 交叉熵损失为目标函数
    dw = 1 / m * np.dot(X, (A - Y).T)   # 计算权重w梯度
    db = 1 / m * np.sum(A - Y)   
    assert(dw.shape == w.shape)
    assert(db.dtype == float)
    cost = np.squeeze(cost)
    assert(cost.shape == ())    
    grads = {"dw": dw,
             "db": db}    
    return grads, cost

# 定义优化算法
def optimize(w, b, X, Y, num_iterations, learning_rate, print_cost):
    costs = []    
    for i in range(num_iterations):    # 梯度下降迭代优化
        grads, cost = propagate(w, b, X, Y)
        dw = grads["dw"]              # 权重w梯度
        db = grads["db"]
        w = w - learning_rate * dw   # 按学习率(learning_rate)负梯度(dw)方向更新w
        b = b - learning_rate * db
        if i % 50 == 0:
            costs.append(cost)
        if print_cost and i % 100 == 0:
            print ("Cost after iteration %i: %f" %(i, cost))
    params = {"w": w,
              "b": b}
    grads = {"dw": dw,
             "db": db}
    return params, grads, costs

#传入优化后的模型参数w,b,模型预测   
def predict(w, b, X):
	m = X.shape[1]
	Y_prediction = np.zeros((1,m))
	A = sigmoid(np.dot(w.T, X) + b)
	for i in range(A.shape[1]):
		if A[0, i] <= 0.5:
			Y_prediction[0, i] = 0
		else:
			Y_prediction[0, i] = 1
	assert(Y_prediction.shape == (1, m))
	return Y_prediction

def model(X_train, Y_train, X_test, Y_test, num_iterations, learning_rate, print_cost):
    # 初始化
    w, b = initialize_with_zeros(X_train.shape[0]) 
    # 梯度下降优化模型参数
    parameters, grads, costs = optimize(w, b, X_train, Y_train, num_iterations, learning_rate, print_cost)
    w = parameters["w"]
    b = parameters["b"]
    # 模型预测结果
    Y_prediction_test = predict(w, b, X_test)
    Y_prediction_train = predict(w, b, X_train)
    # 模型评估准确率
    print("train accuracy: {} %".format(100 - np.mean(np.abs(Y_prediction_train - Y_train)) * 100))
    print("test accuracy: {} %".format(100 - np.mean(np.abs(Y_prediction_test - Y_test)) * 100))    
    d = {"costs": costs,
         "Y_prediction_test": Y_prediction_test, 
         "Y_prediction_train" : Y_prediction_train, 
         "w" : w, 
         "b" : b,
         "learning_rate" : learning_rate,
         "num_iterations": num_iterations}    
    return d
    
# 加载癌细胞数据集
train_set_x, train_set_y, test_set_x, test_set_y = load_dataset()   

# reshape
train_set_x = train_set_x.reshape(train_set_x.shape[0], -1).T
test_set_x = test_set_x.reshape(test_set_x.shape[0], -1).T

print(train_set_x.shape)
print(test_set_x.shape)

#训练模型并评估准确率
paras = model(train_set_x, train_set_y, test_set_x, test_set_y, num_iterations = 100, learning_rate = 0.001, print_cost = False)


(END)

文章首发公众号“算法进阶”,阅读原文可访问文章相关代码

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

全面解析并实现逻辑回归(Python) 的相关文章

随机推荐

  • 初学树莓派——(六)树莓派安装OpenCV及USB摄像头配置

    目录 1 安装OpenCV 1 1前言 1 2换源及源内容更新 1 3安装依赖 1 4下载whl包 1 5安装OpenCV 1 6检查安装 2 USB摄像头配置 同时检查OpenCV安装情况 2 1前言 2 2Python调用cv2库来检查
  • sem_init函数用法

    sem init函数 sem init函数是Posix信号量操作中的函数 sem init 初始化一个定位在 sem 的匿名信号量 value 参数指定信号量的初始值 pshared 参数指明信号量是由进程内线程共享 还是由进程之间共享 如
  • 最优化算法概述以及常见分类

    1 最优化问题概述 通俗的来说 最优化问题就是在一定的条件约束下 使得效果最好 最优化问题是一种数学问题 是研究在给定的约束之下如何求得某些因素的量 来使得某一指标达到最优的学科 工程设计中最优化问题的一般说法是 选择一组参数 在满足一系列
  • 数据结构笔记(六)——散列(Hash Table)之散列函数(1)

    散列表 hash table 的实现叫做散列 hashing 这是以常数平均时间O 1 进行插入 删除和查找的技术 散列表没有顺序 需要元素间排序信息的操作 如findMin findMax不会得到有效支持 就是这东西不是这么用的 你可以实
  • RocketMq顺序发送消息

    错乱消息出现的原因 1 在RocketMq为啥消息不是按照顺序来的呢 首先您需要了解 队列是一个先进先出的一个数据的结构 生产者 您可以将topic理解为里面有一个一个的队列 你将一个消息发送到topic的时候 当前的消息不一定是往当前的这
  • win 10 搭建FTP服务,并使用的FTP进行传输文件(很详细)

    1 安装IIS工具 打开控制面板 点击 程序 点击 启用或关闭Windows功能 找到 internet information services 全部都选上 如下图 点击 确定 会出现以下页面 点击 关闭 即可 2 设置开机启动FTP服务
  • 高光谱图像中的Hughes(休斯)现象

    注解 在高光谱图像的分析中 随着参与运算波段数目的增加 分类精度 先增后降 的现象 场景 高光谱影像 由于维数的大幅度增加 在深度学习中 可以理解成模型提取的特征维数的增加 导致用于参数训练的所需样本数也急剧增加 如果样本数过少 那么估计出
  • Fiddler 详尽教程与抓取移动端数据包

    转载自 http blog csdn net qq 21445563 article details 51017605 阅读目录 1 Fiddler 抓包简介 1 字段说明 2 Statistics 请求的性能数据分析 3 Inspecto
  • C++面试题目集合(持续跟新)

    与我前面写的C语言进阶知识点遥相呼应 这才是C 面试 网上的面试题有些太简单了 C 面试题目最多集中在对象的内存模型 记住了 如果用c c 内存都不清楚 还写个屁的程序 1 C 的虚函数是怎样实现的 C 的虚函数使用了一个虚函数表来存放了每
  • 我的世界服务器物品不掉落指令是什么,我的世界死亡物品怎么不掉落 我的世界物品不掉落指令...

    我的世界死亡不了多指令是gamerulekeepInventorytrue 玩家们要注意我的世界死亡不掉落指令默认是关闭状态的哦 死亡不掉落指令在 我的世界 游戏里面就是当玩家们死亡以后仍然保留其物品栏中的所有物品 包括附魔死亡消失魔咒的物
  • 安装WSL + zsh & Pure (ZSH prompt) 美化【Windows11】

    文章目录 前言 WSL 安装 ZSH 安装ZSH Pure ZSH prompt 安装插件 下载插件 编辑配置文件 插件作用 啊 PS 如果在启动过程中提示 请启用虚拟机平台 windows 功能并确保在 bios 中启用虚拟化 前言 之前
  • 数据分析36计(28):Python 使用 Flask+Docker, 100行代码内实现机器学习实时预测​...

    本文的想法是快速轻松地构建 Docker 容器 Python 以使用 Flask 实现机器学习模型执行在线预测 API 我们将使用 Docker 和 Flask RESTful 实现线性判别分析和多层感知器神经网络模型的实时预测 项目包括的
  • Android中的自绘View的那些事儿(八)之 Paint的高级用法

    我们在 Android中的自绘View的那些事儿 一 中简单介绍过Paint和Canvas的一些常用方法和实例使用 其中 一句话提到Paint中有方法 setStrokeCap setStrokeJoin 和 setPathEffect 今
  • nodejs如何利用libuv实现事件循环和异步

    本文是根据之前在公司内部做的分享整理而成 是早期对nodejs的一个认识 源码版本10 x nodejs是什么 libuv的工作原理 nodejs的工作原理 nodejs如何使用libuv实现事件循环和异步 1 nodejs是什么 Node
  • pyinstaller打包最小体积安装python程序 命令行传参执行

    文章目录 创建虚拟环境 进入虚拟环境安装库 pycharm配置虚拟环境 pycharm 打开terminal进入虚拟环境 运行参数传入 sys argv 是获取运行python文件的时候命令行参数 且以list形式存储参数 打包后的文件运行
  • js记录密码出错次数并锁定账号30分钟

    下面要说的是网站中一个常见的功能 在客人使用抵用券或者其他来支付的时候需要验证密码 如果密码输入错误5次就锁定 不在让客人使用抵用券了 在这里是使用的cookie来实现的 不太严谨 思路很简单 在输入密码错误的时候 使用cookie保存2个
  • 基于vue项目的上拉刷新,下拉加载的效果

    使用插件 better scroll 安装使用教程http ustbhuangyi github io better scroll doc installation html npm 还是看官网比较好 子组件
  • 28_content 阶段的 index 模块

    文章目录 content 阶段的 index 模块 显示目录内容 content 阶段的 autoindex 模块 autoindex 模块的指令 index autoindex 示例配置 content 阶段的 index 模块 ngx
  • 6、基于STM32呼吸灯(PWM)

    之前定时器中有提到输入和输出比较部分 https blog csdn net qq 45764141 article details 125286260 参考有江科大自化协的视频和正电原子的视频 这个文章主要讲输出部分 文章目录 一 OC
  • 全面解析并实现逻辑回归(Python)

    本文以模型 学习目标 优化算法的角度解析逻辑回归 LR 模型 并以Python从头实现LR训练及预测 一 逻辑回归模型结构 逻辑回归是一种广义线性的分类模型且其模型结构可以视为单层的神经网络 由一层输入层 一层仅带有一个sigmoid激活函