2024年6月14日发(作者:遇飞语)
LibSVM(JAVA)二次开发接口调用及源码更改的文档
浙江大学协调服务研究所
文档整理:陈伟 chenweishaoxing#
下载libsvm
方法:google libsvm找到官网下载:
/~cjlin/libsvm/ ,其中图片中椭圆的
解压文档
下载下来libsvm工具包有几个版本的,其中python的最经典,用的人比较多,
还支持matlab,C++等等。我们用的java版的,就到解压开的java文件夹中!
java文件夹
导入到eclipse工程中
创建一个java工程,把上图的源码复制到eclipse中,如同所示
在工程下创建一个文件夹,里面存放训练测试用的数据
首次调用的Demo举例
在java的工程中创建一个属于自己的包,然后写一个mian类。如图
package ;
import ption;
import _predict;
import _train;
public class ComMain {
public static void main(String[] args) throws IOException {
String []arg ={ "", //
存放SVM训练模型用的数据的路径
"trainfilemodel_"};
//存放SVM通过训练数据训/ //练出来的模型的路径
String []parg={"", //这个是存放测试数据
"trainfilemodel_", //调用的是训练以后的模型
"trainfileout_"}; //生成的结果的文件的路径
n("........SVM运行开始..........");
//创建一个训练对象
svm_train t = new svm_train();
//创建一个预测或者分类的对象
svm_predict p= new svm_predict();
(arg); //调用
(parg); //调用
}
}
6.运行工程就可以看到了结果了
Libsvm二次开发的首先要熟悉调用接口的源码
你一定会有疑问:SVM的参数怎么设置,cross-validation怎么用。那么我们首
先来说明一个问题,交叉验证在一般情况下要自己开发自己写。Libsvm内置了
交叉验证,但是如果我希望用同交叉验证的数据用决策树来做,怎么办,显然
Libsvm并没有保存交叉验证的数据。
============================================================
我已经将注释写在了源码中。
Svm_train类的文档说明
package service;
import libsvm.*;
import .*;
import .*;
public class svm_train {
private svm_parameter param; // set by parse_command_line
private svm_problem prob; // set by read_problem
private svm_model model;
private String input_file_name; // set by parse_command_line
private String model_file_name; // set by parse_command_line
private String error_msg;
private int cross_validation;
private int nr_fold;
private static svm_print_interface svm_print_null =
svm_print_interface()
{
public void print(String s) {}
};
private static void exit_with_help()
{
(
"Usage: svm_train [options] training_set_file [model_file]n"
+"options:n"
+"-s svm_type : set type of SVM (default 0)n"
new
+" 0 -- C-SVCn"
+" 1 -- nu-SVCn"
+" 2 -- one-class SVMn"
+" 3 -- epsilon-SVRn"
+" 4 -- nu-SVRn"
+"-t kernel_type : set type of kernel function (default 2)n"
+" 0 -- linear: u'*vn"
+" 1 -- polynomial: (gamma*u'*v + coef0)^degreen"
+" 2 -- radial basis function: exp(-gamma*|u-v|^2)n"
+" 3 -- sigmoid: tanh(gamma*u'*v + coef0)n"
+" 4 -- precomputed kernel (kernel values in training_set_file)n"
+"-d degree : set degree in kernel function (default 3)n"
+"-g gamma : set gamma in kernel function (default 1/num_features)n"
+"-r coef0 : set coef0 in kernel function (default 0)n"
+"-c cost : set the parameter C of C-SVC, epsilon-SVR, and nu-SVR
(default 1)n"
+"-n nu : set the parameter nu of nu-SVC, one-class SVM, and nu-SVR
(default 0.5)n"
+"-p epsilon : set the epsilon in loss function of epsilon-SVR (default
0.1)n"
+"-m cachesize : set cache memory size in MB (default 100)n"
+"-e epsilon : set tolerance of termination criterion (default 0.001)n"
+"-h shrinking : whether to use the shrinking heuristics, 0 or 1 (default
1)n"
+"-b probability_estimates : whether to train a SVC or SVR model for
probability estimates, 0 or 1 (default 0)n"
+"-wi weight : set the parameter C of class i to weight*C, for C-SVC
(default 1)n"
+"-v n : n-fold cross validation moden"
+"-q : quiet mode (no outputs)n"
);
(1);
}
private void do_cross_validation()
{
int i;
int total_correct = 0;
double total_error = 0;
double sumv = 0, sumy = 0, sumvv = 0, sumyy = 0, sumvy = 0;
double[] target = new double[prob.l];
_cross_validation(prob,param,nr_fold,target);
if(_type == svm_N_SVR ||
_type == svm__SVR)
{
for(i=0;i { double y = prob.y[i]; double v = target[i]; total_error += (v-y)*(v-y); sumv += v; sumy += y; sumvv += v*v; sumyy += y*y; sumvy += v*y; } ("Cross Validation Mean squared error = "+total_error/prob.l+"n"); ("Cross Validation Squared correlation coefficient = "+ ((prob.l*sumvy-sumv*sumy)*(prob.l*sumvy-sumv*sumy))/ ((prob.l*sumvv-sumv*sumv)*(prob.l*sumyy-sumy*sumy))+"n" ); } else { for(i=0;i if(target[i] == prob.y[i]) ++total_correct; ("Cross Validation Accuracy = "+100.0*total_correct/prob.l+"%n"); } } private void run(String argv[]) throws IOException { n("我的数组的长度是:" + ) ; parse_command_line(argv); //解析svm参数的配置,我们去这个方法看 看,你可以按住crlt,然后鼠标点击这个方法 read_problem(); error_msg = _check_parameter(prob,param); if(error_msg != null) { ("ERROR: "+error_msg+"n"); (1); } if(cross_validation != 0) { do_cross_validation(); } else { model = _train(prob,param); _save_model(model_file_name,model); } } public static void main(String argv[]) throws IOException { svm_train t = new svm_train(); (argv); } private static double atof(String s) { double d = f(s).doubleValue(); if ((d) || nite(d)) { ("NaN or Infinity in inputn"); (1); } return(d); } //解析控制台输入的string类型的值,因为svm的参数是由整数来代表的, //那么通过这个方法将控制台输入的字符串解析成为整数的 private static int atoi(String s) { return nt(s); } //欢迎来到解析svm参数的方法 private void parse_command_line(String argv[]) { int i; //设置了一个方法域的一个i变量,用于遍历argv这个字符串数组 的的哦 svm_print_interface print_func = null; // default printing to stdout,这 个是一个接口 //创建一个SVM的参数对象,SVM的参数都在这个对象中。 //具体的参数对象可以看svm_parameter这个类 param = new svm_parameter(); // 默认的SVM设置的值,如果需要修改,那么要从控制台输入,然后 下面的for循环会解析svm的参数设置 //我还没用全部搞懂这些参数的意思,但是这些参数的作用完全可以在 帮助信息中看到。 _type = svm_parameter.C_SVC; //默认的支持向量 //_type = svm__SVC; _type = svm_; //默认的核函数高斯核函数 = 3; = 0; // 1/num_features 0 = 0; = 0.01; _size = 100; param.C = 1; = 1e-3; param.p = 0.1; ing = 1; ility = 0; _weight = 0; _label = new int[0]; = new double[0]; cross_validation = 0; //表示关闭交叉验证,1表示开启交叉验证(这里不 能设置1,因为你设置了也没用) // 解析选项SVM参数的选项,如果控制台没有输入对于的字符串,那 么SVM将使用的是默认的SVM的参数设置 for(i=0;i<;i++) //初始化i { //返回的是argv这个数组第i个字符串第一个字符,这里说明控制 台要输入的时候首先要写一个'-'号.(比如i=4,=10,argv[4]) //如果不写,那么将break本次的循环,跳出的是整个for循环,所 以让文件保存的路径在数组中写到最后 if(argv[i].charAt(0) != '-') break; //如果一遇到不是这个跳出的是整 个for循环 //如果查询到了'-'字符,那么会执行这一步了。 //判断这个i的值是不是大于或者等于argv的长度了,如果是数组 的长度了,那么就打印出帮助信息。并且会中断虚拟机了 //++i >= 这个应该是先用i再加1,那么下面的操作的时 候就是i = i + 1了(i=5) if(++i>=) exit_with_help(); //如果执行了第二个if,那么会执行到这里了。这里的i = 5 switch(argv[i-1].charAt(1)) //用到的字符串仍然是argv[5-1]=argv[4], 解析的是第2个字符。 { case 's': //设置svm的类型 _type = atoi(argv[i]); //这个赋值就是将argv[5], 赋值过去了 break; case 't': //设置svm的核函数类型 _type = atoi(argv[i]); break; case 'd': //设置svm参数d的大小,用于多项式核函数 = atoi(argv[i]); break; case 'g': //赋值gamma的 = atof(argv[i]); break; case 'r': //赋值coef0的值 0 = atof(argv[i]); break; case 'n': //赋值n的值 = atof(argv[i]); break; case 'm': //赋值缓存的值 _size = atof(argv[i]); break; case 'c': //赋值的是惩罚因子的大小 param.C = atof(argv[i]); break; case 'e': //赋值的eps的值 = atof(argv[i]); break; case 'p': //赋值******我不想写下去了,因为在实际的应用中, 我还没用用到下面的参数。抱歉。 param.p = atof(argv[i]); break; case 'h': // ing = atoi(argv[i]); break; case 'b': //要不要打印出分类的准确率的值 ility = atoi(argv[i]); break; case 'q': print_func = svm_print_null; i--; break; case 'v': //设置的交叉验证的值 cross_validation = 1; //开启交叉验证 nr_fold = atoi(argv[i]); if(nr_fold < 2) //交叉验证的值不能小于1 { ("n-fold cross validation: n must >= 2n"); exit_with_help(); } break; case 'w': ++_weight; { int[] old = _label; _label = new int[_weight]; opy(old,0,_label,0,_weight-1); } { double[] old = ; = new double[_weight]; opy(old,0,,0,_weight-1); } _label[_weight-1] = atoi(argv[i-1].substring(2)); [_weight-1] = atof(argv[i]); break; default: //如果一个字符都匹配不到,很遗憾要中断JVM了,并且会 打印出那个位子的字符出现了错误,然后打印出帮助信息 ("Unknown option: " + argv[i-1] + "n"); exit_with_help(); }//end switch } //end for _set_print_string_function(print_func); //打印出是不是静音模 式 // determine filenames决定文件名 /** * 我必须中断下操作来说明控制台应该怎么输入的 * argv = {"-s","1","-t","3","-w","5","我是训练用的文件路径","我 是训练完以后保存模型的路径"} * 具体的1,3,5参数要参考官方说明文档,或者查看设置参数那个类 的参数。 * 看到这,你可以继续看下去了 */ if(i>=)//这里是了防止没有输入存放文件的路径,或者存放文 件的路径不够 exit_with_help(); //到这里,i的应该是字符串数组的倒数第二个了 //其实我一直搞不清楚,为什么for循环完毕了,这个i不是argv数组 的长度呢 ?不是的i的值是数组长度-1也就是数组中倒数第二个位子 input_file_name = argv[i]; //将训练的文件路径赋值 n("POSITION="+(i+1)+"我的训练用的数据存放的路 径是:" + input_file_name) ; if(i<-1) //如果i的值比数组长度-1还小,那么将argv[i]下一 个字符串赋值给存放模型的路径 model_file_name = argv[i+1]; //将训练以后的模型路径赋值 else { int p = argv[i].lastIndexOf('/'); ++p; // model_file_name = argv[i].substring(p)+".model"; } n("我的训练用的数据存放的路径是:" + model_file_name) ; } // end run function // read in a problem (in svmlight format) private void read_problem() throws IOException { BufferedReader fp = new BufferedReader(new FileReader(input_file_name)); Vector Vector int max_index = 0; while(true) { String line = ne(); if(line == null) break; StringTokenizer st = new StringTokenizer(line," tnrf:"); ment(atof(ken())); int m = okens()/2; svm_node[] x = new svm_node[m]; for(int j=0;j { x[j] = new svm_node(); x[j].index = atoi(ken()); x[j].value = atof(ken()); } if(m>0) max_index = (max_index, x[m-1].index); ment(x); } prob = new svm_problem(); prob.l = (); prob.x = new svm_node[prob.l][]; for(int i=0;i prob.x[i] = tAt(i); prob.y = new double[prob.l]; for(int i=0;i prob.y[i] = tAt(i); if( == 0 && max_index > 0) = 1.0/max_index; if(_type == svm_PUTED) for(int i=0;i { if (prob.x[i][0].index != 0) { ("Wrong kernel matrix: first column must be 0:sample_serial_numbern"); (1); } if ((int)prob.x[i][0].value <= 0 || (int)prob.x[i][0].value > max_index) { ("Wrong input format: sample_serial_number out of rangen"); (1); } } (); } } Svm_predict类的文档说明 package service; import libsvm.*; import .*; import .*; public class svm_predict { private static double atof(String s) { return f(s).doubleValue(); } private static int atoi(String s) { return nt(s); } private static void predict(BufferedReader input, DataOutputStream output, svm_model model, int predict_probability) throws IOException { //欢迎来到这个预测方法,下面开始分析 //设置方法内局部变量 //这个是预测正确的个数的 int correct = 0; //这个是预测的个数一共有几个 int total = 0; //分类或者预测的准确率,所以用double error = correct / total ; double error = 0; //几个中间变量的参数 double sumv = 0, sumy = 0, sumvv = 0, sumyy = 0, sumvy = 0; int svm_type=_get_svm_type(model); int nr_class=_get_nr_class(model); double[] prob_estimates=null; //如果传入进来的1(默认是0) ,那么从这里开始执行 //这个是不能用回归svm的 if(predict_probability == 1) { if(svm_type == svm_N_SVR || //回归SVM svm_type == svm__SVR) //回归SVM { //打印出出错误了,svm数据不匹配 ("Prob. model for test data: target value = predicted value + z,nz: Laplace distribution e^(-|z|/sigma)/(2sigma),sigma="+_get_svr_probability(model)+"n"); } //用于分类的话就执行这个了 else { int[] labels=new int[nr_class]; //取得标签。分类用的标签 _get_labels(model,labels); prob_estimates = new double[nr_class]; ytes("labels");//写入到文件中去 for(int j=0;j ytes(" "+labels[j]); ytes("n"); } } //end if //这个一定会执行的 while(true) { String line = ne(); //一行一行的读取 if(line == null) break;//如果出现空行,那么就停止,所以在文件中中间不能有 空行 StringTokenizer st = new StringTokenizer(line," tnrf:"); double target = atof(ken()); int m = okens()/2; svm_node[] x = new svm_node[m]; for(int j=0;j { x[j] = new svm_node(); x[j].index = atoi(ken()); x[j].value = atof(ken()); } double v; //如果是分类svm就执行这个 if (predict_probability==1 && (svm_type==svm_parameter.C_SVC || svm_type==svm__SVC)) { v = _predict_probability(model,x,prob_estimates); ytes(v+" "); for(int j=0;j ytes(prob_estimates[j]+" "); ytes("n"); } //end id else { v = _predict(model,x); ytes(v+"n"); } /** * 做二次开发,这里可动手脚,你可以输入要具体预测对的类在这里显示出来 等等 */ if(v == target) //如果预测正确,那么分类的正确就加一 ++correct; error += (v-target)*(v-target); sumv += v; sumy += target; sumvv += v*v; sumyy += target*target; sumvy += v*target; ++total; } //end while //如果是回归的svm就用这个 if(svm_type == svm_N_SVR || svm_type == svm__SVR) { /** * 这里打印出来的是用于回归问题的信息regression */ ("Mean squared error = "+error/total+" (regression)n"); ("Squared correlation coefficient = "+ ((total*sumvy-sumv*sumy)*(total*sumvy-sumv*sumy))/ ((total*sumvv-sumv*sumv)*(total*sumyy-sumy*sumy))+ " (regression)n"); } else //这里打印出来的是用于分类问题的信息classification ("Accuracy = "+(double)correct/total*100+ "% ("+correct+"/"+total+") (classification)n"); }//end function private static void exit_with_help() { ("usage: svm_predict [options] test_file model_file output_filen" +"options:n" +"-b probability_estimates: whether to predict probability estimates, 0 or 1 (default 0); one-class SVM not supported yetn"); (1); } //首先从这里读 public static void main(String argv[]) throws IOException { int i, predict_probability=0; //设置两个值,后面一个0表示不开启 // parse options解析选项,解析和train类类似不做说明 for(i=0;i<;i++) { if(argv[i].charAt(0) != '-') break; ++i; switch(argv[i-1].charAt(1)) { case 'b': predict_probability = atoi(argv[i]); break; default: ("Unknown option: " + argv[i-1] + "n"); exit_with_help(); } } if(i>=-2) exit_with_help(); try { BufferedReader input = new BufferedReader(new FileReader(argv[i])); DataOutputStream output = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(argv[i+2]))); svm_model model = _load_model(argv[i+1]); if(predict_probability == 1) { if(_check_probability_model(model)==0) { ("Model does not support probabiliy estimatesn"); (1); } } else { if(_check_probability_model(model)!=0) { ("Model supports probability estimates, but disabled in prediction.n"); } } /** * 重点来看这个,我们要预测或者分类,中想返回一个预测正确或者分类正 确的类别的 * 你可以按住ctrl,然后用鼠标点击这个类 * 三个个参数: * 一个是模型,已经训练出来的模型 * 一个是输入的测试数据 * 一个是是不是要打印出信息(我没用过,默认是0) */ predict(input,output,model,predict_probability); (); //涉及到文件的操作有关闭的一些操作 (); } catch(FileNotFoundException e) { exit_with_help(); } } } catch(ArrayIndexOutOfBoundsException e) { exit_with_help(); } 下面是一些二次开发的介绍 隔点搜索的代码怎么写? 1. 我们在寻找最佳svm的参数组合的时候不可能自己去手动的去设置.比如高斯核函数有 两个参数要设置,c和gamma.我们要改写train的代码,将c和gama的参数设置到man 方法中去,直接通过调用main就可以改变c和gamma的 打圈的是自己改的。 如果你能看懂上面的意思,那么我想你的java基础完全可以想出来怎么讲correct正确的 作为返回值返回到主main中,你又可以利用这个来写出属于自己的交叉验证 你可以参考一下的调用代码 package _RBF; import edWriter; import ; import iter; import ption; import _predict; import _train; public class ComMain_data_dea { /* * @param args * @throws IOException */ public void main(int ix) throws IOException { // TODO Auto-generated method stub // String []arg ={"file_","file_modeltrain1_"}; // String []parg={"file_","file_modeltrain1_","file_outtrain1_"}; // n("........SVM运行开始.........."); // svm_train t = new svm_train(); // svm_predict p= new svm_predict(); // (arg); // (parg); // n("........SVM运行结束.........."); File f_accurate = new File("file_data_accurateacc_data_"); if(f_accurate == null){ f_NewFile() ; }else{ //f_accurate. } FileWriter fw = new FileWriter(f_accurate); BufferedWriter bw = new BufferedWriter(fw); ("c表示惩罚因子,gamama表示核函数变量, temp表示交叉验证准确验 证的个数,accor表示准确率"); e(); double temp = 0; double c = 0; double val_c = 0 ; double val_gamma = 0 ; double accor = 0; double maxaccor = 0; double max_c = 0; double max_gamama = 0; double max_temp = 0; String tempStr_c =""; String tempStr_g =""; String tempStr_ac =""; for(int i1=0;i1<100;i1++){ val_gamma = 0; val_c = val_c + 0.1 ; for(int i2=0;i2<200;i2++){ val_gamma = val_gamma + 0.01 ; for(int i=1;i<6;i++){ String []arg ={"file_data_ccrtraintrain"+i+".txt","file_data_ccrmodeltrain"+i+"_"}; String []parg={"file_data_ccrtesttest"+i+".txt","file_data_ccrmodeltrain"+i+"_","file_data _ccrouttrain"+i+"_"}; n("........SVM运行开始.........."+i); svm_train t = new svm_train(); svm_predict p= new svm_predict(); (arg,val_c,val_gamma,ix,0); c = (parg); temp = temp + c; }//end for accor = (double)temp/23; tempStr_c = (val_c+"0000000").substring(0, 3); tempStr_g = (val_gamma+"0000000").substring(0, 4); tempStr_ac = (accor+"000000").substring(0, 5); //("c="+val_c+",gamama="+val_gamma+",temp="+temp+",accor="+accor); ("c="+tempStr_c+" ,gamama="+tempStr_g+" ,temp="+temp+" ,accor="+temp Str_ac); e(); if(accor>maxaccor){ maxaccor = accor ; max_c = val_c; max_gamama = val_gamma; max_temp = temp; // break; } c = 0; temp = 0; }//end for if(accor>maxaccor){ maxaccor = accor ; // break; } }//end for ("max_c="+max_c+",max_gamama="+max_gamama+",max_temp="+max_temp+", max_accor="+maxaccor); e(); (); (); //double accor = (double)temp/23; n("........"+maxaccor); n("**********max_c************" + max_c); n("**********max_gamama************" + max_gamama); n("**********max_temp************" + max_temp); // for(int i=1;i<6;i++){ // // String []arg ={"file_traintrain"+i+".txt","file_modeltrain"+i+"_"}; // String []parg={"file_testtest"+i+".txt","file_modeltrain"+i+"_","file_outtrain"+i+"_" }; // n("........SVM运行开始.........."+i); // svm_train t = new svm_train(); // svm_predict p= new svm_predict(); // (arg,val_c,val_gamma); // c = (parg); // temp = temp + c; // n("........C.........."+c); // // }//end for // double accor = (double)temp/23; // n("........"+accor); }//end main }//end class 说明: 你可以在csdn下载到包括demo的说明文档 2024年6月14日发(作者:遇飞语) LibSVM(JAVA)二次开发接口调用及源码更改的文档 浙江大学协调服务研究所 文档整理:陈伟 chenweishaoxing# 下载libsvm 方法:google libsvm找到官网下载: /~cjlin/libsvm/ ,其中图片中椭圆的 解压文档 下载下来libsvm工具包有几个版本的,其中python的最经典,用的人比较多, 还支持matlab,C++等等。我们用的java版的,就到解压开的java文件夹中! java文件夹 导入到eclipse工程中 创建一个java工程,把上图的源码复制到eclipse中,如同所示 在工程下创建一个文件夹,里面存放训练测试用的数据 首次调用的Demo举例 在java的工程中创建一个属于自己的包,然后写一个mian类。如图 package ; import ption; import _predict; import _train; public class ComMain { public static void main(String[] args) throws IOException { String []arg ={ "", // 存放SVM训练模型用的数据的路径 "trainfilemodel_"}; //存放SVM通过训练数据训/ //练出来的模型的路径 String []parg={"", //这个是存放测试数据 "trainfilemodel_", //调用的是训练以后的模型 "trainfileout_"}; //生成的结果的文件的路径 n("........SVM运行开始.........."); //创建一个训练对象 svm_train t = new svm_train(); //创建一个预测或者分类的对象 svm_predict p= new svm_predict(); (arg); //调用 (parg); //调用 } } 6.运行工程就可以看到了结果了 Libsvm二次开发的首先要熟悉调用接口的源码 你一定会有疑问:SVM的参数怎么设置,cross-validation怎么用。那么我们首 先来说明一个问题,交叉验证在一般情况下要自己开发自己写。Libsvm内置了 交叉验证,但是如果我希望用同交叉验证的数据用决策树来做,怎么办,显然 Libsvm并没有保存交叉验证的数据。 ============================================================ 我已经将注释写在了源码中。 Svm_train类的文档说明 package service; import libsvm.*; import .*; import .*; public class svm_train { private svm_parameter param; // set by parse_command_line private svm_problem prob; // set by read_problem private svm_model model; private String input_file_name; // set by parse_command_line private String model_file_name; // set by parse_command_line private String error_msg; private int cross_validation; private int nr_fold; private static svm_print_interface svm_print_null = svm_print_interface() { public void print(String s) {} }; private static void exit_with_help() { ( "Usage: svm_train [options] training_set_file [model_file]n" +"options:n" +"-s svm_type : set type of SVM (default 0)n" new +" 0 -- C-SVCn" +" 1 -- nu-SVCn" +" 2 -- one-class SVMn" +" 3 -- epsilon-SVRn" +" 4 -- nu-SVRn" +"-t kernel_type : set type of kernel function (default 2)n" +" 0 -- linear: u'*vn" +" 1 -- polynomial: (gamma*u'*v + coef0)^degreen" +" 2 -- radial basis function: exp(-gamma*|u-v|^2)n" +" 3 -- sigmoid: tanh(gamma*u'*v + coef0)n" +" 4 -- precomputed kernel (kernel values in training_set_file)n" +"-d degree : set degree in kernel function (default 3)n" +"-g gamma : set gamma in kernel function (default 1/num_features)n" +"-r coef0 : set coef0 in kernel function (default 0)n" +"-c cost : set the parameter C of C-SVC, epsilon-SVR, and nu-SVR (default 1)n" +"-n nu : set the parameter nu of nu-SVC, one-class SVM, and nu-SVR (default 0.5)n" +"-p epsilon : set the epsilon in loss function of epsilon-SVR (default 0.1)n" +"-m cachesize : set cache memory size in MB (default 100)n" +"-e epsilon : set tolerance of termination criterion (default 0.001)n" +"-h shrinking : whether to use the shrinking heuristics, 0 or 1 (default 1)n" +"-b probability_estimates : whether to train a SVC or SVR model for probability estimates, 0 or 1 (default 0)n" +"-wi weight : set the parameter C of class i to weight*C, for C-SVC (default 1)n" +"-v n : n-fold cross validation moden" +"-q : quiet mode (no outputs)n" ); (1); } private void do_cross_validation() { int i; int total_correct = 0; double total_error = 0; double sumv = 0, sumy = 0, sumvv = 0, sumyy = 0, sumvy = 0; double[] target = new double[prob.l]; _cross_validation(prob,param,nr_fold,target); if(_type == svm_N_SVR || _type == svm__SVR) { for(i=0;i { double y = prob.y[i]; double v = target[i]; total_error += (v-y)*(v-y); sumv += v; sumy += y; sumvv += v*v; sumyy += y*y; sumvy += v*y; } ("Cross Validation Mean squared error = "+total_error/prob.l+"n"); ("Cross Validation Squared correlation coefficient = "+ ((prob.l*sumvy-sumv*sumy)*(prob.l*sumvy-sumv*sumy))/ ((prob.l*sumvv-sumv*sumv)*(prob.l*sumyy-sumy*sumy))+"n" ); } else { for(i=0;i if(target[i] == prob.y[i]) ++total_correct; ("Cross Validation Accuracy = "+100.0*total_correct/prob.l+"%n"); } } private void run(String argv[]) throws IOException { n("我的数组的长度是:" + ) ; parse_command_line(argv); //解析svm参数的配置,我们去这个方法看 看,你可以按住crlt,然后鼠标点击这个方法 read_problem(); error_msg = _check_parameter(prob,param); if(error_msg != null) { ("ERROR: "+error_msg+"n"); (1); } if(cross_validation != 0) { do_cross_validation(); } else { model = _train(prob,param); _save_model(model_file_name,model); } } public static void main(String argv[]) throws IOException { svm_train t = new svm_train(); (argv); } private static double atof(String s) { double d = f(s).doubleValue(); if ((d) || nite(d)) { ("NaN or Infinity in inputn"); (1); } return(d); } //解析控制台输入的string类型的值,因为svm的参数是由整数来代表的, //那么通过这个方法将控制台输入的字符串解析成为整数的 private static int atoi(String s) { return nt(s); } //欢迎来到解析svm参数的方法 private void parse_command_line(String argv[]) { int i; //设置了一个方法域的一个i变量,用于遍历argv这个字符串数组 的的哦 svm_print_interface print_func = null; // default printing to stdout,这 个是一个接口 //创建一个SVM的参数对象,SVM的参数都在这个对象中。 //具体的参数对象可以看svm_parameter这个类 param = new svm_parameter(); // 默认的SVM设置的值,如果需要修改,那么要从控制台输入,然后 下面的for循环会解析svm的参数设置 //我还没用全部搞懂这些参数的意思,但是这些参数的作用完全可以在 帮助信息中看到。 _type = svm_parameter.C_SVC; //默认的支持向量 //_type = svm__SVC; _type = svm_; //默认的核函数高斯核函数 = 3; = 0; // 1/num_features 0 = 0; = 0.01; _size = 100; param.C = 1; = 1e-3; param.p = 0.1; ing = 1; ility = 0; _weight = 0; _label = new int[0]; = new double[0]; cross_validation = 0; //表示关闭交叉验证,1表示开启交叉验证(这里不 能设置1,因为你设置了也没用) // 解析选项SVM参数的选项,如果控制台没有输入对于的字符串,那 么SVM将使用的是默认的SVM的参数设置 for(i=0;i<;i++) //初始化i { //返回的是argv这个数组第i个字符串第一个字符,这里说明控制 台要输入的时候首先要写一个'-'号.(比如i=4,=10,argv[4]) //如果不写,那么将break本次的循环,跳出的是整个for循环,所 以让文件保存的路径在数组中写到最后 if(argv[i].charAt(0) != '-') break; //如果一遇到不是这个跳出的是整 个for循环 //如果查询到了'-'字符,那么会执行这一步了。 //判断这个i的值是不是大于或者等于argv的长度了,如果是数组 的长度了,那么就打印出帮助信息。并且会中断虚拟机了 //++i >= 这个应该是先用i再加1,那么下面的操作的时 候就是i = i + 1了(i=5) if(++i>=) exit_with_help(); //如果执行了第二个if,那么会执行到这里了。这里的i = 5 switch(argv[i-1].charAt(1)) //用到的字符串仍然是argv[5-1]=argv[4], 解析的是第2个字符。 { case 's': //设置svm的类型 _type = atoi(argv[i]); //这个赋值就是将argv[5], 赋值过去了 break; case 't': //设置svm的核函数类型 _type = atoi(argv[i]); break; case 'd': //设置svm参数d的大小,用于多项式核函数 = atoi(argv[i]); break; case 'g': //赋值gamma的 = atof(argv[i]); break; case 'r': //赋值coef0的值 0 = atof(argv[i]); break; case 'n': //赋值n的值 = atof(argv[i]); break; case 'm': //赋值缓存的值 _size = atof(argv[i]); break; case 'c': //赋值的是惩罚因子的大小 param.C = atof(argv[i]); break; case 'e': //赋值的eps的值 = atof(argv[i]); break; case 'p': //赋值******我不想写下去了,因为在实际的应用中, 我还没用用到下面的参数。抱歉。 param.p = atof(argv[i]); break; case 'h': // ing = atoi(argv[i]); break; case 'b': //要不要打印出分类的准确率的值 ility = atoi(argv[i]); break; case 'q': print_func = svm_print_null; i--; break; case 'v': //设置的交叉验证的值 cross_validation = 1; //开启交叉验证 nr_fold = atoi(argv[i]); if(nr_fold < 2) //交叉验证的值不能小于1 { ("n-fold cross validation: n must >= 2n"); exit_with_help(); } break; case 'w': ++_weight; { int[] old = _label; _label = new int[_weight]; opy(old,0,_label,0,_weight-1); } { double[] old = ; = new double[_weight]; opy(old,0,,0,_weight-1); } _label[_weight-1] = atoi(argv[i-1].substring(2)); [_weight-1] = atof(argv[i]); break; default: //如果一个字符都匹配不到,很遗憾要中断JVM了,并且会 打印出那个位子的字符出现了错误,然后打印出帮助信息 ("Unknown option: " + argv[i-1] + "n"); exit_with_help(); }//end switch } //end for _set_print_string_function(print_func); //打印出是不是静音模 式 // determine filenames决定文件名 /** * 我必须中断下操作来说明控制台应该怎么输入的 * argv = {"-s","1","-t","3","-w","5","我是训练用的文件路径","我 是训练完以后保存模型的路径"} * 具体的1,3,5参数要参考官方说明文档,或者查看设置参数那个类 的参数。 * 看到这,你可以继续看下去了 */ if(i>=)//这里是了防止没有输入存放文件的路径,或者存放文 件的路径不够 exit_with_help(); //到这里,i的应该是字符串数组的倒数第二个了 //其实我一直搞不清楚,为什么for循环完毕了,这个i不是argv数组 的长度呢 ?不是的i的值是数组长度-1也就是数组中倒数第二个位子 input_file_name = argv[i]; //将训练的文件路径赋值 n("POSITION="+(i+1)+"我的训练用的数据存放的路 径是:" + input_file_name) ; if(i<-1) //如果i的值比数组长度-1还小,那么将argv[i]下一 个字符串赋值给存放模型的路径 model_file_name = argv[i+1]; //将训练以后的模型路径赋值 else { int p = argv[i].lastIndexOf('/'); ++p; // model_file_name = argv[i].substring(p)+".model"; } n("我的训练用的数据存放的路径是:" + model_file_name) ; } // end run function // read in a problem (in svmlight format) private void read_problem() throws IOException { BufferedReader fp = new BufferedReader(new FileReader(input_file_name)); Vector Vector int max_index = 0; while(true) { String line = ne(); if(line == null) break; StringTokenizer st = new StringTokenizer(line," tnrf:"); ment(atof(ken())); int m = okens()/2; svm_node[] x = new svm_node[m]; for(int j=0;j { x[j] = new svm_node(); x[j].index = atoi(ken()); x[j].value = atof(ken()); } if(m>0) max_index = (max_index, x[m-1].index); ment(x); } prob = new svm_problem(); prob.l = (); prob.x = new svm_node[prob.l][]; for(int i=0;i prob.x[i] = tAt(i); prob.y = new double[prob.l]; for(int i=0;i prob.y[i] = tAt(i); if( == 0 && max_index > 0) = 1.0/max_index; if(_type == svm_PUTED) for(int i=0;i { if (prob.x[i][0].index != 0) { ("Wrong kernel matrix: first column must be 0:sample_serial_numbern"); (1); } if ((int)prob.x[i][0].value <= 0 || (int)prob.x[i][0].value > max_index) { ("Wrong input format: sample_serial_number out of rangen"); (1); } } (); } } Svm_predict类的文档说明 package service; import libsvm.*; import .*; import .*; public class svm_predict { private static double atof(String s) { return f(s).doubleValue(); } private static int atoi(String s) { return nt(s); } private static void predict(BufferedReader input, DataOutputStream output, svm_model model, int predict_probability) throws IOException { //欢迎来到这个预测方法,下面开始分析 //设置方法内局部变量 //这个是预测正确的个数的 int correct = 0; //这个是预测的个数一共有几个 int total = 0; //分类或者预测的准确率,所以用double error = correct / total ; double error = 0; //几个中间变量的参数 double sumv = 0, sumy = 0, sumvv = 0, sumyy = 0, sumvy = 0; int svm_type=_get_svm_type(model); int nr_class=_get_nr_class(model); double[] prob_estimates=null; //如果传入进来的1(默认是0) ,那么从这里开始执行 //这个是不能用回归svm的 if(predict_probability == 1) { if(svm_type == svm_N_SVR || //回归SVM svm_type == svm__SVR) //回归SVM { //打印出出错误了,svm数据不匹配 ("Prob. model for test data: target value = predicted value + z,nz: Laplace distribution e^(-|z|/sigma)/(2sigma),sigma="+_get_svr_probability(model)+"n"); } //用于分类的话就执行这个了 else { int[] labels=new int[nr_class]; //取得标签。分类用的标签 _get_labels(model,labels); prob_estimates = new double[nr_class]; ytes("labels");//写入到文件中去 for(int j=0;j ytes(" "+labels[j]); ytes("n"); } } //end if //这个一定会执行的 while(true) { String line = ne(); //一行一行的读取 if(line == null) break;//如果出现空行,那么就停止,所以在文件中中间不能有 空行 StringTokenizer st = new StringTokenizer(line," tnrf:"); double target = atof(ken()); int m = okens()/2; svm_node[] x = new svm_node[m]; for(int j=0;j { x[j] = new svm_node(); x[j].index = atoi(ken()); x[j].value = atof(ken()); } double v; //如果是分类svm就执行这个 if (predict_probability==1 && (svm_type==svm_parameter.C_SVC || svm_type==svm__SVC)) { v = _predict_probability(model,x,prob_estimates); ytes(v+" "); for(int j=0;j ytes(prob_estimates[j]+" "); ytes("n"); } //end id else { v = _predict(model,x); ytes(v+"n"); } /** * 做二次开发,这里可动手脚,你可以输入要具体预测对的类在这里显示出来 等等 */ if(v == target) //如果预测正确,那么分类的正确就加一 ++correct; error += (v-target)*(v-target); sumv += v; sumy += target; sumvv += v*v; sumyy += target*target; sumvy += v*target; ++total; } //end while //如果是回归的svm就用这个 if(svm_type == svm_N_SVR || svm_type == svm__SVR) { /** * 这里打印出来的是用于回归问题的信息regression */ ("Mean squared error = "+error/total+" (regression)n"); ("Squared correlation coefficient = "+ ((total*sumvy-sumv*sumy)*(total*sumvy-sumv*sumy))/ ((total*sumvv-sumv*sumv)*(total*sumyy-sumy*sumy))+ " (regression)n"); } else //这里打印出来的是用于分类问题的信息classification ("Accuracy = "+(double)correct/total*100+ "% ("+correct+"/"+total+") (classification)n"); }//end function private static void exit_with_help() { ("usage: svm_predict [options] test_file model_file output_filen" +"options:n" +"-b probability_estimates: whether to predict probability estimates, 0 or 1 (default 0); one-class SVM not supported yetn"); (1); } //首先从这里读 public static void main(String argv[]) throws IOException { int i, predict_probability=0; //设置两个值,后面一个0表示不开启 // parse options解析选项,解析和train类类似不做说明 for(i=0;i<;i++) { if(argv[i].charAt(0) != '-') break; ++i; switch(argv[i-1].charAt(1)) { case 'b': predict_probability = atoi(argv[i]); break; default: ("Unknown option: " + argv[i-1] + "n"); exit_with_help(); } } if(i>=-2) exit_with_help(); try { BufferedReader input = new BufferedReader(new FileReader(argv[i])); DataOutputStream output = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(argv[i+2]))); svm_model model = _load_model(argv[i+1]); if(predict_probability == 1) { if(_check_probability_model(model)==0) { ("Model does not support probabiliy estimatesn"); (1); } } else { if(_check_probability_model(model)!=0) { ("Model supports probability estimates, but disabled in prediction.n"); } } /** * 重点来看这个,我们要预测或者分类,中想返回一个预测正确或者分类正 确的类别的 * 你可以按住ctrl,然后用鼠标点击这个类 * 三个个参数: * 一个是模型,已经训练出来的模型 * 一个是输入的测试数据 * 一个是是不是要打印出信息(我没用过,默认是0) */ predict(input,output,model,predict_probability); (); //涉及到文件的操作有关闭的一些操作 (); } catch(FileNotFoundException e) { exit_with_help(); } } } catch(ArrayIndexOutOfBoundsException e) { exit_with_help(); } 下面是一些二次开发的介绍 隔点搜索的代码怎么写? 1. 我们在寻找最佳svm的参数组合的时候不可能自己去手动的去设置.比如高斯核函数有 两个参数要设置,c和gamma.我们要改写train的代码,将c和gama的参数设置到man 方法中去,直接通过调用main就可以改变c和gamma的 打圈的是自己改的。 如果你能看懂上面的意思,那么我想你的java基础完全可以想出来怎么讲correct正确的 作为返回值返回到主main中,你又可以利用这个来写出属于自己的交叉验证 你可以参考一下的调用代码 package _RBF; import edWriter; import ; import iter; import ption; import _predict; import _train; public class ComMain_data_dea { /* * @param args * @throws IOException */ public void main(int ix) throws IOException { // TODO Auto-generated method stub // String []arg ={"file_","file_modeltrain1_"}; // String []parg={"file_","file_modeltrain1_","file_outtrain1_"}; // n("........SVM运行开始.........."); // svm_train t = new svm_train(); // svm_predict p= new svm_predict(); // (arg); // (parg); // n("........SVM运行结束.........."); File f_accurate = new File("file_data_accurateacc_data_"); if(f_accurate == null){ f_NewFile() ; }else{ //f_accurate. } FileWriter fw = new FileWriter(f_accurate); BufferedWriter bw = new BufferedWriter(fw); ("c表示惩罚因子,gamama表示核函数变量, temp表示交叉验证准确验 证的个数,accor表示准确率"); e(); double temp = 0; double c = 0; double val_c = 0 ; double val_gamma = 0 ; double accor = 0; double maxaccor = 0; double max_c = 0; double max_gamama = 0; double max_temp = 0; String tempStr_c =""; String tempStr_g =""; String tempStr_ac =""; for(int i1=0;i1<100;i1++){ val_gamma = 0; val_c = val_c + 0.1 ; for(int i2=0;i2<200;i2++){ val_gamma = val_gamma + 0.01 ; for(int i=1;i<6;i++){ String []arg ={"file_data_ccrtraintrain"+i+".txt","file_data_ccrmodeltrain"+i+"_"}; String []parg={"file_data_ccrtesttest"+i+".txt","file_data_ccrmodeltrain"+i+"_","file_data _ccrouttrain"+i+"_"}; n("........SVM运行开始.........."+i); svm_train t = new svm_train(); svm_predict p= new svm_predict(); (arg,val_c,val_gamma,ix,0); c = (parg); temp = temp + c; }//end for accor = (double)temp/23; tempStr_c = (val_c+"0000000").substring(0, 3); tempStr_g = (val_gamma+"0000000").substring(0, 4); tempStr_ac = (accor+"000000").substring(0, 5); //("c="+val_c+",gamama="+val_gamma+",temp="+temp+",accor="+accor); ("c="+tempStr_c+" ,gamama="+tempStr_g+" ,temp="+temp+" ,accor="+temp Str_ac); e(); if(accor>maxaccor){ maxaccor = accor ; max_c = val_c; max_gamama = val_gamma; max_temp = temp; // break; } c = 0; temp = 0; }//end for if(accor>maxaccor){ maxaccor = accor ; // break; } }//end for ("max_c="+max_c+",max_gamama="+max_gamama+",max_temp="+max_temp+", max_accor="+maxaccor); e(); (); (); //double accor = (double)temp/23; n("........"+maxaccor); n("**********max_c************" + max_c); n("**********max_gamama************" + max_gamama); n("**********max_temp************" + max_temp); // for(int i=1;i<6;i++){ // // String []arg ={"file_traintrain"+i+".txt","file_modeltrain"+i+"_"}; // String []parg={"file_testtest"+i+".txt","file_modeltrain"+i+"_","file_outtrain"+i+"_" }; // n("........SVM运行开始.........."+i); // svm_train t = new svm_train(); // svm_predict p= new svm_predict(); // (arg,val_c,val_gamma); // c = (parg); // temp = temp + c; // n("........C.........."+c); // // }//end for // double accor = (double)temp/23; // n("........"+accor); }//end main }//end class 说明: 你可以在csdn下载到包括demo的说明文档