本文转载自:http://blog.csdn.net/orangehdc/article/details/38682501
随机梯度下降法(Stochastic Gradient Descent)求解以下的线性SVM模型:
w的梯度为:
传统的梯度下降法需要把所有样本都带入计算,对于一个样本数为n的d维样本,每次迭代求一次梯度,计算复杂度为O(nd) ,当处理的数据量很大而且迭代次数比较多的时候,程序运行时间就会非常慢。
随机梯度下降法每次迭代不再是找到一个全局最优的下降方向,而是用梯度的无偏估计 来代替梯度。每次更新过程为:
由于随机梯度每次迭代采用单个样本来近似全局最优的梯度方向,迭代的步长应适当选小一些以使得随机梯度下降过程尽可能接近于真实的梯度下降法。
下面我用matlab写的一个demo,速度不是很快,跑USPS数据库(二进制格式)csdn下载链接(mat格式),要五分钟,准确率88%左右,效果一般:
- clear;
- load E:\dataset\USPS\USPS.mat;
- % data format:
- % Xtr n1*dim
- % Xte n2*dim
- % Ytr n1*1
- % Yte n2*1
- % warning: labels must range from 1 to n, n is the number of labels
- % other label values will make mistakes
- u=unique(Ytr);
- Nclass=length(u);
-
- allw=[];allb=[];
- step=0.01;C=0.1;
- param.iterations=1;
- param.lambda=1e-3;
- param.biaScale=1;
- param.t0=100;
-
- tic;
- for classname=1:1:Nclass
- temp_Ytr=change_label(Ytr,classname);
- [w,b] = sgd_svm(Xtr,temp_Ytr, param);
- allw=[allw;w];
- allb=[allb;b];
- fprintf('class %d is done \n', classname);
- end
-
- [accuracy predict_label]=my_svmpredict(Xte, Yte, allw, allb);
- fprintf(' accuracy is %.2f percent.\n' , accuracy*100 );
- toc;
- function [temp_Ytr] = change_label(Ytr,classname)
- temp_Ytr=Ytr;
- tep2=find(Ytr~=classname);
- tep1=find(Ytr==classname);
- temp_Ytr(tep2)=-1;
- temp_Ytr(tep1)= 1;
- function [true_W,b]=sgd_svm(X,Y,param)
- % input:
- % X is n*dim
- % Y is n*1 (label is 1 or 0)
- % output:
- % true_W is dim*1 ,so the score is X*W'+b
- % b is 1*1 number
- iterations=param.iterations;%10
- lambda=param.lambda;%1e-3
- biaScale=param.biaScale;%0
- t0=param.t0;%100
- t=t0;
-
- w=zeros(1,size(X,2));
- bias=0;
-
- for k=1:1:iterations
- for i=1:1:size(X,1)
- t=t+1;
- alpha = (1.0/(lambda*t));
- if(Y(i)*(X(i,:)*w'+bias)<1)
- bias=bias+alpha*Y(i)*biaScale;
- w=w+alpha*Y(i,1).*X(i,:);
- end
- end
- end
- b=bias;
- true_W=w;
- function [accuracy predict_label]=my_svmpredict(Xte, Yte, allw, allb)
- % allw is nclass * dim
- % allb is nclass * 1
- % Yte must range from 1 to nclass, other label values will make mistakes
- score = Xte * allw'+repmat(allb',[size(Bte,1),1]);
- [bb c]=sort(score,2,'descend');
- predict_label=c(:,1);
- temp = predict_label((predict_label-Yte)==0);
- right=size( temp,1 );
- accuracy=right/size(Yte,1);