001/**
002 * Copyright (C) 2007 - 2016, Jens Lehmann
003 *
004 * This file is part of DL-Learner.
005 *
006 * DL-Learner is free software; you can redistribute it and/or modify
007 * it under the terms of the GNU General Public License as published by
008 * the Free Software Foundation; either version 3 of the License, or
009 * (at your option) any later version.
010 *
011 * DL-Learner is distributed in the hope that it will be useful,
012 * but WITHOUT ANY WARRANTY; without even the implied warranty of
013 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
014 * GNU General Public License for more details.
015 *
016 * You should have received a copy of the GNU General Public License
017 * along with this program.  If not, see <http://www.gnu.org/licenses/>.
018 */
019package org.dllearner.algorithms.semkernel;
020
021import java.io.BufferedOutputStream;
022import java.io.BufferedReader;
023import java.io.DataOutputStream;
024import java.io.File;
025import java.io.FileOutputStream;
026import java.io.FileReader;
027import java.io.IOException;
028
029import org.dllearner.core.AbstractComponent;
030
031import semlibsvm.svm_predict;
032import semlibsvm.svm_train;
033import semlibsvm.libsvm.svm;
034import semlibsvm.libsvm.svm_model;
035import semlibsvm.libsvm.svm_parameter;
036import edu.berkeley.compbio.jlibsvm.multi.MultiClassModel.AllVsAllMode;
037import edu.berkeley.compbio.jlibsvm.multi.MultiClassModel.OneVsAllMode;
038
039public class SemKernel extends AbstractComponent {
040    public enum SvmType {
041        C_SVC,
042        NU_SVC,
043        ONE_CLASS,
044        EPSILON_SVR,
045        NU_SVR
046    }
047
048    public enum ScalingMode { NONE, LINEAR, ZSCORE }
049
050    private boolean useCrossValidation;
051    private static final Float UNSPECIFIED_GAMMA = -1F;
052    private boolean predictProbability;
053
054    // SVM params
055    private svm_parameter svmParams;
056    private float nu = 0.5f;
057    private int cacheSize = 100;
058    private float epsilon = 1e-3f;
059    private float p = 0.1f;
060    private boolean doShrinking = true;
061    private boolean doProbabilityEstimates = false;
062    /**
063     * For unbalanced data, redistribute the misclassification cost C according
064     * to the numbers of examples in each class, so that each class has the
065     * same total misclassification weight assigned to it and the average is
066     * param.C
067     * (from edu.berkeley.compbio.jlibsvm.ImmutableSvmParameter.java)
068     * */
069    private boolean redistributeUnbalanbcedCosts = true;
070    private SvmType svmType = SvmType.C_SVC;
071    /** degree in kernel function */
072    private int degree = 3;
073    /** gammas in kernel function */
074    private double gamma = UNSPECIFIED_GAMMA;
075    /** coef0 in kernel function */
076    private int coef0 = 0;
077    /** the parameter C of C-SVC, epsilon-SVR, and nu-SVR */
078    private double cost;
079    private int crossValidationFolds;
080    // TODO: weights
081    /** allVsAllMode: None, AllVsAll, FilteredVsAll, FilteredVsFiltered */
082    private AllVsAllMode allVsAllMode;
083    /** oneVsAllMode: None, Best, Veto, BreakTies, VetoAndBreakTies */
084    private OneVsAllMode oneVsAllMode;
085    private double oneVsAllThreshold = -1;
086    /** the chosen class must have at least this proportion of the total votes */
087    private double minVoteProportion = -1;
088    /** scalingmode : none (default), linear, zscore */
089    private ScalingMode scalingMode = ScalingMode.NONE;
090    /** scalinglimit : maximum examples to use for scaling (default 1000) */
091    private int scalingLimit = 1000;
092    /** project to unit sphere (normalize L2 distance) */
093    private boolean normalizeL2 = false;
094
095    // input/output
096    // TODO: use fallback for these if not set (create tmp dir)
097    private String ontologyFilePath;
098    private String trainingDirPath;
099    private String modelDirPath;
100    private String predictionDataDirPath;
101    private String resultsDirPath;
102
103    @Override
104    public void init() {
105        svmParams = new svm_parameter();
106        svmParams.C = cost;
107        svmParams.cache_size = cacheSize;
108        // params.class2id  // set by svm_train.read_problem
109        svmParams.coef0 = coef0;
110        svmParams.degree = degree;
111        svmParams.eps = epsilon;
112        svmParams.gamma = gamma;
113        svmParams.kernel_type = svm_parameter.SEMANTIC;
114        svmParams.nr_weight = 0;  // TODO: make configurable
115        svmParams.nu = nu;
116        svmParams.ontology_file = ontologyFilePath;
117        svmParams.p = p;
118        svmParams.probability = doProbabilityEstimates ? 1 : 0;
119        svmParams.shrinking = doShrinking ? 1 : 0;
120
121        switch (svmType) {
122            case C_SVC:
123                svmParams.svm_type = svm_parameter.C_SVC;
124                break;
125            case NU_SVC:
126                svmParams.svm_type = svm_parameter.NU_SVC;
127                break;
128            case ONE_CLASS:
129                svmParams.svm_type = svm_parameter.ONE_CLASS;
130                break;
131            case EPSILON_SVR:
132                svmParams.svm_type = svm_parameter.EPSILON_SVR;
133                break;
134            case NU_SVR:
135                svmParams.svm_type = svm_parameter.NU_SVR;
136                break;
137        }
138
139        svmParams.weight = new double[0];  // TODO: make configurable
140        svmParams.weight_label = new int[0];  // TODO make configurable
141        
142        initialized = true;
143    }
144
145    public void train() {
146        svm_train svmTrain = new svm_train();
147        File trainDir = new File(trainingDirPath);
148
149        for (String trainFileName : trainDir.list()) {
150            String modelFilePath;
151            if (!modelDirPath.endsWith(File.separator)) {
152                modelFilePath = modelDirPath + File.separator + trainFileName;
153            } else {
154                modelFilePath = modelDirPath + trainFileName;
155            }
156
157            String trainFilePath;
158            if (!trainingDirPath.endsWith(File.separator)) {
159                trainFilePath = trainingDirPath + File.separator + trainFileName;
160            } else {
161                trainFilePath = trainingDirPath + trainFileName;
162            }
163            try {
164                svmTrain.run(svmParams, trainFilePath, modelFilePath);
165            } catch (IOException e) {
166                e.printStackTrace();
167                System.exit(1);
168            }
169        }
170    }
171
172    public void predict() {
173        File predDataDir = new File(predictionDataDirPath);
174
175        for (String predFileName : predDataDir.list()) {
176            String predFilePath;
177            if (!predictionDataDirPath.endsWith(File.separator)) {
178                predFilePath = predictionDataDirPath + File.separator +
179                        predFileName;
180            } else {
181                predFilePath = predictionDataDirPath + predFileName;
182            }
183
184            String modelFilePath;
185            if (!modelDirPath.endsWith(File.separator)) {
186                modelFilePath = modelDirPath + File.separator + predFileName;
187            } else {
188                modelFilePath = modelDirPath + predFileName;
189            }
190
191            String resultFilePath;
192            if (!resultsDirPath.endsWith(File.separator)) {
193                resultFilePath = resultsDirPath + File.separator + predFileName;
194            } else {
195                resultFilePath = resultsDirPath + predFileName;
196            }
197
198            try {
199                svm_model model = svm.svm_load_model(modelFilePath);
200
201                if (model == null) {
202                    final String msg = String.format(
203                            "can't open model file %s", modelFilePath);
204                    throw new Exception(msg);
205                }
206
207                model.param.ontology_file = ontologyFilePath;
208                svm.initSimilarityEngine(ontologyFilePath);
209
210                if(predictProbability) {
211                    if(svm.svm_check_probability_model(model)==0) {
212                        final String msg =
213                                "Model does not support probabiliy estimates";
214                        throw new Exception(msg);
215                    }
216                } else {
217                    if(svm.svm_check_probability_model(model)!=0) {
218                        svm_predict.info("Model supports probability " +
219                                "estimates, but disabled in prediction.\n");
220                    }
221                }
222
223                BufferedReader predFileReader = new BufferedReader(
224                        new FileReader(predFilePath));
225                DataOutputStream resStream = new DataOutputStream(
226                        new BufferedOutputStream(
227                                new FileOutputStream(resultFilePath)));
228
229                int predProbInt = predictProbability ? 1 : 0;
230
231                svm_predict.predict(predFileReader, resStream, model, predProbInt);
232
233                predFileReader.close();
234                resStream.close();
235
236            } catch (Exception e) {
237                e.printStackTrace();
238                System.exit(1);
239            }
240        }
241    }
242    // ------------------- only getters and setters below ---------------------
243    public boolean isUseCrossValidation() {
244        return useCrossValidation;
245    }
246
247    public void setUseCrossValidation(boolean useCrossValidation) {
248        this.useCrossValidation = useCrossValidation;
249    }
250
251    public float getNu() {
252        return nu;
253    }
254
255    public void setNu(float nu) {
256        this.nu = nu;
257    }
258
259    public int getCacheSize() {
260        return cacheSize;
261    }
262
263    public void setCacheSize(int cacheSize) {
264        this.cacheSize = cacheSize;
265    }
266
267    public float getEpsilon() {
268        return epsilon;
269    }
270
271    public void setEpsilon(float epsilon) {
272        this.epsilon = epsilon;
273    }
274
275    public float getP() {
276        return p;
277    }
278
279    public void setP(float p) {
280        this.p = p;
281    }
282
283    public boolean isDoShrinking() {
284        return doShrinking;
285    }
286
287    public void setDoShrinking(boolean doShrinking) {
288        this.doShrinking = doShrinking;
289    }
290
291    public boolean isDoProbabilityEstimates() {
292        return doProbabilityEstimates;
293    }
294
295    public void setDoProbabilityEstimates(boolean doProbabilityEstimates) {
296        this.doProbabilityEstimates = doProbabilityEstimates;
297    }
298
299    public boolean isRedistributeUnbalanbcedCosts() {
300        return redistributeUnbalanbcedCosts;
301    }
302
303    public void setRedistributeUnbalanbcedCosts(boolean redistributeUnbalanbcedCosts) {
304        this.redistributeUnbalanbcedCosts = redistributeUnbalanbcedCosts;
305    }
306
307    public SvmType getSvmType() {
308        return svmType;
309    }
310
311    public void setSvmType(SvmType svmType) {
312        this.svmType = svmType;
313    }
314
315    public int getDegree() {
316        return degree;
317    }
318
319    public void setDegree(int degree) {
320        this.degree = degree;
321    }
322
323    public int getCoef0() {
324        return coef0;
325    }
326
327    public void setCoef0(int coef0) {
328        this.coef0 = coef0;
329    }
330
331    public double getGamma() {
332        return gamma;
333    }
334
335    public void setGamma(double gammaSet) {
336        this.gamma = gammaSet;
337    }
338
339    public double getCost() {
340        return cost;
341    }
342
343    public void setCost(double costs) {
344        this.cost = costs;
345    }
346
347    public int getCrossValidationFolds() {
348        return crossValidationFolds;
349    }
350
351    public void setCrossValidationFolds(int crossValidationFolds) {
352        this.useCrossValidation = true;
353        this.crossValidationFolds = crossValidationFolds;
354    }
355
356    public AllVsAllMode getAllVsAllMode() {
357        return allVsAllMode;
358    }
359
360    public void setAllVsAllMode(AllVsAllMode allVsAllMode) {
361        this.allVsAllMode = allVsAllMode;
362    }
363
364    public OneVsAllMode getOneVsAllMode() {
365        return oneVsAllMode;
366    }
367
368    public void setOneVsAllMode(OneVsAllMode oneVsAllMode) {
369        this.oneVsAllMode = oneVsAllMode;
370    }
371
372    public double getOneVsAllThreshold() {
373        return oneVsAllThreshold;
374    }
375
376    public void setOneVsAllThreshold(double oneVsAllThreshold) {
377        this.oneVsAllThreshold = oneVsAllThreshold;
378    }
379
380    public double getMinVoteProportion() {
381        return minVoteProportion;
382    }
383
384    public void setMinVoteProportion(double minVoteProportion) {
385        this.minVoteProportion = minVoteProportion;
386    }
387
388    public ScalingMode getScalingMode() {
389        return scalingMode;
390    }
391
392    public void setScalingMode(ScalingMode scalingMode) {
393        this.scalingMode = scalingMode;
394    }
395
396    public int getScalingLimit() {
397        return scalingLimit;
398    }
399
400    public void setScalingLimit(int scalingLimit) {
401        this.scalingLimit = scalingLimit;
402    }
403
404    public boolean isNormalizeL2() {
405        return normalizeL2;
406    }
407
408    public void setNormalizeL2(boolean normalizeL2) {
409        this.normalizeL2 = normalizeL2;
410    }
411
412    public String getTrainingOutputDirPath() {
413        return trainingDirPath;
414    }
415
416    public void setTrainingDirPath(String trainingOutputDirPath) {
417        this.trainingDirPath = trainingOutputDirPath;
418    }
419
420    public String getOntologyFilePath() {
421        return ontologyFilePath;
422    }
423
424    public void setOntologyFilePath(String ontologyFilePath) {
425        this.ontologyFilePath = ontologyFilePath;
426    }
427
428    public String getModelDirPath() {
429        return modelDirPath;
430    }
431
432    public void setModelDirPath(String modelDirPath) {
433        this.modelDirPath = modelDirPath;
434    }
435
436    public String getPredictionDataDirPath() {
437        return predictionDataDirPath;
438    }
439
440    public void setPredictionDataDirPath(String predictionDataDirPath) {
441        this.predictionDataDirPath = predictionDataDirPath;
442    }
443
444    public String getResultsDirPath() {
445        return resultsDirPath;
446    }
447
448    public void setResultsDirPath(String resultsDirPath) {
449        this.resultsDirPath = resultsDirPath;
450    }
451
452    public boolean isPredictProbability() {
453        return predictProbability;
454    }
455
456    public void setPredictProbability(boolean predictProbability) {
457        this.predictProbability = predictProbability;
458    }
459}