最新消息: USBMI致力于为网友们分享Windows、安卓、IOS等主流手机系统相关的资讯以及评测、同时提供相关教程、应用、软件下载等服务。

我的libsvm文档java 文档

IT圈 admin 35浏览 0评论

2024年6月14日发(作者:遇飞语)

LibSVM(JAVA)二次开发接口调用及源码更改的文档

浙江大学协调服务研究所

文档整理:陈伟 chenweishaoxing#

下载libsvm

方法:google libsvm找到官网下载:

/~cjlin/libsvm/ ,其中图片中椭圆的

解压文档

下载下来libsvm工具包有几个版本的,其中python的最经典,用的人比较多,

还支持matlab,C++等等。我们用的java版的,就到解压开的java文件夹中!

java文件夹

导入到eclipse工程中

创建一个java工程,把上图的源码复制到eclipse中,如同所示

在工程下创建一个文件夹,里面存放训练测试用的数据

首次调用的Demo举例

在java的工程中创建一个属于自己的包,然后写一个mian类。如图

package ;

import ption;

import _predict;

import _train;

public class ComMain {

public static void main(String[] args) throws IOException {

String []arg ={ "", //

存放SVM训练模型用的数据的路径

"trainfilemodel_"};

//存放SVM通过训练数据训/ //练出来的模型的路径

String []parg={"", //这个是存放测试数据

"trainfilemodel_", //调用的是训练以后的模型

"trainfileout_"}; //生成的结果的文件的路径

n("........SVM运行开始..........");

//创建一个训练对象

svm_train t = new svm_train();

//创建一个预测或者分类的对象

svm_predict p= new svm_predict();

(arg); //调用

(parg); //调用

}

}

6.运行工程就可以看到了结果了

Libsvm二次开发的首先要熟悉调用接口的源码

你一定会有疑问:SVM的参数怎么设置,cross-validation怎么用。那么我们首

先来说明一个问题,交叉验证在一般情况下要自己开发自己写。Libsvm内置了

交叉验证,但是如果我希望用同交叉验证的数据用决策树来做,怎么办,显然

Libsvm并没有保存交叉验证的数据。

============================================================

我已经将注释写在了源码中。

Svm_train类的文档说明

package service;

import libsvm.*;

import .*;

import .*;

public class svm_train {

private svm_parameter param; // set by parse_command_line

private svm_problem prob; // set by read_problem

private svm_model model;

private String input_file_name; // set by parse_command_line

private String model_file_name; // set by parse_command_line

private String error_msg;

private int cross_validation;

private int nr_fold;

private static svm_print_interface svm_print_null =

svm_print_interface()

{

public void print(String s) {}

};

private static void exit_with_help()

{

(

"Usage: svm_train [options] training_set_file [model_file]n"

+"options:n"

+"-s svm_type : set type of SVM (default 0)n"

new

+" 0 -- C-SVCn"

+" 1 -- nu-SVCn"

+" 2 -- one-class SVMn"

+" 3 -- epsilon-SVRn"

+" 4 -- nu-SVRn"

+"-t kernel_type : set type of kernel function (default 2)n"

+" 0 -- linear: u'*vn"

+" 1 -- polynomial: (gamma*u'*v + coef0)^degreen"

+" 2 -- radial basis function: exp(-gamma*|u-v|^2)n"

+" 3 -- sigmoid: tanh(gamma*u'*v + coef0)n"

+" 4 -- precomputed kernel (kernel values in training_set_file)n"

+"-d degree : set degree in kernel function (default 3)n"

+"-g gamma : set gamma in kernel function (default 1/num_features)n"

+"-r coef0 : set coef0 in kernel function (default 0)n"

+"-c cost : set the parameter C of C-SVC, epsilon-SVR, and nu-SVR

(default 1)n"

+"-n nu : set the parameter nu of nu-SVC, one-class SVM, and nu-SVR

(default 0.5)n"

+"-p epsilon : set the epsilon in loss function of epsilon-SVR (default

0.1)n"

+"-m cachesize : set cache memory size in MB (default 100)n"

+"-e epsilon : set tolerance of termination criterion (default 0.001)n"

+"-h shrinking : whether to use the shrinking heuristics, 0 or 1 (default

1)n"

+"-b probability_estimates : whether to train a SVC or SVR model for

probability estimates, 0 or 1 (default 0)n"

+"-wi weight : set the parameter C of class i to weight*C, for C-SVC

(default 1)n"

+"-v n : n-fold cross validation moden"

+"-q : quiet mode (no outputs)n"

);

(1);

}

private void do_cross_validation()

