001/** 002 * Copyright (C) 2007-2008, Jens Lehmann 003 * 004 * This file is part of DL-Learner. 005 * 006 * DL-Learner is free software; you can redistribute it and/or modify 007 * it under the terms of the GNU General Public License as published by 008 * the Free Software Foundation; either version 3 of the License, or 009 * (at your option) any later version. 010 * 011 * DL-Learner is distributed in the hope that it will be useful, 012 * but WITHOUT ANY WARRANTY; without even the implied warranty of 013 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 014 * GNU General Public License for more details. 015 * 016 * You should have received a copy of the GNU General Public License 017 * along with this program. If not, see <http://www.gnu.org/licenses/>. 018 * 019 */ 020package org.dllearner.cli; 021 022import com.google.common.collect.Sets; 023import org.dllearner.core.*; 024import org.dllearner.learningproblems.Heuristics; 025import org.dllearner.learningproblems.PosNegLP; 026import org.dllearner.learningproblems.PosOnlyLP; 027import org.dllearner.utilities.Files; 028import org.dllearner.utilities.Helper; 029import org.dllearner.utilities.owl.ManchesterOWLSyntaxOWLObjectRendererImplExt; 030import org.dllearner.utilities.owl.OWLClassExpressionUtils; 031import org.dllearner.utilities.statistics.Stat; 032import org.semanticweb.owlapi.model.OWLClassExpression; 033import org.semanticweb.owlapi.model.OWLIndividual; 034import org.semanticweb.owlapi.util.SimpleShortFormProvider; 035 036import java.io.File; 037import java.lang.reflect.InvocationTargetException; 038import java.text.DecimalFormat; 039import java.util.*; 040import java.util.concurrent.ExecutorService; 041import java.util.concurrent.Executors; 042import java.util.concurrent.TimeUnit; 043 044/** 045 * Performs cross validation for the given problem. Supports 046 * k-fold cross-validation and leave-one-out cross-validation. 047 * 048 * @author Jens Lehmann 049 * 050 */ 051public class CrossValidation { 052 053 // statistical values 054 protected Stat runtime = new Stat(); 055 protected Stat accuracy = new Stat(); 056 protected Stat length = new Stat(); 057 protected Stat accuracyTraining = new Stat(); 058 protected Stat fMeasure = new Stat(); 059 protected Stat fMeasureTraining = new Stat(); 060 public static boolean writeToFile = false; 061 public static File outputFile; 062 public static boolean multiThreaded = false; 063 064 protected Stat trainingCompletenessStat = new Stat(); 065 protected Stat trainingCorrectnessStat = new Stat(); 066 067 protected Stat testingCompletenessStat = new Stat(); 068 protected Stat testingCorrectnessStat = new Stat(); 069 070 DecimalFormat df = new DecimalFormat(); 071 072 073 074 public CrossValidation() { 075 076 } 077 078 public CrossValidation(AbstractCELA la, AbstractClassExpressionLearningProblem lp, final AbstractReasonerComponent rs, int folds, boolean leaveOneOut) { 079 //console rendering of class expressions 080 ManchesterOWLSyntaxOWLObjectRendererImplExt renderer = new ManchesterOWLSyntaxOWLObjectRendererImplExt(); 081 StringRenderer.setRenderer(renderer); 082 StringRenderer.setShortFormProvider(new SimpleShortFormProvider()); 083 084 // the training and test sets used later on 085 List<Set<OWLIndividual>> trainingSetsPos = new LinkedList<>(); 086 List<Set<OWLIndividual>> trainingSetsNeg = new LinkedList<>(); 087 List<Set<OWLIndividual>> testSetsPos = new LinkedList<>(); 088 List<Set<OWLIndividual>> testSetsNeg = new LinkedList<>(); 089 090 // get examples and shuffle them too 091 Set<OWLIndividual> posExamples; 092 Set<OWLIndividual> negExamples; 093 if(lp instanceof PosNegLP){ 094 posExamples = ((PosNegLP)lp).getPositiveExamples(); 095 negExamples = ((PosNegLP)lp).getNegativeExamples(); 096 } else if(lp instanceof PosOnlyLP){ 097 posExamples = ((PosNegLP)lp).getPositiveExamples(); 098 negExamples = new HashSet<>(); 099 } else { 100 throw new IllegalArgumentException("Only PosNeg and PosOnly learning problems are supported"); 101 } 102 List<OWLIndividual> posExamplesList = new LinkedList<>(posExamples); 103 List<OWLIndividual> negExamplesList = new LinkedList<>(negExamples); 104 Collections.shuffle(posExamplesList, new Random(1)); 105 Collections.shuffle(negExamplesList, new Random(2)); 106 107 // sanity check whether nr. of folds makes sense for this benchmark 108 if(!leaveOneOut && (posExamples.size()<folds && negExamples.size()<folds)) { 109 System.out.println("The number of folds is higher than the number of " 110 + "positive/negative examples. This can result in empty test sets. Exiting."); 111 System.exit(0); 112 } 113 114 if(leaveOneOut) { 115 // note that leave-one-out is not identical to k-fold with 116 // k = nr. of examples in the current implementation, because 117 // with n folds and n examples there is no guarantee that a fold 118 // is never empty (this is an implementation issue) 119 int nrOfExamples = posExamples.size() + negExamples.size(); 120 for(int i = 0; i < nrOfExamples; i++) { 121 // ... 122 } 123 System.out.println("Leave-one-out not supported yet."); 124 System.exit(1); 125 } else { 126 // calculating where to split the sets, ; note that we split 127 // positive and negative examples separately such that the 128 // distribution of positive and negative examples remains similar 129 // (note that there are better but more complex ways to implement this, 130 // which guarantee that the sum of the elements of a fold for pos 131 // and neg differs by at most 1 - it can differ by 2 in our implementation, 132 // e.g. with 3 folds, 4 pos. examples, 4 neg. examples) 133 int[] splitsPos = calculateSplits(posExamples.size(),folds); 134 int[] splitsNeg = calculateSplits(negExamples.size(),folds); 135 136// System.out.println(splitsPos[0]); 137// System.out.println(splitsNeg[0]); 138 139 // calculating training and test sets 140 for(int i=0; i<folds; i++) { 141 Set<OWLIndividual> testPos = getTestingSet(posExamplesList, splitsPos, i); 142 Set<OWLIndividual> testNeg = getTestingSet(negExamplesList, splitsNeg, i); 143 testSetsPos.add(i, testPos); 144 testSetsNeg.add(i, testNeg); 145 trainingSetsPos.add(i, getTrainingSet(posExamples, testPos)); 146 trainingSetsNeg.add(i, getTrainingSet(negExamples, testNeg)); 147 } 148 149 } 150 151 // run the algorithm 152 if( multiThreaded && lp instanceof Cloneable && la instanceof Cloneable){ 153 ExecutorService es = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors()-1); 154 for(int currFold=0; currFold<folds; currFold++) { 155 try { 156 final AbstractClassExpressionLearningProblem lpClone = (AbstractClassExpressionLearningProblem) lp.getClass().getMethod("clone").invoke(lp); 157 final Set<OWLIndividual> trainPos = trainingSetsPos.get(currFold); 158 final Set<OWLIndividual> trainNeg = trainingSetsNeg.get(currFold); 159 final Set<OWLIndividual> testPos = testSetsPos.get(currFold); 160 final Set<OWLIndividual> testNeg = testSetsNeg.get(currFold); 161 if(lp instanceof PosNegLP){ 162 ((PosNegLP)lpClone).setPositiveExamples(trainPos); 163 ((PosNegLP)lpClone).setNegativeExamples(trainNeg); 164 } else if(lp instanceof PosOnlyLP){ 165 ((PosOnlyLP)lpClone).setPositiveExamples(new TreeSet<>(trainPos)); 166 } 167 final AbstractCELA laClone = (AbstractCELA) la.getClass().getMethod("clone").invoke(la); 168 final int i = currFold; 169 es.submit(new Runnable() { 170 171 @Override 172 public void run() { 173 try { 174 validate(laClone, lpClone, rs, i, trainPos, trainNeg, testPos, testNeg); 175 } catch (Exception e) { 176 e.printStackTrace(); 177 } 178 } 179 }); 180 } catch (IllegalAccessException | SecurityException | NoSuchMethodException | InvocationTargetException | IllegalArgumentException e) { 181 e.printStackTrace(); 182 } 183 } 184 es.shutdown(); 185 try { 186 es.awaitTermination(1, TimeUnit.DAYS); 187 } catch (InterruptedException e) { 188 e.printStackTrace(); 189 } 190 } else { 191 for(int currFold=0; currFold<folds; currFold++) { 192 final Set<OWLIndividual> trainPos = trainingSetsPos.get(currFold); 193 final Set<OWLIndividual> trainNeg = trainingSetsNeg.get(currFold); 194 final Set<OWLIndividual> testPos = testSetsPos.get(currFold); 195 final Set<OWLIndividual> testNeg = testSetsNeg.get(currFold); 196 197 if(lp instanceof PosNegLP){ 198 ((PosNegLP)lp).setPositiveExamples(trainPos); 199 ((PosNegLP)lp).setNegativeExamples(trainNeg); 200 } else if(lp instanceof PosOnlyLP){ 201 ((PosOnlyLP)lp).setPositiveExamples(new TreeSet<>(trainPos)); 202 } 203 204 validate(la, lp, rs, currFold, trainPos, trainNeg, testPos, testNeg); 205 } 206 } 207 208 outputWriter(""); 209 outputWriter("Finished " + folds + "-folds cross-validation."); 210 outputWriter("runtime: " + statOutput(df, runtime, "s")); 211 outputWriter("length: " + statOutput(df, length, "")); 212 outputWriter("F-Measure on training set: " + statOutput(df, fMeasureTraining, "%")); 213 outputWriter("F-Measure: " + statOutput(df, fMeasure, "%")); 214 outputWriter("predictive accuracy on training set: " + statOutput(df, accuracyTraining, "%")); 215 outputWriter("predictive accuracy: " + statOutput(df, accuracy, "%")); 216 217 } 218 219 private void validate(AbstractCELA la, AbstractClassExpressionLearningProblem lp, AbstractReasonerComponent rs, 220 int currFold, Set<OWLIndividual> trainPos, Set<OWLIndividual> trainNeg, Set<OWLIndividual> testPos, Set<OWLIndividual> testNeg){ 221 Set<String> pos = Helper.getStringSet(trainPos); 222 Set<String> neg = Helper.getStringSet(trainNeg); 223 String output = ""; 224 output += "+" + new TreeSet<>(pos) + "\n"; 225 output += "-" + new TreeSet<>(neg) + "\n"; 226 try { 227 lp.init(); 228 la.setLearningProblem(lp); 229 la.init(); 230 } catch (ComponentInitException e) { 231 // TODO Auto-generated catch block 232 e.printStackTrace(); 233 } 234 235 long algorithmStartTime = System.nanoTime(); 236 la.start(); 237 long algorithmDuration = System.nanoTime() - algorithmStartTime; 238 runtime.addNumber(algorithmDuration/(double)1000000000); 239 240 OWLClassExpression concept = la.getCurrentlyBestDescription(); 241 242 Set<OWLIndividual> tmp = rs.hasType(concept, testPos); 243 Set<OWLIndividual> tmp2 = Sets.difference(testPos, tmp); 244 Set<OWLIndividual> tmp3 = rs.hasType(concept, testNeg); 245 246 // calculate training accuracies 247 int trainingCorrectPosClassified = getCorrectPosClassified(rs, concept, trainPos); 248 int trainingCorrectNegClassified = getCorrectNegClassified(rs, concept, trainNeg); 249 int trainingCorrectExamples = trainingCorrectPosClassified + trainingCorrectNegClassified; 250 double trainingAccuracy = 100*((double)trainingCorrectExamples/(trainPos.size()+ 251 trainNeg.size())); 252 accuracyTraining.addNumber(trainingAccuracy); 253 // calculate test accuracies 254 int correctPosClassified = getCorrectPosClassified(rs, concept, testPos); 255 int correctNegClassified = getCorrectNegClassified(rs, concept, testNeg); 256 int correctExamples = correctPosClassified + correctNegClassified; 257 double currAccuracy = 100*((double)correctExamples/(testPos.size()+ 258 testNeg.size())); 259 accuracy.addNumber(currAccuracy); 260 // calculate training F-Score 261 int negAsPosTraining = rs.hasType(concept, trainNeg).size(); 262 double precisionTraining = trainingCorrectPosClassified + negAsPosTraining == 0 ? 0 : trainingCorrectPosClassified / (double) (trainingCorrectPosClassified + negAsPosTraining); 263 double recallTraining = trainingCorrectPosClassified / (double) trainPos.size(); 264 fMeasureTraining.addNumber(100*Heuristics.getFScore(recallTraining, precisionTraining)); 265 // calculate test F-Score 266 int negAsPos = rs.hasType(concept, testNeg).size(); 267 double precision = correctPosClassified + negAsPos == 0 ? 0 : correctPosClassified / (double) (correctPosClassified + negAsPos); 268 double recall = correctPosClassified / (double) testPos.size(); 269// System.out.println(precision);System.out.println(recall); 270 fMeasure.addNumber(100*Heuristics.getFScore(recall, precision)); 271 272 length.addNumber(OWLClassExpressionUtils.getLength(concept)); 273 274 275 output += "test set errors pos: " + tmp2 + "\n"; 276 output += "test set errors neg: " + tmp3 + "\n"; 277 output += "fold " + currFold + ":" + "\n"; 278 output += " training: " + pos.size() + " positive and " + neg.size() + " negative examples"; 279 output += " testing: " + correctPosClassified + "/" + testPos.size() + " correct positives, " 280 + correctNegClassified + "/" + testNeg.size() + " correct negatives" + "\n"; 281 output += " concept: " + concept.toString().replace("\n", " ") + "\n"; 282 output += " accuracy: " + df.format(currAccuracy) + "% (" + df.format(trainingAccuracy) + "% on training set)" + "\n"; 283 output += " length: " + df.format(OWLClassExpressionUtils.getLength(concept)) + "\n"; 284 output += " runtime: " + df.format(algorithmDuration/(double)1000000000) + "s" + "\n"; 285 286 outputWriter(output); 287 } 288 289 protected int getCorrectPosClassified(AbstractReasonerComponent rs, OWLClassExpression concept, Set<OWLIndividual> testSetPos) { 290 return rs.hasType(concept, testSetPos).size(); 291 } 292 293 protected int getCorrectNegClassified(AbstractReasonerComponent rs, OWLClassExpression concept, Set<OWLIndividual> testSetNeg) { 294 return testSetNeg.size() - rs.hasType(concept, testSetNeg).size(); 295 } 296 297 public static Set<OWLIndividual> getTestingSet(List<OWLIndividual> examples, int[] splits, int fold) { 298 int fromIndex; 299 // we either start from 0 or after the last fold ended 300 if(fold == 0) 301 fromIndex = 0; 302 else 303 fromIndex = splits[fold-1]; 304 // the split corresponds to the ends of the folds 305 int toIndex = splits[fold]; 306 307// System.out.println("from " + fromIndex + " to " + toIndex); 308 309 Set<OWLIndividual> testingSet = new HashSet<>(); 310 // +1 because 2nd element is exclusive in subList method 311 testingSet.addAll(examples.subList(fromIndex, toIndex)); 312 return testingSet; 313 } 314 315 public static Set<OWLIndividual> getTrainingSet(Set<OWLIndividual> examples, Set<OWLIndividual> testingSet) { 316 return Sets.difference(examples, testingSet); 317 } 318 319 // takes nr. of examples and the nr. of folds for this examples; 320 // returns an array which says where each fold ends, i.e. 321 // splits[i] is the index of the last element of fold i in the examples 322 public static int[] calculateSplits(int nrOfExamples, int folds) { 323 int[] splits = new int[folds]; 324 for(int i=1; i<=folds; i++) { 325 // we always round up to the next integer 326 splits[i-1] = (int)Math.ceil(i*nrOfExamples/(double)folds); 327 } 328 return splits; 329 } 330 331 public static String statOutput(DecimalFormat df, Stat stat, String unit) { 332 String str = "av. " + df.format(stat.getMean()) + unit; 333 str += " (deviation " + df.format(stat.getStandardDeviation()) + unit + "; "; 334 str += "min " + df.format(stat.getMin()) + unit + "; "; 335 str += "max " + df.format(stat.getMax()) + unit + ")"; 336 return str; 337 } 338 339 public Stat getAccuracy() { 340 return accuracy; 341 } 342 343 public Stat getLength() { 344 return length; 345 } 346 347 public Stat getRuntime() { 348 return runtime; 349 } 350 351 protected void outputWriter(String output) { 352 if(writeToFile) { 353 Files.appendToFile(outputFile, output +"\n"); 354 System.out.println(output); 355 } else { 356 System.out.println(output); 357 } 358 359 } 360 361 public Stat getfMeasure() { 362 return fMeasure; 363 } 364 365 public Stat getfMeasureTraining() { 366 return fMeasureTraining; 367 } 368 369}