python实现次梯度(subgradient)和近端梯度下降法 (proximal gradient descent)方法求解L1正则化

2023-10-27

l1范数最小化

考虑函数f(x)=\frac{1}{2}(Ax - b)^2+\left | x \right |,显然其在零点不可微,其对应的的次微分为:

\begin{equation} g = A^T(Ax - b) + \mu \left\{ \begin{array}{lr} 1, & x >0 \\ \left [ -1,1 \right ], & x =0 \\ -1, & x < 0 \end{array} \right. \end{equation}

注意x = 0\partial f(x)的取值为一个区间\left [ -1, 1 \right ]

两个重要定理:

1)一个凸函数,当且仅当0 \in \partial f(x^*)x^*为全局最小值,即 x^*为最小值点 \Leftrightarrow 0 \in \partial f(x^*)

2)x^*为函数f不一定是凸函数)的最小值点,当且仅当f在该点可次微分且0 \in \partial f(x^*)

考虑最简单的一种情况,目标函数为:

f(x) = \frac{1}{2}(x - \hat{x})^2 + \lambda \left | x \right |

对应的次微分为:

\partial f(x) = x - \hat{x} + \lambda z(x)

进一步可以表示为:

\begin{equation} \partial f(x) = \left\{ \begin{array}{lr} x - \hat{x}_i + \lambda, & x >0,(-\hat{x}_i+\lambda,+ \infty ) \\ -\hat{x}_i + \left[ -\lambda,\lambda \right ], & x =0,[-\hat{x}_i-\lambda, -\hat{x}_i+\lambda] \\ x - \hat{x}_i - \lambda, & x < 0,(-\infty, -\hat{x}_i-\lambda) \end{array} \right. \end{equation}

故,若-\hat{x}_i + \lambda \leqslant 0 \Rightarrow \hat{x}_i \geqslant \lambda,最小值点x_i^*为:x_i^* = \hat{x}_i-\lambda

-\hat{x}_i - \lambda \geqslant 0 \Rightarrow \hat{x}_i \leqslant -\lambda,最小值点x_i^*为:x_i^* = \hat{x}_i+\lambda

\left | \hat{x}_i \right | < \lambda, 最小值点x_i^*为:x_i^* = 0

简而言之,最优解x_i^* = S(\hat{x}, \lambda)=sgn(\hat{x})max(\left | \hat{x} \right | - \lambda,0)S(\hat{x}, \lambda)通常被称为软阈值(soft threshold)算子

次梯度(subgradient)求解L1正则化问题

考虑最小化问题:

min \left \| Ax - b \right \|^2 + \mu \left \| x \right \|_1

随机生成矩阵A和向量\mub=A\mu\mu = 1e^{-3},仅Ab\mu为已知量,x为参数。

那么次梯度为:

\begin{equation} g = A^T(Ax - b) + \mu \left\{ \begin{array}{lr} 1, & x >0 \\ \left [ -1,1 \right ], & x =0 \\ -1, & x < 0 \end{array} \right. \end{equation}

对于次梯度求解方法,次梯度g可以进一步表示为:

g(x) = A^T(Ax - b) + \mu \cdot \textbf{sgn}(x)

进一步,可以得到次梯度的更新公式为:

x^{k+1} = x^k - t\cdot g(x^k)

python代码如下:

# -*- coding: utf-8 -*-

import numpy as np
import scipy as spy
from scipy.sparse import csc_matrix
import matplotlib.pyplot as plt
import time   #用来计算运行时间

#=======模拟数据======================

m = 512
n = 1024


#稀疏矩阵的产生,A使用的是正态稀疏矩阵
u= spy.sparse.rand(n,1,density=0.1,format='csc',dtype=None)
u1 = u.nonzero()
row = u1[0]
col = u1[1]
data = np.random.randn(int(0.1*n))
u = csc_matrix((data, (row, col)), shape=(n,1)).toarray() #1024 * 1

#u1 = u.nonzero()        #观察是否是正态分布
#plt.hist(u[u1[0],u1[1]].tolist())

#u = u.todense()  #转为非稀疏形式

a = np.random.randn(m,n) #512 * 1024
b = np.dot(a,u) # a * u, 512 * 1
v = 1e-3      #v为题目里面的miu

def f(x0):    #目标函数 1/2*||Ax - b||^2 + mu*||x||1
    return 1/2*np.dot((np.dot(a,x0)-b).T,np.dot(a,x0)-b)+v*sum(abs(x0))

#==========初始值=============================
x0 = np.zeros((n,1)) #1024 * 1

y = []
time1 = []

start = time.clock()
print("begin to train!")
#=========开始迭代==========================
for i in range(1000):
    if i %100 == 0:
        if len(y) > 0:
            print("step " + str(i) + "val: " + str(y[len(y) - 1]))
    mid_result = f(x0) 
    y.append(f(x0)[0,0])    #存放每次的迭代函数值
    
    g0 = (np.dot(np.dot(a.T,a),x0)-np.dot(a.T,b) + v*np.sign(x0)) 
    #次梯度, A^T(Ax - b) + mu * sign(x)
    t = 0.01/np.sqrt(sum(np.dot(g0.T,g0)))    #设为0.01效果比0.1好很多,步长

       
    x1 = x0 - t[0]*g0  
    x0 = x1
    
    end = time.clock()
    time1.append(end)

y = np.array(y).reshape((1000,1))    

time1 = np.array(time1)
time1 = time1 - start
time2 = time1[np.where(y - y[999] < 10e-4)[0][0]]

plt.plot(y)
plt.show()
for val in y:
    print(val)
    
#    if i % 100 == 0: 
#        f = 1/2*np.dot((np.dot(a,x0)-b).T,np.dot(a,x0)-b)+v*sum(abs(x0))
#        print(f)
        
#在计算机计算时,可以明显感受到proximal gradient方法比次梯度方法快

结果如下:

前几个参数值:

可以看到参数具有稀疏性。 

近端梯度下降法 (proximal gradient descent)方法求解L1正则化

proximal gradient算法推导

对于梯度下降的每一步,可以看做是一个平方模型的局部最小化:

x_{k+1} = arg\mathop{min}\limits_{x}\left \{\frac{1}{2t_k} \left \| x - (x_k - t_k\bigtriangledown{f(x_k)}) \right \|^2_2 \right \}

其中,t_k>0为合适的步长,那么,同理对于有L1正则项的情况,上式变为:

x_{k+1} = arg\mathop{min}\limits_{x}\left \{\frac{1}{2t_k} \left \| x - (x_k - t_k\bigtriangledown{f(x_k)}) \right \|^2_2 + \mu \left \| x \right \| _1\right \}

考虑前面介绍的soft threshold,利用soft threshold的方法,我们可以马上得到最优解:

x_{k+1}^* = S(x_k - t_k\bigtriangledown{f(x_k)}, \mu)=sgn(x_k - t_k\bigtriangledown{f(x_k)},)max(\left | x_k - t_k\bigtriangledown{f(x_k)},\right | - \mu,0)

对与上面的例子f(x) = \frac{1}{2}(x - \hat{x})^2 + \lambda \left | x \right |,我们可以得到基于此方法的梯度更新公式为:

g(x) = A^T(Ax - b)

x^{k+1} = sgn(x^k - t\cdot g(x^k))max(x^k - t\cdot g(x^k) - \mu, 0)

这就是通常所说的Iterative Shrinkage Thresholding Algorithm (ISTA)算法,与subgradient算法相比,由于soft threshold的限制,使得参数的稀疏性更强

相应的python代码为:

# -*- coding: utf-8 -*-

import numpy as np
import scipy as spy
from scipy.sparse import csc_matrix
import matplotlib.pyplot as plt
import time   #用来计算运行时间

#=======模拟数据======================

m = 512
n = 1024


#稀疏矩阵的产生,A使用的是正态稀疏矩阵
u= spy.sparse.rand(n,1,density=0.1,format='csc',dtype=None)
u1 = u.nonzero()
row = u1[0]
col = u1[1]
data = np.random.randn(int(0.1*n))
u = csc_matrix((data, (row, col)), shape=(n,1)).toarray() #1024 * 1

#u1 = u.nonzero()        #观察是否是正态分布
#plt.hist(u[u1[0],u1[1]].tolist())

#u = u.todense()  #转为非稀疏形式

a = np.random.randn(m,n) #512 * 1024
b = np.dot(a,u) # a * u, 512 * 1
v = 1e-3      #v为题目里面的miu

def f(x0):    #目标函数 1/2*||Ax - b||^2 + mu*||x||1
    return 1/2*np.dot((np.dot(a,x0)-b).T,np.dot(a,x0)-b)+v*sum(abs(x0))

def S(x1,v):
    for i in range(len(x1)):
        if np.abs(x1[i]) - v > 0:
            x1[i] = np.sign(x1[i]) * (np.abs(x1[i]) - v)
        else:
            x1[i] = 0
    return x1


#==========初始值=============================
#x0 = np.zeros((n,1)) #1024 * 1
x0 = (2.0*np.random.random((n,1)) - 1.0) * 0.01

y = []
time1 = []

start = time.clock()
print("begin to train!")
#=========开始迭代==========================
for i in range(2000):
    if i %10 == 0:
        if len(y) > 0:
            print("step " + str(i) + "val: " + str(y[len(y) - 1]))
    mid_result = f(x0) 
    y.append(f(x0)[0,0])    #存放每次的迭代函数值
    
    #g0 = (np.dot(np.dot(a.T,a),x0)-np.dot(a.T,b) + v*np.sign(x0)) 
    #次梯度, A^T(Ax - b) + mu * sign(x)
    g0 = np.dot(np.dot(a.T,a),x0)-np.dot(a.T,b)
    t = 0.025/np.sqrt(sum(np.dot(g0.T,g0)))    #设为0.01效果比0.1好很多,步长   
    x1 = S(x0 - t[0]*g0, v)  
    x0 = x1
    
    end = time.clock()
    time1.append(end)

y = np.array(y).reshape((2000,1))    

time1 = np.array(time1)
time1 = time1 - start
time2 = time1[np.where(y - y[999] < 10e-4)[0][0]]

plt.plot(y)
plt.show()
for val in y:
    print(val)
    
#    if i % 100 == 0: 
#        f = 1/2*np.dot((np.dot(a,x0)-b).T,np.dot(a,x0)-b)+v*sum(abs(x0))
#        print(f)
        
#在计算机计算时,可以明显感受到proximal gradient方法比次梯度方法快

 

前几个参数如下:

显然与基础次梯度算法相比,此方法稀疏性明显更好。 

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

python实现次梯度(subgradient)和近端梯度下降法 (proximal gradient descent)方法求解L1正则化 的相关文章

  • 是否有解决方法可以通过 CoinGecko API 安全检查?

    我在工作中运行我的代码 一切都很顺利 但在不同的网络 家庭 WiFi 上 我不断收到403访问时出错CoinGecko V3 API https www coingecko com api documentations v3 可以观察到 在
  • Python 中的哈希映射

    我想用Python实现HashMap 我想请求用户输入 根据他的输入 我从 HashMap 中检索一些信息 如果用户输入HashMap的某个键 我想检索相应的值 如何在 Python 中实现此功能 HashMap
  • Pandas/Google BigQuery:架构不匹配导致上传失败

    我的谷歌表中的架构如下所示 price datetime DATETIME symbol STRING bid open FLOAT bid high FLOAT bid low FLOAT bid close FLOAT ask open
  • Pandas 日期时间格式

    是否可以用零后缀表示 pd to datetime 似乎零被删除了 print pd to datetime 2000 07 26 14 21 00 00000 format Y m d H M S f 结果是 2000 07 26 14
  • 如何使用 Pandas、Numpy 加速 Python 中的嵌套 for 循环逻辑?

    我想检查一下表的字段是否TestProject包含了Client端传入的参数 嵌套for循环很丑陋 有什么高效简单的方法来实现吗 非常感谢您的任何建议 def test parameter a list parameter b list g
  • YOLOv8获取预测边界框

    我想将 OpenCV 与 YOLOv8 集成ultralytics 所以我想从模型预测中获取边界框坐标 我该怎么做呢 from ultralytics import YOLO import cv2 model YOLO yolov8n pt
  • Pandas Merge (pd.merge) 如何设置索引和连接

    我有两个 pandas 数据框 dfLeft 和 dfRight 以日期作为索引 dfLeft cusip factorL date 2012 01 03 XXXX 4 5 2012 01 03 YYYY 6 2 2012 01 04 XX
  • Python 2:SMTPServerDisconnected:连接意外关闭

    我在用 Python 发送电子邮件时遇到一个小问题 me my email address you recipient s email address me email protected cdn cgi l email protectio
  • 在Python中检索PostgreSQL数据库的新记录

    在数据库表中 第二列和第三列有数字 将会不断添加新行 每次 每当数据库表中添加新行时 python 都需要不断检查它们 当 sql 表中收到的新行数低于 105 时 python 应打印一条通知消息 警告 数量已降至 105 以下 另一方面
  • 如何使用 Mysql Python 连接器检索二进制数据?

    如果我在 MySQL 中创建一个包含二进制数据的简单表 CREATE TABLE foo bar binary 4 INSERT INTO foo bar VALUES UNHEX de12 然后尝试使用 MySQL Connector P
  • Docker 中的 Python 日志记录

    我正在 Ubuntu Web 服务器上的 Docker 容器中测试运行 python 脚本 我正在尝试查找由 Python Logger 模块生成的日志文件 下面是我的Python脚本 import time import logging
  • pyspark 将 twitter json 流式传输到 DF

    我正在从事集成工作spark streaming with twitter using pythonAPI 我看到的大多数示例或代码片段和博客是他们从Twitter JSON文件进行最终处理 但根据我的用例 我需要所有字段twitter J
  • Numpy - 根据表示一维的坐标向量的条件替换数组中的值

    我有一个data多维数组 最后一个是距离 另一方面 我有距离向量r 例如 Data np ones 20 30 100 r np linspace 10 50 100 最后 我还有一个临界距离值列表 称为r0 使得 r0 shape Dat
  • 加快网络抓取速度

    我正在使用一个非常简单的网络抓取工具抓取 23770 个网页scrapy 我对 scrapy 甚至 python 都很陌生 但设法编写了一个可以完成这项工作的蜘蛛 然而 它确实很慢 爬行 23770 个页面大约需要 28 小时 我看过scr
  • 如何使用原始 SQL 查询实现搜索功能

    我正在创建一个由 CS50 的网络系列指导的应用程序 这要求我仅使用原始 SQL 查询而不是 ORM 我正在尝试创建一个搜索功能 用户可以在其中查找存储在数据库中的书籍列表 我希望他们能够查询 书籍 表中的 ISBN 标题 作者列 目前 它
  • Pandas 将多行列数据帧转换为单行多列数据帧

    我的数据框如下 code df Car measurements Before After amb temp 30 268212 26 627491 engine temp 41 812730 39 254255 engine eff 15
  • 在本地网络上运行 Bokeh 服务器

    我有一个简单的 Bokeh 应用程序 名为app py如下 contents of app py from bokeh client import push session from bokeh embed import server do
  • 将 Python 中的日期与日期时间进行比较

    所以我有一个日期列表 datetime date 2013 7 9 datetime date 2013 7 12 datetime date 2013 7 15 datetime date 2013 7 18 datetime date
  • 模拟pytest中的异常终止

    我的多线程应用程序遇到了一个错误 主线程的任何异常终止 例如 未捕获的异常或某些信号 都会导致其他线程之一死锁 并阻止进程干净退出 我解决了这个问题 但我想添加一个测试来防止回归 但是 我不知道如何在 pytest 中模拟异常终止 如果我只
  • 如何计算Python中字典中最常见的前10个值

    我对 python 和一般编程都很陌生 所以请友善 我正在尝试分析包含音乐信息的 csv 文件并返回最常听的前 n 个乐队 从下面的代码中 每听一首歌曲都是一个列表中的字典条目 格式如下 album Exile on Main Street

随机推荐

  • STL中常用的排序算法

    merge 以下是排序和通用算法 提供元素排序策略 merge 合并两个有序序列 存放到另一个序列 例如 vecIntA vecIntB vecIntC是用vector
  • Git 版本回退与前进(03)

    现在 你已经学会了修改文件 然后把修改提交到Git版本库 现在 再练习一次 修改readme txt文件如下 Git is a distributed version control system Git is free software
  • 理解attention的image to caption(图片的文字描述)

    更多查看 https github com B C WANG AI Storage 4 1 理解attention的image to caption 图片的文字描述 4 1 1 一 一个简单模型 Encoder 使用预训练的CNN进行fin
  • flex局部的知识总结(转载)

    版权声明 本文为CSDN博主 Coralpapy 的原创文章 遵循CC 4 0 BY SA版权协议 转载请附上原文出处链接及本声明 原文链接 https blog csdn net Coralpapy article details 120
  • 用limma包的voom方法来做RNA-seq 差异分析

    用limma包的voom方法来做RNA seq 差异分析 大家都知道 这十几年来最流行的差异分析软件就是R的limma包了 但是它以前只支持microarray的表达数据 考虑到大家都熟悉了它 它又发了一个voom的方法 让它从此支持RNA
  • Python-绘制七段数码管

    SevenDigitsDrawV2 py import turtle time def drawGap 绘制数码管间隔 turtle penup turtle fd 5 def drawLine draw 绘制单段数码管 drawGap t
  • vue踩坑填坑(四):在vue单页中修改title

    由于在vue单页应用中title只设定在入口文件index html 如果切换路由 title怎么更换 在路由router中设置meta path chooseBrand component resolve gt require compo
  • 数据链路层简介

    1 数据链路层的基本概念 数据链路层在物理层提供服务的基础上向网络层提供服务 其最基本的服务是将源自网络层来的数据可靠地传输到相邻节点的目标机网络层 其主要作用是加强物理层传输原始比特流的功能 将物理层提供的可能出错的物理连接改造成为逻辑上
  • Python 保存数据的方法(4种方法)

    Python 保存数据的方法 open函数保存 使用with open 新建对象 写入数据 这里使用的是爬取豆瓣读书中一本书的豆瓣短评作为例子 import requests from lxml import etree 发送Request
  • 无线连接打印服务器,如何用旧电脑架设无线网络打印服务器

    如何用旧电脑架设无线网络打印服务器 由会员分享 可在线阅读 更多相关 如何用旧电脑架设无线网络打印服务器 4页珍藏版 请在人人文库网上搜索 1 如何用旧电脑架设无线网络打印服务器在工作中 单位需要打印的文件还是不少的 可是笔记本电脑连接一个
  • input框限输入数字并保留两位小数

    先把非数字的都替换掉 除了数字和 obj value obj value replace d g 保证只有出现一个 而没有多个 obj value obj value replace 2 g 必须保证第一个为数字而不是 obj value
  • iOS上架及ipa包上传到AppStore

    概述 由于苹果的机制 在非越狱机器上安装应用必须通过官方的Appstore 开发者开发好应用后上传Appstore 也需要通过审核等环节 AppCan作为一个跨主流平台的一个开发平台 也对ipa包上传Appstore作了支持 本文从三个流程
  • 通过canvas实现将html的某些元素转为png图片

    有时候我们需要把html或者某些html元素转换为图片 并且支持下载 下面是学习之后的总结 希望能给大家带来帮助 所需插件库 html2canvas js canvas2image js base64 js 资源地址 链接 https pa
  • 蛇形矩阵(完全)

    画 n阶蛇形方阵 比如如图是5阶方阵 5条对角线 1 2 6 7 15 3 5 8 14 16 4 9 13 17 22 10 12 18 21 23 11 19 20 24 25 解题思路 1 分为上三角和下三角 上三角的思路是同蛇形矩阵
  • 训练自己的ai模型(一)学习笔记与项目实操

    ai模型大火 作为普通人 我也想做个自己的ai模型 训练自己的ai模型通常需要接下来的的六步 一 收集和准备数据集 需要收集和准备一个数据集 其中包含想要训练模型的数据 这可能需要一些数据清理和预处理 以确保数据集的质量和一致性 二 选择和
  • clash设置代理后内网访问慢及访问不到问题

    配置忽略代理的ip及域名即可 在 config clash文件夹下新建 proxyIgnoreList plist文件 如果不知道 config clash在哪的 可以通过打开本地文件夹来定位 然后在新创建的文件内写入要忽略代理的域名及ip
  • 链表与节点

    链表 java中通过 node next 表示 node的下一个节点 同理 node next next 表示 node后的第二个节点 通过链表这种数据结构 可以实现许多奇妙的组合 这里我通过接口的方式 把重要的方法进行了封装 虽然只有三个
  • 逆流而上——泛谈对二进制可执行程序的静态反编译

    欢迎对本blog相关主题感兴趣的团体或单位转载相关文章 但转载时请注明出处 谢谢 一 概述 首先应该声明的是 这里讨论的反编译是针对二进制可执行程序进行的静态反向编译操作 虽然对于类似Java Bytecode和MSIL的虚拟机中间代码的反
  • 【layui】 灵活使用弹出层iframe 让你的父界面代码更加清晰

    第一次使用layui框架时候 layer open 使用了最最累赘的 静态布局div 在写入content中不但让你的代码逻辑很乱 而且不利于开发 所以整理一款 弹出层js 是在开发中必不可少的 在此感谢 X admin2 0 提供的源码
  • python实现次梯度(subgradient)和近端梯度下降法 (proximal gradient descent)方法求解L1正则化

    l1范数最小化 考虑函数 显然其在零点不可微 其对应的的次微分为 注意 的取值为一个区间 两个重要定理 1 一个凸函数 当且仅当 为全局最小值 即 为最小值点 2 为函数 不一定是凸函数 的最小值点 当且仅当在该点可次微分且 考虑最简单的一