001/** 002 * Copyright (C) 2007 - 2016, Jens Lehmann 003 * 004 * This file is part of DL-Learner. 005 * 006 * DL-Learner is free software; you can redistribute it and/or modify 007 * it under the terms of the GNU General Public License as published by 008 * the Free Software Foundation; either version 3 of the License, or 009 * (at your option) any later version. 010 * 011 * DL-Learner is distributed in the hope that it will be useful, 012 * but WITHOUT ANY WARRANTY; without even the implied warranty of 013 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 014 * GNU General Public License for more details. 015 * 016 * You should have received a copy of the GNU General Public License 017 * along with this program. If not, see <http://www.gnu.org/licenses/>. 018 */ 019package org.dllearner.algorithms.semkernel; 020 021import java.io.BufferedOutputStream; 022import java.io.BufferedReader; 023import java.io.DataOutputStream; 024import java.io.File; 025import java.io.FileOutputStream; 026import java.io.FileReader; 027import java.io.IOException; 028 029import org.dllearner.core.AbstractComponent; 030 031import semlibsvm.svm_predict; 032import semlibsvm.svm_train; 033import semlibsvm.libsvm.svm; 034import semlibsvm.libsvm.svm_model; 035import semlibsvm.libsvm.svm_parameter; 036import edu.berkeley.compbio.jlibsvm.multi.MultiClassModel.AllVsAllMode; 037import edu.berkeley.compbio.jlibsvm.multi.MultiClassModel.OneVsAllMode; 038 039public class SemKernel extends AbstractComponent { 040 public enum SvmType { 041 C_SVC, 042 NU_SVC, 043 ONE_CLASS, 044 EPSILON_SVR, 045 NU_SVR 046 } 047 048 public enum ScalingMode { NONE, LINEAR, ZSCORE } 049 050 private boolean useCrossValidation; 051 private static final Float UNSPECIFIED_GAMMA = -1F; 052 private boolean predictProbability; 053 054 // SVM params 055 private svm_parameter svmParams; 056 private float nu = 0.5f; 057 private int cacheSize = 100; 058 private float epsilon = 1e-3f; 059 private float p = 0.1f; 060 private boolean doShrinking = true; 061 private boolean doProbabilityEstimates = false; 062 /** 063 * For unbalanced data, redistribute the misclassification cost C according 064 * to the numbers of examples in each class, so that each class has the 065 * same total misclassification weight assigned to it and the average is 066 * param.C 067 * (from edu.berkeley.compbio.jlibsvm.ImmutableSvmParameter.java) 068 * */ 069 private boolean redistributeUnbalanbcedCosts = true; 070 private SvmType svmType = SvmType.C_SVC; 071 /** degree in kernel function */ 072 private int degree = 3; 073 /** gammas in kernel function */ 074 private double gamma = UNSPECIFIED_GAMMA; 075 /** coef0 in kernel function */ 076 private int coef0 = 0; 077 /** the parameter C of C-SVC, epsilon-SVR, and nu-SVR */ 078 private double cost; 079 private int crossValidationFolds; 080 // TODO: weights 081 /** allVsAllMode: None, AllVsAll, FilteredVsAll, FilteredVsFiltered */ 082 private AllVsAllMode allVsAllMode; 083 /** oneVsAllMode: None, Best, Veto, BreakTies, VetoAndBreakTies */ 084 private OneVsAllMode oneVsAllMode; 085 private double oneVsAllThreshold = -1; 086 /** the chosen class must have at least this proportion of the total votes */ 087 private double minVoteProportion = -1; 088 /** scalingmode : none (default), linear, zscore */ 089 private ScalingMode scalingMode = ScalingMode.NONE; 090 /** scalinglimit : maximum examples to use for scaling (default 1000) */ 091 private int scalingLimit = 1000; 092 /** project to unit sphere (normalize L2 distance) */ 093 private boolean normalizeL2 = false; 094 095 // input/output 096 // TODO: use fallback for these if not set (create tmp dir) 097 private String ontologyFilePath; 098 private String trainingDirPath; 099 private String modelDirPath; 100 private String predictionDataDirPath; 101 private String resultsDirPath; 102 103 @Override 104 public void init() { 105 svmParams = new svm_parameter(); 106 svmParams.C = cost; 107 svmParams.cache_size = cacheSize; 108 // params.class2id // set by svm_train.read_problem 109 svmParams.coef0 = coef0; 110 svmParams.degree = degree; 111 svmParams.eps = epsilon; 112 svmParams.gamma = gamma; 113 svmParams.kernel_type = svm_parameter.SEMANTIC; 114 svmParams.nr_weight = 0; // TODO: make configurable 115 svmParams.nu = nu; 116 svmParams.ontology_file = ontologyFilePath; 117 svmParams.p = p; 118 svmParams.probability = doProbabilityEstimates ? 1 : 0; 119 svmParams.shrinking = doShrinking ? 1 : 0; 120 121 switch (svmType) { 122 case C_SVC: 123 svmParams.svm_type = svm_parameter.C_SVC; 124 break; 125 case NU_SVC: 126 svmParams.svm_type = svm_parameter.NU_SVC; 127 break; 128 case ONE_CLASS: 129 svmParams.svm_type = svm_parameter.ONE_CLASS; 130 break; 131 case EPSILON_SVR: 132 svmParams.svm_type = svm_parameter.EPSILON_SVR; 133 break; 134 case NU_SVR: 135 svmParams.svm_type = svm_parameter.NU_SVR; 136 break; 137 } 138 139 svmParams.weight = new double[0]; // TODO: make configurable 140 svmParams.weight_label = new int[0]; // TODO make configurable 141 142 initialized = true; 143 } 144 145 public void train() { 146 svm_train svmTrain = new svm_train(); 147 File trainDir = new File(trainingDirPath); 148 149 for (String trainFileName : trainDir.list()) { 150 String modelFilePath; 151 if (!modelDirPath.endsWith(File.separator)) { 152 modelFilePath = modelDirPath + File.separator + trainFileName; 153 } else { 154 modelFilePath = modelDirPath + trainFileName; 155 } 156 157 String trainFilePath; 158 if (!trainingDirPath.endsWith(File.separator)) { 159 trainFilePath = trainingDirPath + File.separator + trainFileName; 160 } else { 161 trainFilePath = trainingDirPath + trainFileName; 162 } 163 try { 164 svmTrain.run(svmParams, trainFilePath, modelFilePath); 165 } catch (IOException e) { 166 e.printStackTrace(); 167 System.exit(1); 168 } 169 } 170 } 171 172 public void predict() { 173 File predDataDir = new File(predictionDataDirPath); 174 175 for (String predFileName : predDataDir.list()) { 176 String predFilePath; 177 if (!predictionDataDirPath.endsWith(File.separator)) { 178 predFilePath = predictionDataDirPath + File.separator + 179 predFileName; 180 } else { 181 predFilePath = predictionDataDirPath + predFileName; 182 } 183 184 String modelFilePath; 185 if (!modelDirPath.endsWith(File.separator)) { 186 modelFilePath = modelDirPath + File.separator + predFileName; 187 } else { 188 modelFilePath = modelDirPath + predFileName; 189 } 190 191 String resultFilePath; 192 if (!resultsDirPath.endsWith(File.separator)) { 193 resultFilePath = resultsDirPath + File.separator + predFileName; 194 } else { 195 resultFilePath = resultsDirPath + predFileName; 196 } 197 198 try { 199 svm_model model = svm.svm_load_model(modelFilePath); 200 201 if (model == null) { 202 final String msg = String.format( 203 "can't open model file %s", modelFilePath); 204 throw new Exception(msg); 205 } 206 207 model.param.ontology_file = ontologyFilePath; 208 svm.initSimilarityEngine(ontologyFilePath); 209 210 if(predictProbability) { 211 if(svm.svm_check_probability_model(model)==0) { 212 final String msg = 213 "Model does not support probabiliy estimates"; 214 throw new Exception(msg); 215 } 216 } else { 217 if(svm.svm_check_probability_model(model)!=0) { 218 svm_predict.info("Model supports probability " + 219 "estimates, but disabled in prediction.\n"); 220 } 221 } 222 223 BufferedReader predFileReader = new BufferedReader( 224 new FileReader(predFilePath)); 225 DataOutputStream resStream = new DataOutputStream( 226 new BufferedOutputStream( 227 new FileOutputStream(resultFilePath))); 228 229 int predProbInt = predictProbability ? 1 : 0; 230 231 svm_predict.predict(predFileReader, resStream, model, predProbInt); 232 233 predFileReader.close(); 234 resStream.close(); 235 236 } catch (Exception e) { 237 e.printStackTrace(); 238 System.exit(1); 239 } 240 } 241 } 242 // ------------------- only getters and setters below --------------------- 243 public boolean isUseCrossValidation() { 244 return useCrossValidation; 245 } 246 247 public void setUseCrossValidation(boolean useCrossValidation) { 248 this.useCrossValidation = useCrossValidation; 249 } 250 251 public float getNu() { 252 return nu; 253 } 254 255 public void setNu(float nu) { 256 this.nu = nu; 257 } 258 259 public int getCacheSize() { 260 return cacheSize; 261 } 262 263 public void setCacheSize(int cacheSize) { 264 this.cacheSize = cacheSize; 265 } 266 267 public float getEpsilon() { 268 return epsilon; 269 } 270 271 public void setEpsilon(float epsilon) { 272 this.epsilon = epsilon; 273 } 274 275 public float getP() { 276 return p; 277 } 278 279 public void setP(float p) { 280 this.p = p; 281 } 282 283 public boolean isDoShrinking() { 284 return doShrinking; 285 } 286 287 public void setDoShrinking(boolean doShrinking) { 288 this.doShrinking = doShrinking; 289 } 290 291 public boolean isDoProbabilityEstimates() { 292 return doProbabilityEstimates; 293 } 294 295 public void setDoProbabilityEstimates(boolean doProbabilityEstimates) { 296 this.doProbabilityEstimates = doProbabilityEstimates; 297 } 298 299 public boolean isRedistributeUnbalanbcedCosts() { 300 return redistributeUnbalanbcedCosts; 301 } 302 303 public void setRedistributeUnbalanbcedCosts(boolean redistributeUnbalanbcedCosts) { 304 this.redistributeUnbalanbcedCosts = redistributeUnbalanbcedCosts; 305 } 306 307 public SvmType getSvmType() { 308 return svmType; 309 } 310 311 public void setSvmType(SvmType svmType) { 312 this.svmType = svmType; 313 } 314 315 public int getDegree() { 316 return degree; 317 } 318 319 public void setDegree(int degree) { 320 this.degree = degree; 321 } 322 323 public int getCoef0() { 324 return coef0; 325 } 326 327 public void setCoef0(int coef0) { 328 this.coef0 = coef0; 329 } 330 331 public double getGamma() { 332 return gamma; 333 } 334 335 public void setGamma(double gammaSet) { 336 this.gamma = gammaSet; 337 } 338 339 public double getCost() { 340 return cost; 341 } 342 343 public void setCost(double costs) { 344 this.cost = costs; 345 } 346 347 public int getCrossValidationFolds() { 348 return crossValidationFolds; 349 } 350 351 public void setCrossValidationFolds(int crossValidationFolds) { 352 this.useCrossValidation = true; 353 this.crossValidationFolds = crossValidationFolds; 354 } 355 356 public AllVsAllMode getAllVsAllMode() { 357 return allVsAllMode; 358 } 359 360 public void setAllVsAllMode(AllVsAllMode allVsAllMode) { 361 this.allVsAllMode = allVsAllMode; 362 } 363 364 public OneVsAllMode getOneVsAllMode() { 365 return oneVsAllMode; 366 } 367 368 public void setOneVsAllMode(OneVsAllMode oneVsAllMode) { 369 this.oneVsAllMode = oneVsAllMode; 370 } 371 372 public double getOneVsAllThreshold() { 373 return oneVsAllThreshold; 374 } 375 376 public void setOneVsAllThreshold(double oneVsAllThreshold) { 377 this.oneVsAllThreshold = oneVsAllThreshold; 378 } 379 380 public double getMinVoteProportion() { 381 return minVoteProportion; 382 } 383 384 public void setMinVoteProportion(double minVoteProportion) { 385 this.minVoteProportion = minVoteProportion; 386 } 387 388 public ScalingMode getScalingMode() { 389 return scalingMode; 390 } 391 392 public void setScalingMode(ScalingMode scalingMode) { 393 this.scalingMode = scalingMode; 394 } 395 396 public int getScalingLimit() { 397 return scalingLimit; 398 } 399 400 public void setScalingLimit(int scalingLimit) { 401 this.scalingLimit = scalingLimit; 402 } 403 404 public boolean isNormalizeL2() { 405 return normalizeL2; 406 } 407 408 public void setNormalizeL2(boolean normalizeL2) { 409 this.normalizeL2 = normalizeL2; 410 } 411 412 public String getTrainingOutputDirPath() { 413 return trainingDirPath; 414 } 415 416 public void setTrainingDirPath(String trainingOutputDirPath) { 417 this.trainingDirPath = trainingOutputDirPath; 418 } 419 420 public String getOntologyFilePath() { 421 return ontologyFilePath; 422 } 423 424 public void setOntologyFilePath(String ontologyFilePath) { 425 this.ontologyFilePath = ontologyFilePath; 426 } 427 428 public String getModelDirPath() { 429 return modelDirPath; 430 } 431 432 public void setModelDirPath(String modelDirPath) { 433 this.modelDirPath = modelDirPath; 434 } 435 436 public String getPredictionDataDirPath() { 437 return predictionDataDirPath; 438 } 439 440 public void setPredictionDataDirPath(String predictionDataDirPath) { 441 this.predictionDataDirPath = predictionDataDirPath; 442 } 443 444 public String getResultsDirPath() { 445 return resultsDirPath; 446 } 447 448 public void setResultsDirPath(String resultsDirPath) { 449 this.resultsDirPath = resultsDirPath; 450 } 451 452 public boolean isPredictProbability() { 453 return predictProbability; 454 } 455 456 public void setPredictProbability(boolean predictProbability) { 457 this.predictProbability = predictProbability; 458 } 459}