サポートベクターマシン(SVM)

Last-modified: 2010-01-16 (土) 22:41:51

サポートベクターマシンのサンプルプログラムです.

*************************************************************
svm.sas
2010.1.16 翔
SMOアルゴリズムにより,SVMの学習を行う
参考文献:J.Platt. Fast training of support vector machines using sequential minimal optimization.
In Advances in kernel methods : support vector learning. MIT Press,1998.
*************************************************************;


options nocenter;

proc iml;
***********************************************
モジュール定義
smo()   /*SMOメイン.更新できる未定乗数がなくなるまで繰り返す.*/
examinexample(i1)  /*選ばれたサンプルがKKT違反であれば,もうひとつのサンプルを選び更新を試みる*/
svm(i) global(C,tol,eps,N,d,x,y,b,a,E)/*SVMの出力*/
kernel2(i)  /* IMLモジュールに同じ引数名を複数書けなかったので別に定義した*/
start takestep(i1,i2)  /*2点の更新を試みる*/
(kernel関数の定義は,実行時に行う)
***********************************************;

start smo()   global(data,C,tol,eps,N,d,x,y,b,a,E);
/*SMOメイン.更新できる未定乗数がなくなるまで繰り返す.*/

N=nrow(data);  /*サンプル件数*/
d=ncol(data)-1;/*説明変数の数*/
x=data[,1:d];  /*説明変数x*/
y=data[,d+1];  /*分類変数y*/

b=0;        /*閾値*/
a=j(N,1,0); /*未定乗数ベクトルa*/
E=j(N,1,0); /*エラーキャッシュE*/

numchanged=0; /*更新されたサンプルの数*/


do until(numchanged=0);

  /*部分チェック*/
  do until(numchanged=0);
    numchanged=0;
	do k=1 to n;
	  if (0<a[k,] & a[k,]<C) then do;
        numchanged=numchanged+examinexample(k);
      end;
	end;
  end;

  /*全体チェック*/
  do k=1 to n;
    numchanged=numchanged+examinexample(k);
  end;
end;
finish;


start examinexample(i1)  global(C,tol,eps,N,d,x,y,b,a,E);
/*選ばれたサンプルがKKT違反であれば,もうひとつのサンプルを選び更新を試みる*/

y1=y[i1,];
a1=a[i1,];
if 0<a1 & a1<c then E1=E[i1,];
else E1=SVM(i1)-y1;
r1=E1*y1;

/*KKT-condition violation test 選ばれたサンプルがKKT違反かチェック*/
if (r1<-tol & a1<C) | (tol<r1 & 0<a1) then do;

  /*KKT違反なら以下の3つの方法で,もうひとつのサンプルを選ぶ.違反でなければreturn=0を返す*/
  /*try-1 0<a2<Cを満たすもののうち,E1とE2の相違が最大のサンプルをもうひとつのサンプルとして更新を試みる*/
  i2=-1;
  tmax=0;
  do k=1 to N;
    if (0<a[k,] & a[k,]<C) then do;
	  E2=E[k,];
      temp=abs(E1-E2);
      if (temp>tmax) then do;
        tmax=temp;
        i2=k;
      end;
    end;
  end;
  if (i2>0) then do;
    if takestep(i1,i2) then return(1);
  end;

  /*try-2 更新できなければ,0<a2<Cを満たすサンプルを順次もうひとつのサンプルとして更新を試みる*/
  k0=int(n*uniform(0))+1;
  k=k0;
  do while(k<(k0+N));
    i2=mod(k,N)+1;
    if (0<a[i2,] & a[i2,]<C) then do;
      if takestep(i1,i2) then return(1);
    end;
    k=k+1;
  end;

  /*try-3 更新できなければ,全データから順次もうひとつのサンプルを選び更新を試みる*/
  k0=int(n*uniform(0))+1;
  k=k0;
  do while(k<(k0+N));
	i2=mod(k,N)+1;
    if takestep(i1,i2) then return(1);
	k=k+1;
  end;

end;

return(0);

finish;


start svm(i) global(C,tol,eps,N,d,x,y,b,a,E);
/*SVMの出力*/
  s=0;
  do k=1 to N;
    s = s + a[k,]*y[k,]*kernel(k,i);
  end;
  s=s-b;
  return(s);
finish;