{

int i;

int total_correct = 0;

double total_error = 0;

double sumv = 0, sumy = 0, sumvv = 0, sumyy = 0, sumvy = 0;

double[] target = new double[prob.l];

_cross_validation(prob,param,nr_fold,target);

if(_type == svm_N_SVR ||

_type == svm__SVR)

{

for(i=0;i

{

double y = prob.y[i];

double v = target[i];

total_error += (v-y)*(v-y);

sumv += v;

sumy += y;

sumvv += v*v;

sumyy += y*y;

sumvy += v*y;

}

("Cross Validation Mean squared error =

"+total_error/prob.l+"n");

("Cross Validation Squared correlation coefficient =

"+

((prob.l*sumvy-sumv*sumy)*(prob.l*sumvy-sumv*sumy))/

((prob.l*sumvv-sumv*sumv)*(prob.l*sumyy-sumy*sumy))+"n"

);

}

else

{

for(i=0;i

if(target[i] == prob.y[i])

++total_correct;

("Cross Validation Accuracy =

"+100.0*total_correct/prob.l+"%n");

}

}

private void run(String argv[]) throws IOException

{

n("我的数组的长度是:" + ) ;

parse_command_line(argv); //解析svm参数的配置,我们去这个方法看

看,你可以按住crlt,然后鼠标点击这个方法

read_problem();

error_msg = _check_parameter(prob,param);

if(error_msg != null)

{

("ERROR: "+error_msg+"n");

(1);

}

if(cross_validation != 0)

{

do_cross_validation();

}

else

{

model = _train(prob,param);

_save_model(model_file_name,model);

}

}

public static void main(String argv[]) throws IOException

{

svm_train t = new svm_train();

(argv);

}

private static double atof(String s)

{

double d = f(s).doubleValue();

if ((d) || nite(d))

{

("NaN or Infinity in inputn");

(1);

}

return(d);

}

//解析控制台输入的string类型的值,因为svm的参数是由整数来代表的,

//那么通过这个方法将控制台输入的字符串解析成为整数的

private static int atoi(String s)

{

return nt(s);

}

//欢迎来到解析svm参数的方法

private void parse_command_line(String argv[])

{

int i; //设置了一个方法域的一个i变量,用于遍历argv这个字符串数组

的的哦

svm_print_interface print_func = null; // default printing to stdout,这

个是一个接口

//创建一个SVM的参数对象,SVM的参数都在这个对象中。

//具体的参数对象可以看svm_parameter这个类

param = new svm_parameter();

// 默认的SVM设置的值,如果需要修改,那么要从控制台输入,然后

下面的for循环会解析svm的参数设置

//我还没用全部搞懂这些参数的意思,但是这些参数的作用完全可以在

帮助信息中看到。

_type = svm_parameter.C_SVC; //默认的支持向量

//_type = svm__SVC;

_type = svm_; //默认的核函数高斯核函数

= 3;

= 0; // 1/num_features

0 = 0;

= 0.01;

_size = 100;

param.C = 1;

= 1e-3;

param.p = 0.1;

ing = 1;

ility = 0;

_weight = 0;

_label = new int[0];

= new double[0];

cross_validation = 0; //表示关闭交叉验证,1表示开启交叉验证(这里不

能设置1,因为你设置了也没用)

// 解析选项SVM参数的选项,如果控制台没有输入对于的字符串,那

么SVM将使用的是默认的SVM的参数设置

for(i=0;i<;i++) //初始化i

{

//返回的是argv这个数组第i个字符串第一个字符,这里说明控制

台要输入的时候首先要写一个'-'号.(比如i=4,=10,argv[4])

//如果不写,那么将break本次的循环,跳出的是整个for循环,所

以让文件保存的路径在数组中写到最后

if(argv[i].charAt(0) != '-') break; //如果一遇到不是这个跳出的是整

个for循环

//如果查询到了'-'字符,那么会执行这一步了。

//判断这个i的值是不是大于或者等于argv的长度了,如果是数组

的长度了,那么就打印出帮助信息。并且会中断虚拟机了

//++i >= 这个应该是先用i再加1,那么下面的操作的时

候就是i = i + 1了(i=5)

if(++i>=)

exit_with_help();

//如果执行了第二个if,那么会执行到这里了。这里的i = 5

switch(argv[i-1].charAt(1)) //用到的字符串仍然是argv[5-1]=argv[4],

解析的是第2个字符。

{

case 's': //设置svm的类型

_type = atoi(argv[i]); //这个赋值就是将argv[5],

赋值过去了

break;

case 't': //设置svm的核函数类型

_type = atoi(argv[i]);

break;

case 'd': //设置svm参数d的大小,用于多项式核函数

= atoi(argv[i]);

break;

case 'g': //赋值gamma的

= atof(argv[i]);

break;

case 'r': //赋值coef0的值

0 = atof(argv[i]);

break;

case 'n': //赋值n的值

= atof(argv[i]);

break;

case 'm': //赋值缓存的值

_size = atof(argv[i]);

break;

case 'c': //赋值的是惩罚因子的大小

param.C = atof(argv[i]);

break;

case 'e': //赋值的eps的值

= atof(argv[i]);

break;

case 'p': //赋值******我不想写下去了,因为在实际的应用中,

我还没用用到下面的参数。抱歉。

param.p = atof(argv[i]);

break;

case 'h': //

ing = atoi(argv[i]);

break;

case 'b': //要不要打印出分类的准确率的值

ility = atoi(argv[i]);

break;

case 'q':

print_func = svm_print_null;

i--;

break;

case 'v': //设置的交叉验证的值

cross_validation = 1; //开启交叉验证

nr_fold = atoi(argv[i]);

if(nr_fold < 2) //交叉验证的值不能小于1

{

("n-fold cross validation: n must >=

2n");

exit_with_help();

}

break;

case 'w':

++_weight;

{

int[] old = _label;

_label = new int[_weight];

opy(old,0,_label,0,_weight-1);

}

{

double[] old = ;

= new double[_weight];

opy(old,0,,0,_weight-1);

}

_label[_weight-1] =

atoi(argv[i-1].substring(2));

[_weight-1] = atof(argv[i]);

break;

default:

//如果一个字符都匹配不到,很遗憾要中断JVM了,并且会

打印出那个位子的字符出现了错误,然后打印出帮助信息

("Unknown option: " + argv[i-1] + "n");

exit_with_help();

}//end switch

} //end for

_set_print_string_function(print_func); //打印出是不是静音模

// determine filenames决定文件名

/**

* 我必须中断下操作来说明控制台应该怎么输入的

* argv = {"-s","1","-t","3","-w","5","我是训练用的文件路径","我

是训练完以后保存模型的路径"}

* 具体的1,3,5参数要参考官方说明文档,或者查看设置参数那个类

的参数。

* 看到这,你可以继续看下去了

*/

if(i>=)//这里是了防止没有输入存放文件的路径,或者存放文

件的路径不够

exit_with_help();

//到这里,i的应该是字符串数组的倒数第二个了

//其实我一直搞不清楚,为什么for循环完毕了,这个i不是argv数组

的长度呢 ?不是的i的值是数组长度-1也就是数组中倒数第二个位子

input_file_name = argv[i]; //将训练的文件路径赋值

n("POSITION="+(i+1)+"我的训练用的数据存放的路

径是:" + input_file_name) ;

if(i<-1) //如果i的值比数组长度-1还小,那么将argv[i]下一

个字符串赋值给存放模型的路径

model_file_name = argv[i+1]; //将训练以后的模型路径赋值

else

{

int p = argv[i].lastIndexOf('/');

++p; //

model_file_name = argv[i].substring(p)+".model";

}

n("我的训练用的数据存放的路径是:" +

model_file_name) ;

} // end run function

// read in a problem (in svmlight format)

private void read_problem() throws IOException

{

BufferedReader fp = new BufferedReader(new

FileReader(input_file_name));

Vector vy = new Vector();

Vector vx = new Vector();

int max_index = 0;

while(true)

{

String line = ne();

if(line == null) break;

StringTokenizer st = new StringTokenizer(line," tnrf:");

ment(atof(ken()));

int m = okens()/2;

svm_node[] x = new svm_node[m];

for(int j=0;j

{

x[j] = new svm_node();

x[j].index = atoi(ken());

x[j].value = atof(ken());

}

if(m>0) max_index = (max_index, x[m-1].index);

ment(x);

}

prob = new svm_problem();

prob.l = ();

prob.x = new svm_node[prob.l][];

for(int i=0;i

prob.x[i] = tAt(i);

prob.y = new double[prob.l];

for(int i=0;i

prob.y[i] = tAt(i);

if( == 0 && max_index > 0)

= 1.0/max_index;

if(_type == svm_PUTED)

for(int i=0;i

{

if (prob.x[i][0].index != 0)

{

("Wrong kernel matrix: first column must

be 0:sample_serial_numbern");

(1);

}

if ((int)prob.x[i][0].value <= 0 || (int)prob.x[i][0].value >

max_index)

{

("Wrong input format:

sample_serial_number out of rangen");

(1);

}

}

();

}

}

Svm_predict类的文档说明

package service;

import libsvm.*;

import .*;

import .*;

public class svm_predict {

private static double atof(String s)

{

return f(s).doubleValue();

}

private static int atoi(String s)

{

return nt(s);

}

private static void predict(BufferedReader input, DataOutputStream output, svm_model

model, int predict_probability) throws IOException

{

//欢迎来到这个预测方法,下面开始分析

//设置方法内局部变量

//这个是预测正确的个数的

int correct = 0;

//这个是预测的个数一共有几个

int total = 0;

//分类或者预测的准确率,所以用double error = correct / total ;

double error = 0;

//几个中间变量的参数

double sumv = 0, sumy = 0, sumvv = 0, sumyy = 0, sumvy = 0;

int svm_type=_get_svm_type(model);

int nr_class=_get_nr_class(model);

double[] prob_estimates=null;

//如果传入进来的1(默认是0) ,那么从这里开始执行

//这个是不能用回归svm的

if(predict_probability == 1)

{

if(svm_type == svm_N_SVR || //回归SVM

svm_type == svm__SVR) //回归SVM

{

//打印出出错误了,svm数据不匹配

("Prob. model for test data: target value = predicted value +

z,nz: Laplace distribution

e^(-|z|/sigma)/(2sigma),sigma="+_get_svr_probability(model)+"n");

}

//用于分类的话就执行这个了

else

{

int[] labels=new int[nr_class]; //取得标签。分类用的标签

_get_labels(model,labels);

prob_estimates = new double[nr_class];

ytes("labels");//写入到文件中去

for(int j=0;j

ytes(" "+labels[j]);

ytes("n");

}

} //end if

//这个一定会执行的

while(true)

{

String line = ne(); //一行一行的读取

if(line == null) break;//如果出现空行,那么就停止,所以在文件中中间不能有

空行

StringTokenizer st = new StringTokenizer(line," tnrf:");

double target = atof(ken());

int m = okens()/2;

svm_node[] x = new svm_node[m];

for(int j=0;j

{

x[j] = new svm_node();

x[j].index = atoi(ken());

x[j].value = atof(ken());

}

double v;

//如果是分类svm就执行这个

if (predict_probability==1 && (svm_type==svm_parameter.C_SVC ||

svm_type==svm__SVC))

{

v = _predict_probability(model,x,prob_estimates);

ytes(v+" ");

for(int j=0;j

ytes(prob_estimates[j]+" ");

ytes("n");

} //end id

else

{

v = _predict(model,x);

ytes(v+"n");

}

/**

* 做二次开发,这里可动手脚,你可以输入要具体预测对的类在这里显示出来

等等

*/

if(v == target) //如果预测正确,那么分类的正确就加一

++correct;

error += (v-target)*(v-target);

sumv += v;

sumy += target;

sumvv += v*v;

sumyy += target*target;

sumvy += v*target;

++total;

} //end while

//如果是回归的svm就用这个

if(svm_type == svm_N_SVR ||

svm_type == svm__SVR)

{

/**

* 这里打印出来的是用于回归问题的信息regression

*/

("Mean squared error = "+error/total+" (regression)n");

("Squared correlation coefficient = "+

((total*sumvy-sumv*sumy)*(total*sumvy-sumv*sumy))/

((total*sumvv-sumv*sumv)*(total*sumyy-sumy*sumy))+

" (regression)n");

}

else //这里打印出来的是用于分类问题的信息classification

("Accuracy = "+(double)correct/total*100+

"% ("+correct+"/"+total+") (classification)n");

}//end function

private static void exit_with_help()

{

("usage: svm_predict [options] test_file model_file output_filen"

+"options:n"

+"-b probability_estimates: whether to predict probability estimates, 0 or 1 (default 0);

one-class SVM not supported yetn");

(1);

}

//首先从这里读

public static void main(String argv[]) throws IOException

{

int i, predict_probability=0; //设置两个值,后面一个0表示不开启

// parse options解析选项,解析和train类类似不做说明

for(i=0;i<;i++)

{

if(argv[i].charAt(0) != '-') break;

++i;

switch(argv[i-1].charAt(1))

{

case 'b':

predict_probability = atoi(argv[i]);

break;

default:

("Unknown option: " + argv[i-1] + "n");

exit_with_help();

}

}

if(i>=-2)

exit_with_help();

try

{

BufferedReader input = new BufferedReader(new FileReader(argv[i]));

DataOutputStream output = new DataOutputStream(new

BufferedOutputStream(new FileOutputStream(argv[i+2])));

svm_model model = _load_model(argv[i+1]);

if(predict_probability == 1)

{

if(_check_probability_model(model)==0)

{

("Model does not support probabiliy estimatesn");

(1);

}

}

else

{

if(_check_probability_model(model)!=0)

{

("Model supports probability estimates, but disabled in

prediction.n");

}

}

/**

* 重点来看这个,我们要预测或者分类,中想返回一个预测正确或者分类正

确的类别的

* 你可以按住ctrl,然后用鼠标点击这个类

* 三个个参数:

* 一个是模型,已经训练出来的模型

* 一个是输入的测试数据

* 一个是是不是要打印出信息(我没用过,默认是0)

*/

predict(input,output,model,predict_probability);

(); //涉及到文件的操作有关闭的一些操作

();

}

catch(FileNotFoundException e)

{

exit_with_help();

}

}

}

catch(ArrayIndexOutOfBoundsException e)

{

exit_with_help();

}

下面是一些二次开发的介绍

隔点搜索的代码怎么写?

1. 我们在寻找最佳svm的参数组合的时候不可能自己去手动的去设置.比如高斯核函数有

两个参数要设置,c和gamma.我们要改写train的代码,将c和gama的参数设置到man

方法中去,直接通过调用main就可以改变c和gamma的

打圈的是自己改的。

如果你能看懂上面的意思,那么我想你的java基础完全可以想出来怎么讲correct正确的

作为返回值返回到主main中,你又可以利用这个来写出属于自己的交叉验证

你可以参考一下的调用代码

package _RBF;

import edWriter;

import ;

import iter;

import ption;

import _predict;

import _train;

public class ComMain_data_dea {

/*

* @param args

* @throws IOException

*/

public void main(int ix) throws IOException {

// TODO Auto-generated method stub

// String []arg ={"file_","file_modeltrain1_"};

// String

[]parg={"file_","file_modeltrain1_","file_outtrain1_"};

// n("........SVM运行开始..........");

// svm_train t = new svm_train();

// svm_predict p= new svm_predict();

// (arg);

// (parg);

// n("........SVM运行结束..........");

File f_accurate = new File("file_data_accurateacc_data_");

if(f_accurate == null){

f_NewFile() ;

}else{

//f_accurate.

}

FileWriter fw = new FileWriter(f_accurate);

BufferedWriter bw = new BufferedWriter(fw);

("c表示惩罚因子,gamama表示核函数变量, temp表示交叉验证准确验

证的个数,accor表示准确率");

e();

double temp = 0;

double c = 0;

double val_c = 0 ;

double val_gamma = 0 ;

double accor = 0;

double maxaccor = 0;

double max_c = 0;

double max_gamama = 0;

double max_temp = 0;

String tempStr_c ="";

String tempStr_g ="";

String tempStr_ac ="";

for(int i1=0;i1<100;i1++){

val_gamma = 0;

val_c = val_c + 0.1 ;

for(int i2=0;i2<200;i2++){

val_gamma = val_gamma + 0.01 ;

for(int i=1;i<6;i++){

String []arg

={"file_data_ccrtraintrain"+i+".txt","file_data_ccrmodeltrain"+i+"_"};

String

[]parg={"file_data_ccrtesttest"+i+".txt","file_data_ccrmodeltrain"+i+"_","file_data

_ccrouttrain"+i+"_"};

n("........SVM运行开始.........."+i);

svm_train t = new svm_train();

svm_predict p= new svm_predict();

(arg,val_c,val_gamma,ix,0);

c = (parg);

temp = temp + c;

}//end for

accor = (double)temp/23;

tempStr_c = (val_c+"0000000").substring(0, 3);

tempStr_g = (val_gamma+"0000000").substring(0, 4);

tempStr_ac = (accor+"000000").substring(0, 5);

//("c="+val_c+",gamama="+val_gamma+",temp="+temp+",accor="+accor);

("c="+tempStr_c+" ,gamama="+tempStr_g+" ,temp="+temp+" ,accor="+temp

Str_ac);

e();

if(accor>maxaccor){

maxaccor = accor ;

max_c = val_c;

max_gamama = val_gamma;

max_temp = temp;

// break;

}

c = 0;

temp = 0;

}//end for

if(accor>maxaccor){

maxaccor = accor ;

// break;

}

}//end for

("max_c="+max_c+",max_gamama="+max_gamama+",max_temp="+max_temp+",

max_accor="+maxaccor);

e();

();

();

//double accor = (double)temp/23;

n("........"+maxaccor);

n("**********max_c************" + max_c);

n("**********max_gamama************" + max_gamama);

n("**********max_temp************" + max_temp);

// for(int i=1;i<6;i++){

//

// String []arg ={"file_traintrain"+i+".txt","file_modeltrain"+i+"_"};

// String

[]parg={"file_testtest"+i+".txt","file_modeltrain"+i+"_","file_outtrain"+i+"_"

};

// n("........SVM运行开始.........."+i);

// svm_train t = new svm_train();

// svm_predict p= new svm_predict();

// (arg,val_c,val_gamma);

// c = (parg);

// temp = temp + c;

// n("........C.........."+c);

//

// }//end for

// double accor = (double)temp/23;

// n("........"+accor);

}//end main

}//end class

说明:

你可以在csdn下载到包括demo的说明文档

2024年6月14日发(作者:遇飞语)

LibSVM(JAVA)二次开发接口调用及源码更改的文档

浙江大学协调服务研究所

文档整理:陈伟 chenweishaoxing#

下载libsvm

方法:google libsvm找到官网下载:

/~cjlin/libsvm/ ,其中图片中椭圆的

解压文档

下载下来libsvm工具包有几个版本的,其中python的最经典,用的人比较多,

还支持matlab,C++等等。我们用的java版的,就到解压开的java文件夹中!

java文件夹

导入到eclipse工程中

创建一个java工程,把上图的源码复制到eclipse中,如同所示

在工程下创建一个文件夹,里面存放训练测试用的数据

首次调用的Demo举例

在java的工程中创建一个属于自己的包,然后写一个mian类。如图

package ;

import ption;

import _predict;

import _train;

public class ComMain {

public static void main(String[] args) throws IOException {

String []arg ={ "", //

存放SVM训练模型用的数据的路径

"trainfilemodel_"};

//存放SVM通过训练数据训/ //练出来的模型的路径

String []parg={"", //这个是存放测试数据

"trainfilemodel_", //调用的是训练以后的模型

"trainfileout_"}; //生成的结果的文件的路径

n("........SVM运行开始..........");

//创建一个训练对象

svm_train t = new svm_train();

//创建一个预测或者分类的对象

svm_predict p= new svm_predict();

(arg); //调用

(parg); //调用

}

}

6.运行工程就可以看到了结果了

Libsvm二次开发的首先要熟悉调用接口的源码

你一定会有疑问:SVM的参数怎么设置,cross-validation怎么用。那么我们首

先来说明一个问题,交叉验证在一般情况下要自己开发自己写。Libsvm内置了

交叉验证,但是如果我希望用同交叉验证的数据用决策树来做,怎么办,显然

Libsvm并没有保存交叉验证的数据。

============================================================

我已经将注释写在了源码中。

Svm_train类的文档说明

package service;

import libsvm.*;

import .*;

import .*;

public class svm_train {

private svm_parameter param; // set by parse_command_line

private svm_problem prob; // set by read_problem

private svm_model model;

private String input_file_name; // set by parse_command_line

private String model_file_name; // set by parse_command_line

private String error_msg;

private int cross_validation;

private int nr_fold;

private static svm_print_interface svm_print_null =

svm_print_interface()

{

public void print(String s) {}

};

private static void exit_with_help()

{

(

"Usage: svm_train [options] training_set_file [model_file]n"

+"options:n"

+"-s svm_type : set type of SVM (default 0)n"

new

+" 0 -- C-SVCn"

+" 1 -- nu-SVCn"

+" 2 -- one-class SVMn"

+" 3 -- epsilon-SVRn"

+" 4 -- nu-SVRn"

+"-t kernel_type : set type of kernel function (default 2)n"

+" 0 -- linear: u'*vn"

+" 1 -- polynomial: (gamma*u'*v + coef0)^degreen"

+" 2 -- radial basis function: exp(-gamma*|u-v|^2)n"

+" 3 -- sigmoid: tanh(gamma*u'*v + coef0)n"

+" 4 -- precomputed kernel (kernel values in training_set_file)n"

+"-d degree : set degree in kernel function (default 3)n"

+"-g gamma : set gamma in kernel function (default 1/num_features)n"

+"-r coef0 : set coef0 in kernel function (default 0)n"

+"-c cost : set the parameter C of C-SVC, epsilon-SVR, and nu-SVR

(default 1)n"

+"-n nu : set the parameter nu of nu-SVC, one-class SVM, and nu-SVR

(default 0.5)n"

+"-p epsilon : set the epsilon in loss function of epsilon-SVR (default

0.1)n"

+"-m cachesize : set cache memory size in MB (default 100)n"

+"-e epsilon : set tolerance of termination criterion (default 0.001)n"

+"-h shrinking : whether to use the shrinking heuristics, 0 or 1 (default

1)n"

+"-b probability_estimates : whether to train a SVC or SVR model for

probability estimates, 0 or 1 (default 0)n"

+"-wi weight : set the parameter C of class i to weight*C, for C-SVC

(default 1)n"

+"-v n : n-fold cross validation moden"

+"-q : quiet mode (no outputs)n"

);

(1);

}

private void do_cross_validation()

{

int i;

int total_correct = 0;

double total_error = 0;

double sumv = 0, sumy = 0, sumvv = 0, sumyy = 0, sumvy = 0;

double[] target = new double[prob.l];

_cross_validation(prob,param,nr_fold,target);

if(_type == svm_N_SVR ||

_type == svm__SVR)

{

for(i=0;i

{

double y = prob.y[i];

double v = target[i];

total_error += (v-y)*(v-y);

sumv += v;

sumy += y;

sumvv += v*v;

sumyy += y*y;

sumvy += v*y;

}

("Cross Validation Mean squared error =

"+total_error/prob.l+"n");

("Cross Validation Squared correlation coefficient =

"+

((prob.l*sumvy-sumv*sumy)*(prob.l*sumvy-sumv*sumy))/

((prob.l*sumvv-sumv*sumv)*(prob.l*sumyy-sumy*sumy))+"n"

);

}

else

{

for(i=0;i

if(target[i] == prob.y[i])

++total_correct;

("Cross Validation Accuracy =

"+100.0*total_correct/prob.l+"%n");

}

}

private void run(String argv[]) throws IOException

{

n("我的数组的长度是:" + ) ;

parse_command_line(argv); //解析svm参数的配置,我们去这个方法看

看,你可以按住crlt,然后鼠标点击这个方法

read_problem();

error_msg = _check_parameter(prob,param);

if(error_msg != null)

{

("ERROR: "+error_msg+"n");

(1);

}

if(cross_validation != 0)

{

do_cross_validation();

}

else

{

model = _train(prob,param);

_save_model(model_file_name,model);

}

}

public static void main(String argv[]) throws IOException

{

svm_train t = new svm_train();

(argv);

}

private static double atof(String s)

{

double d = f(s).doubleValue();

if ((d) || nite(d))

{

("NaN or Infinity in inputn");

(1);

}

return(d);

}

//解析控制台输入的string类型的值,因为svm的参数是由整数来代表的,

//那么通过这个方法将控制台输入的字符串解析成为整数的

private static int atoi(String s)

{

return nt(s);

}

//欢迎来到解析svm参数的方法

private void parse_command_line(String argv[])

{

int i; //设置了一个方法域的一个i变量,用于遍历argv这个字符串数组

的的哦

svm_print_interface print_func = null; // default printing to stdout,这

个是一个接口

//创建一个SVM的参数对象,SVM的参数都在这个对象中。

//具体的参数对象可以看svm_parameter这个类

param = new svm_parameter();

// 默认的SVM设置的值,如果需要修改,那么要从控制台输入,然后

下面的for循环会解析svm的参数设置

//我还没用全部搞懂这些参数的意思,但是这些参数的作用完全可以在

帮助信息中看到。

_type = svm_parameter.C_SVC; //默认的支持向量

//_type = svm__SVC;

_type = svm_; //默认的核函数高斯核函数

= 3;

= 0; // 1/num_features

0 = 0;

= 0.01;

_size = 100;

param.C = 1;

= 1e-3;

param.p = 0.1;

ing = 1;

ility = 0;

_weight = 0;

_label = new int[0];

= new double[0];

cross_validation = 0; //表示关闭交叉验证,1表示开启交叉验证(这里不

能设置1,因为你设置了也没用)

// 解析选项SVM参数的选项,如果控制台没有输入对于的字符串,那

么SVM将使用的是默认的SVM的参数设置

for(i=0;i<;i++) //初始化i

{

//返回的是argv这个数组第i个字符串第一个字符,这里说明控制

台要输入的时候首先要写一个'-'号.(比如i=4,=10,argv[4])

//如果不写,那么将break本次的循环,跳出的是整个for循环,所

以让文件保存的路径在数组中写到最后

if(argv[i].charAt(0) != '-') break; //如果一遇到不是这个跳出的是整

个for循环

//如果查询到了'-'字符,那么会执行这一步了。

//判断这个i的值是不是大于或者等于argv的长度了,如果是数组

的长度了,那么就打印出帮助信息。并且会中断虚拟机了

//++i >= 这个应该是先用i再加1,那么下面的操作的时

候就是i = i + 1了(i=5)

if(++i>=)

exit_with_help();

//如果执行了第二个if,那么会执行到这里了。这里的i = 5

switch(argv[i-1].charAt(1)) //用到的字符串仍然是argv[5-1]=argv[4],

解析的是第2个字符。

{

case 's': //设置svm的类型

_type = atoi(argv[i]); //这个赋值就是将argv[5],

赋值过去了

break;

case 't': //设置svm的核函数类型

_type = atoi(argv[i]);

break;

case 'd': //设置svm参数d的大小,用于多项式核函数

= atoi(argv[i]);

break;

case 'g': //赋值gamma的

= atof(argv[i]);

break;

case 'r': //赋值coef0的值

0 = atof(argv[i]);

break;

case 'n': //赋值n的值

= atof(argv[i]);

break;

case 'm': //赋值缓存的值

_size = atof(argv[i]);

break;

case 'c': //赋值的是惩罚因子的大小

param.C = atof(argv[i]);

break;

case 'e': //赋值的eps的值

= atof(argv[i]);

break;

case 'p': //赋值******我不想写下去了,因为在实际的应用中,

我还没用用到下面的参数。抱歉。

param.p = atof(argv[i]);

break;

case 'h': //

ing = atoi(argv[i]);

break;

case 'b': //要不要打印出分类的准确率的值

ility = atoi(argv[i]);

break;

case 'q':

print_func = svm_print_null;

i--;

break;

case 'v': //设置的交叉验证的值

cross_validation = 1; //开启交叉验证

nr_fold = atoi(argv[i]);

if(nr_fold < 2) //交叉验证的值不能小于1

{

("n-fold cross validation: n must >=

2n");

exit_with_help();

}

break;

case 'w':

++_weight;

{

int[] old = _label;

_label = new int[_weight];

opy(old,0,_label,0,_weight-1);

}

{

double[] old = ;

= new double[_weight];

opy(old,0,,0,_weight-1);

}

_label[_weight-1] =

atoi(argv[i-1].substring(2));

[_weight-1] = atof(argv[i]);

break;

default:

//如果一个字符都匹配不到,很遗憾要中断JVM了,并且会

打印出那个位子的字符出现了错误,然后打印出帮助信息

("Unknown option: " + argv[i-1] + "n");

exit_with_help();

}//end switch

} //end for

_set_print_string_function(print_func); //打印出是不是静音模

// determine filenames决定文件名

/**

* 我必须中断下操作来说明控制台应该怎么输入的

* argv = {"-s","1","-t","3","-w","5","我是训练用的文件路径","我

是训练完以后保存模型的路径"}

* 具体的1,3,5参数要参考官方说明文档,或者查看设置参数那个类

的参数。

* 看到这,你可以继续看下去了

*/

if(i>=)//这里是了防止没有输入存放文件的路径,或者存放文

件的路径不够

exit_with_help();

//到这里,i的应该是字符串数组的倒数第二个了

//其实我一直搞不清楚,为什么for循环完毕了,这个i不是argv数组

的长度呢 ?不是的i的值是数组长度-1也就是数组中倒数第二个位子

input_file_name = argv[i]; //将训练的文件路径赋值

n("POSITION="+(i+1)+"我的训练用的数据存放的路

径是:" + input_file_name) ;

if(i<-1) //如果i的值比数组长度-1还小,那么将argv[i]下一

个字符串赋值给存放模型的路径

model_file_name = argv[i+1]; //将训练以后的模型路径赋值

else

{

int p = argv[i].lastIndexOf('/');

++p; //

model_file_name = argv[i].substring(p)+".model";

}

n("我的训练用的数据存放的路径是:" +

model_file_name) ;

} // end run function

// read in a problem (in svmlight format)

private void read_problem() throws IOException

{

BufferedReader fp = new BufferedReader(new

FileReader(input_file_name));

Vector vy = new Vector();

Vector vx = new Vector();

int max_index = 0;

while(true)

{

String line = ne();

if(line == null) break;

StringTokenizer st = new StringTokenizer(line," tnrf:");

ment(atof(ken()));

int m = okens()/2;

svm_node[] x = new svm_node[m];

for(int j=0;j

{

x[j] = new svm_node();

x[j].index = atoi(ken());

x[j].value = atof(ken());

}

if(m>0) max_index = (max_index, x[m-1].index);

ment(x);

}

prob = new svm_problem();

prob.l = ();

prob.x = new svm_node[prob.l][];

for(int i=0;i

prob.x[i] = tAt(i);

prob.y = new double[prob.l];

for(int i=0;i

prob.y[i] = tAt(i);

if( == 0 && max_index > 0)

= 1.0/max_index;

if(_type == svm_PUTED)

for(int i=0;i

{

if (prob.x[i][0].index != 0)

{

("Wrong kernel matrix: first column must

be 0:sample_serial_numbern");

(1);

}

if ((int)prob.x[i][0].value <= 0 || (int)prob.x[i][0].value >

max_index)

{

("Wrong input format:

sample_serial_number out of rangen");

(1);

}

}

();

}

}

Svm_predict类的文档说明

package service;

import libsvm.*;

import .*;

import .*;

public class svm_predict {

private static double atof(String s)

{

return f(s).doubleValue();

}

private static int atoi(String s)

{

return nt(s);

}

private static void predict(BufferedReader input, DataOutputStream output, svm_model

model, int predict_probability) throws IOException

{

//欢迎来到这个预测方法,下面开始分析

//设置方法内局部变量

//这个是预测正确的个数的

int correct = 0;

//这个是预测的个数一共有几个

int total = 0;

//分类或者预测的准确率,所以用double error = correct / total ;

double error = 0;

//几个中间变量的参数

double sumv = 0, sumy = 0, sumvv = 0, sumyy = 0, sumvy = 0;

int svm_type=_get_svm_type(model);

int nr_class=_get_nr_class(model);

double[] prob_estimates=null;

//如果传入进来的1(默认是0) ,那么从这里开始执行

//这个是不能用回归svm的

if(predict_probability == 1)

{

if(svm_type == svm_N_SVR || //回归SVM

svm_type == svm__SVR) //回归SVM

{

//打印出出错误了,svm数据不匹配

("Prob. model for test data: target value = predicted value +

z,nz: Laplace distribution

e^(-|z|/sigma)/(2sigma),sigma="+_get_svr_probability(model)+"n");

}

//用于分类的话就执行这个了

else

{

int[] labels=new int[nr_class]; //取得标签。分类用的标签

_get_labels(model,labels);

prob_estimates = new double[nr_class];

ytes("labels");//写入到文件中去

for(int j=0;j

ytes(" "+labels[j]);

ytes("n");

}

} //end if

//这个一定会执行的

while(true)

{

String line = ne(); //一行一行的读取

if(line == null) break;//如果出现空行,那么就停止,所以在文件中中间不能有

空行

StringTokenizer st = new StringTokenizer(line," tnrf:");

double target = atof(ken());

int m = okens()/2;

svm_node[] x = new svm_node[m];

for(int j=0;j

{

x[j] = new svm_node();

x[j].index = atoi(ken());

x[j].value = atof(ken());

}

double v;

//如果是分类svm就执行这个

if (predict_probability==1 && (svm_type==svm_parameter.C_SVC ||

svm_type==svm__SVC))

{

v = _predict_probability(model,x,prob_estimates);

ytes(v+" ");

for(int j=0;j

ytes(prob_estimates[j]+" ");

ytes("n");

} //end id

else

{

v = _predict(model,x);

ytes(v+"n");

}

/**

* 做二次开发,这里可动手脚,你可以输入要具体预测对的类在这里显示出来

等等

*/

if(v == target) //如果预测正确,那么分类的正确就加一

++correct;

error += (v-target)*(v-target);

sumv += v;

sumy += target;

sumvv += v*v;

sumyy += target*target;

sumvy += v*target;

++total;

} //end while

//如果是回归的svm就用这个

if(svm_type == svm_N_SVR ||

svm_type == svm__SVR)

{

/**

* 这里打印出来的是用于回归问题的信息regression

*/

("Mean squared error = "+error/total+" (regression)n");

("Squared correlation coefficient = "+

((total*sumvy-sumv*sumy)*(total*sumvy-sumv*sumy))/

((total*sumvv-sumv*sumv)*(total*sumyy-sumy*sumy))+

" (regression)n");

}

else //这里打印出来的是用于分类问题的信息classification

("Accuracy = "+(double)correct/total*100+

"% ("+correct+"/"+total+") (classification)n");

}//end function

private static void exit_with_help()

{

("usage: svm_predict [options] test_file model_file output_filen"

+"options:n"

+"-b probability_estimates: whether to predict probability estimates, 0 or 1 (default 0);

one-class SVM not supported yetn");

(1);

}

//首先从这里读

public static void main(String argv[]) throws IOException

{

int i, predict_probability=0; //设置两个值,后面一个0表示不开启

// parse options解析选项,解析和train类类似不做说明

for(i=0;i<;i++)

{

if(argv[i].charAt(0) != '-') break;

++i;

switch(argv[i-1].charAt(1))

{

case 'b':

predict_probability = atoi(argv[i]);

break;

default:

("Unknown option: " + argv[i-1] + "n");

exit_with_help();

}

}

if(i>=-2)

exit_with_help();

try

{

BufferedReader input = new BufferedReader(new FileReader(argv[i]));

DataOutputStream output = new DataOutputStream(new

BufferedOutputStream(new FileOutputStream(argv[i+2])));

svm_model model = _load_model(argv[i+1]);

if(predict_probability == 1)

{

if(_check_probability_model(model)==0)

{

("Model does not support probabiliy estimatesn");

(1);

}

}

else

{

if(_check_probability_model(model)!=0)

{

("Model supports probability estimates, but disabled in

prediction.n");

}

}

/**

* 重点来看这个,我们要预测或者分类,中想返回一个预测正确或者分类正

确的类别的

* 你可以按住ctrl,然后用鼠标点击这个类

* 三个个参数:

* 一个是模型,已经训练出来的模型

* 一个是输入的测试数据

* 一个是是不是要打印出信息(我没用过,默认是0)

*/

predict(input,output,model,predict_probability);

(); //涉及到文件的操作有关闭的一些操作

();

}

catch(FileNotFoundException e)

{

exit_with_help();

}

}

}

catch(ArrayIndexOutOfBoundsException e)

{

exit_with_help();

}

下面是一些二次开发的介绍

隔点搜索的代码怎么写?

1. 我们在寻找最佳svm的参数组合的时候不可能自己去手动的去设置.比如高斯核函数有

两个参数要设置,c和gamma.我们要改写train的代码,将c和gama的参数设置到man

方法中去,直接通过调用main就可以改变c和gamma的

打圈的是自己改的。

如果你能看懂上面的意思,那么我想你的java基础完全可以想出来怎么讲correct正确的

作为返回值返回到主main中,你又可以利用这个来写出属于自己的交叉验证

你可以参考一下的调用代码

package _RBF;

import edWriter;

import ;

import iter;

import ption;

import _predict;

import _train;

public class ComMain_data_dea {

/*

* @param args

* @throws IOException

*/

public void main(int ix) throws IOException {

// TODO Auto-generated method stub

// String []arg ={"file_","file_modeltrain1_"};

// String

[]parg={"file_","file_modeltrain1_","file_outtrain1_"};

// n("........SVM运行开始..........");

// svm_train t = new svm_train();

// svm_predict p= new svm_predict();

// (arg);

// (parg);

// n("........SVM运行结束..........");

File f_accurate = new File("file_data_accurateacc_data_");

if(f_accurate == null){

f_NewFile() ;

}else{

//f_accurate.

}

FileWriter fw = new FileWriter(f_accurate);

BufferedWriter bw = new BufferedWriter(fw);

("c表示惩罚因子,gamama表示核函数变量, temp表示交叉验证准确验

证的个数,accor表示准确率");

e();

double temp = 0;

double c = 0;

double val_c = 0 ;

double val_gamma = 0 ;

double accor = 0;

double maxaccor = 0;

double max_c = 0;

double max_gamama = 0;

double max_temp = 0;

String tempStr_c ="";

String tempStr_g ="";

String tempStr_ac ="";

for(int i1=0;i1<100;i1++){

val_gamma = 0;

val_c = val_c + 0.1 ;

for(int i2=0;i2<200;i2++){

val_gamma = val_gamma + 0.01 ;

for(int i=1;i<6;i++){

String []arg

={"file_data_ccrtraintrain"+i+".txt","file_data_ccrmodeltrain"+i+"_"};

String

[]parg={"file_data_ccrtesttest"+i+".txt","file_data_ccrmodeltrain"+i+"_","file_data

_ccrouttrain"+i+"_"};

n("........SVM运行开始.........."+i);

svm_train t = new svm_train();

svm_predict p= new svm_predict();

(arg,val_c,val_gamma,ix,0);

c = (parg);

temp = temp + c;

}//end for

accor = (double)temp/23;

tempStr_c = (val_c+"0000000").substring(0, 3);

tempStr_g = (val_gamma+"0000000").substring(0, 4);

tempStr_ac = (accor+"000000").substring(0, 5);

//("c="+val_c+",gamama="+val_gamma+",temp="+temp+",accor="+accor);

("c="+tempStr_c+" ,gamama="+tempStr_g+" ,temp="+temp+" ,accor="+temp

Str_ac);

e();

if(accor>maxaccor){

maxaccor = accor ;

max_c = val_c;

max_gamama = val_gamma;

max_temp = temp;

// break;

}

c = 0;

temp = 0;

}//end for

if(accor>maxaccor){

maxaccor = accor ;

// break;

}

}//end for

("max_c="+max_c+",max_gamama="+max_gamama+",max_temp="+max_temp+",

max_accor="+maxaccor);

e();

();

();

//double accor = (double)temp/23;

n("........"+maxaccor);

n("**********max_c************" + max_c);

n("**********max_gamama************" + max_gamama);

n("**********max_temp************" + max_temp);

// for(int i=1;i<6;i++){

//

// String []arg ={"file_traintrain"+i+".txt","file_modeltrain"+i+"_"};

// String

[]parg={"file_testtest"+i+".txt","file_modeltrain"+i+"_","file_outtrain"+i+"_"

};

// n("........SVM运行开始.........."+i);

// svm_train t = new svm_train();

// svm_predict p= new svm_predict();

// (arg,val_c,val_gamma);

// c = (parg);

// temp = temp + c;

// n("........C.........."+c);

//

// }//end for

// double accor = (double)temp/23;

// n("........"+accor);

}//end main

}//end class

说明:

你可以在csdn下载到包括demo的说明文档

发布评论

评论列表 (0)

  1. 暂无评论