start kernel2(i)  global(C,tol,eps,N,d,x,y,b,a,E);
/* IMLモジュールに同じ引数名を複数書けなかったので別に定義した*/
j=i;
inner_product=kernel(i,j);
return(inner_product);
finish;


start takestep(i1,i2)   global(C,tol,eps,N,d,x,y,b,a,E);
/*2点の更新を試みる*/
if i1=i2 then return(0);/* 更新失敗--->更新数0を返す */

y1=y[i1,];
a1=a[i1,];
if 0<a1 & a1<c then E1=E[i1,];
else E1=SVM(i1)-y1;
y2=y[i2,];
a2=a[i2,];
if 0<a2 & a2<c then E2=E[i2,];
else E2=SVM(i2)-y2;
s=y1*y2;

if y1=y2 then do;L=max(0,a1+a2-C);
                 H=min(C,a1+a2);end;
else          do;L=max(0,-a1+a2);
                 H=min(C,C-a1+a2);end;

if L=H then return(0);/* 更新失敗--->更新数0を返す */

k11=kernel2(i1);
k12=kernel(i1,i2);
k22=kernel2(i2);
eta=2*k12-k11-k22;

if eta<0 then do;
  a2new=a2+y2*(E2-E1)/eta;
  a2new=max(L,min(H,a2new));
end;
else do;
  c1=eta/2;
  c2=y2*(E1-E2)-eta*a2;
  Lobj=c1*L*L+c2*L;
  Hobj=c1*H*H+c2*H;
  if Lobj>(Hobj-eps) then a2new=H;
  else                    a2new=a2;
end;

if abs(a2new-a2)<eps*(a2new+a2+eps) then return(0);

a1new=a1-s*(a2new-a2);
if a1new<0 then do;
  a2new=a2new+s*a1new;
  a1new=0;
end;
else if a1new>C then do;
  t=a1new-C;
  a2new=a2new+s*t;
  a1new=C;
end;

/* bの更新 */
if (0<a1new & a1new<C) then
  bnew = b + E1 + y1*(a1new-a1)*k11 + y2*(a2new-a2)*k12;
else if (0<a2new & a2new<C) then
  bnew = b + E2 + y1*(a1new-a1)*k12 + y2*(a2new-a2)*k22;
else do;
  b1 = b + E1 + y1*(a1new-a1)*k11 + y2*(a2new-a2)*k12;
  b2 = b + E2 + y1*(a1new-a1)*k12 + y2*(a2new-a2)*k22;
  bnew=(b1+b2)/2;
end;
delta_b=bnew-b;
b=bnew;

/* Eの更新 */
t1=y1*(a1new-a1);
t2=y2*(a2new-a2);
do i=1 to N;
  if (0<a[i,] & a[i,]<C) then
     E[i,] = E[i,] + t1*kernel(i1,i) + t2*kernel(i2,i) - delta_b;
end;
E[i1,]=0;
E[i2,]=0;

/* aの更新 */
a[i1,]=a1new;
a[i2,]=a2new;

*print i1 i2 a;

return(1);

finish;


**************************************************
実行手順
1.学習データを与える
2.学習パラメータを指定
 C:ハードマージンなら大きく,ソフトマージンなら小さくとる
3.カーネルを選び,そのパラメータを指定
4.SMOアルゴリズムを実行
**************************************************;

/* 1.学習データ */
data={/* 1行1サンプルで,各行の最後の列が分類変数(1 or -1)*/
3 0  1,
3 1  1,
3 2  1,
3 3  1,
2 3  1,
1 0 -1,
0 1 -1
};

/* 2.学習パラメータ */
C=100000;/* 数字が大きいほどハード(分離失敗点を許さない)*/
tol=0.001;
eps=0.001;

/* 3.カーネル関数 */
start kernel(i1,i2)  global(C,tol,eps,N,d,x,y,b,a,E);
inner_product=x[i1,]*t(x[i2,]);              /*ドットカーネル(線形SVM)*/
*inner_product=exp(-ssq(x[i1,]-x[i2,])/2);    /*ガウシアンカーネル*/
*inner_product=1/(1+exp(-x[i1,]*t(x[i2,])/2));/*シグモイドカーネル*/
*inner_product=(x[i1,]*t(x[i2,])+2)**3;       /*多項式カーネル*/
return(inner_product);
finish;

/* sequential minimal optimazation for SVM */
run smo();


print a;
/*
    A

    0.444
        0
        0
        0
    0.111
    0.555
        0
 */
quit;