001/** 002 * Copyright (C) 2007-2008, Jens Lehmann 003 * 004 * This file is part of DL-Learner. 005 * 006 * DL-Learner is free software; you can redistribute it and/or modify 007 * it under the terms of the GNU General Public License as published by 008 * the Free Software Foundation; either version 3 of the License, or 009 * (at your option) any later version. 010 * 011 * DL-Learner is distributed in the hope that it will be useful, 012 * but WITHOUT ANY WARRANTY; without even the implied warranty of 013 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 014 * GNU General Public License for more details. 015 * 016 * You should have received a copy of the GNU General Public License 017 * along with this program. If not, see <http://www.gnu.org/licenses/>. 018 * 019 */ 020package org.dllearner.cli; 021 022import com.google.common.collect.Sets; 023import org.dllearner.algorithms.qtl.QTL2Disjunctive; 024import org.dllearner.algorithms.qtl.datastructures.impl.QueryTreeImpl.LiteralNodeSubsumptionStrategy; 025import org.dllearner.algorithms.qtl.datastructures.impl.RDFResourceTree; 026import org.dllearner.core.AbstractClassExpressionLearningProblem; 027import org.dllearner.core.ComponentInitException; 028import org.dllearner.core.IndividualReasoner; 029import org.dllearner.learningproblems.Heuristics; 030import org.dllearner.learningproblems.PosNegLP; 031import org.dllearner.learningproblems.PosOnlyLP; 032import org.dllearner.reasoning.SPARQLReasoner; 033import org.dllearner.utilities.Files; 034import org.dllearner.utilities.Helper; 035import org.dllearner.utilities.owl.OWLClassExpressionUtils; 036import org.dllearner.utilities.statistics.Stat; 037import org.semanticweb.owlapi.model.OWLClassExpression; 038import org.semanticweb.owlapi.model.OWLIndividual; 039 040import java.io.File; 041import java.text.DecimalFormat; 042import java.util.*; 043 044/** 045 * Performs cross validation for the given problem. Supports 046 * k-fold cross-validation and leave-one-out cross-validation. 047 * 048 * @author Jens Lehmann 049 * 050 */ 051public class SPARQLCrossValidation { 052 053 // statistical values 054 protected Stat runtime = new Stat(); 055 protected Stat accuracy = new Stat(); 056 protected Stat length = new Stat(); 057 protected Stat accuracyTraining = new Stat(); 058 protected Stat fMeasure = new Stat(); 059 protected Stat fMeasureTraining = new Stat(); 060 protected static boolean writeToFile = false; 061 protected static File outputFile; 062 063 064 protected Stat trainingCompletenessStat = new Stat(); 065 protected Stat trainingCorrectnessStat = new Stat(); 066 067 protected Stat testingCompletenessStat = new Stat(); 068 protected Stat testingCorrectnessStat = new Stat(); 069 070 LiteralNodeSubsumptionStrategy literalNodeSubsumptionStrategy = LiteralNodeSubsumptionStrategy.INTERVAL; 071 072 public SPARQLCrossValidation() { 073 074 } 075 076 public SPARQLCrossValidation(QTL2Disjunctive la, AbstractClassExpressionLearningProblem lp, IndividualReasoner rs, int folds, boolean leaveOneOut) { 077 078 DecimalFormat df = new DecimalFormat(); 079 080 // the training and test sets used later on 081 List<Set<OWLIndividual>> trainingSetsPos = new LinkedList<>(); 082 List<Set<OWLIndividual>> trainingSetsNeg = new LinkedList<>(); 083 List<Set<OWLIndividual>> testSetsPos = new LinkedList<>(); 084 List<Set<OWLIndividual>> testSetsNeg = new LinkedList<>(); 085 086 // get examples and shuffle them too 087 Set<OWLIndividual> posExamples; 088 Set<OWLIndividual> negExamples; 089 if(lp instanceof PosNegLP){ 090 posExamples = ((PosNegLP)lp).getPositiveExamples(); 091 negExamples = ((PosNegLP)lp).getNegativeExamples(); 092 } else if(lp instanceof PosOnlyLP){ 093 posExamples = ((PosNegLP)lp).getPositiveExamples(); 094 negExamples = new HashSet<>(); 095 } else { 096 throw new IllegalArgumentException("Only PosNeg and PosOnly learning problems are supported"); 097 } 098 List<OWLIndividual> posExamplesList = new LinkedList<>(posExamples); 099 List<OWLIndividual> negExamplesList = new LinkedList<>(negExamples); 100 Collections.shuffle(posExamplesList, new Random(1)); 101 Collections.shuffle(negExamplesList, new Random(2)); 102 103 // sanity check whether nr. of folds makes sense for this benchmark 104 if(!leaveOneOut && (posExamples.size()<folds && negExamples.size()<folds)) { 105 System.out.println("The number of folds is higher than the number of " 106 + "positive/negative examples. This can result in empty test sets. Exiting."); 107 System.exit(0); 108 } 109 110 if(leaveOneOut) { 111 // note that leave-one-out is not identical to k-fold with 112 // k = nr. of examples in the current implementation, because 113 // with n folds and n examples there is no guarantee that a fold 114 // is never empty (this is an implementation issue) 115 int nrOfExamples = posExamples.size() + negExamples.size(); 116 for(int i = 0; i < nrOfExamples; i++) { 117 // ... 118 } 119 System.out.println("Leave-one-out not supported yet."); 120 System.exit(1); 121 } else { 122 // calculating where to split the sets, ; note that we split 123 // positive and negative examples separately such that the 124 // distribution of positive and negative examples remains similar 125 // (note that there are better but more complex ways to implement this, 126 // which guarantee that the sum of the elements of a fold for pos 127 // and neg differs by at most 1 - it can differ by 2 in our implementation, 128 // e.g. with 3 folds, 4 pos. examples, 4 neg. examples) 129 int[] splitsPos = calculateSplits(posExamples.size(),folds); 130 int[] splitsNeg = calculateSplits(negExamples.size(),folds); 131 132// System.out.println(splitsPos[0]); 133// System.out.println(splitsNeg[0]); 134 135 // calculating training and test sets 136 for(int i=0; i<folds; i++) { 137 Set<OWLIndividual> testPos = getTestingSet(posExamplesList, splitsPos, i); 138 Set<OWLIndividual> testNeg = getTestingSet(negExamplesList, splitsNeg, i); 139 testSetsPos.add(i, testPos); 140 testSetsNeg.add(i, testNeg); 141 trainingSetsPos.add(i, getTrainingSet(posExamples, testPos)); 142 trainingSetsNeg.add(i, getTrainingSet(negExamples, testNeg)); 143 } 144 145 } 146 147 // run the algorithm 148 for(int currFold=0; currFold<folds; currFold++) { 149 150 Set<String> pos = Helper.getStringSet(trainingSetsPos.get(currFold)); 151 Set<String> neg = Helper.getStringSet(trainingSetsNeg.get(currFold)); 152 if(lp instanceof PosNegLP){ 153 ((PosNegLP)lp).setPositiveExamples(trainingSetsPos.get(currFold)); 154 ((PosNegLP)lp).setNegativeExamples(trainingSetsNeg.get(currFold)); 155 } else if(lp instanceof PosOnlyLP){ 156 ((PosOnlyLP)lp).setPositiveExamples(new TreeSet<>(trainingSetsPos.get(currFold))); 157 } 158 159 160 try { 161 lp.init(); 162 la.init(); 163 } catch (ComponentInitException e) { 164 // TODO Auto-generated catch block 165 e.printStackTrace(); 166 } 167 168 long algorithmStartTime = System.nanoTime(); 169 la.start(); 170 long algorithmDuration = System.nanoTime() - algorithmStartTime; 171 runtime.addNumber(algorithmDuration/(double)1000000000); 172 173 OWLClassExpression concept = la.getCurrentlyBestDescription(); 174 System.out.println(concept); 175// Set<OWLIndividual> tmp = rs.hasType(concept, testSetsPos.get(currFold)); 176 Set<OWLIndividual> tmp = hasType(testSetsPos.get(currFold), la); 177 Set<OWLIndividual> tmp2 = Sets.difference(testSetsPos.get(currFold), tmp); 178// Set<OWLIndividual> tmp3 = rs.hasType(concept, testSetsNeg.get(currFold)); 179 Set<OWLIndividual> tmp3 = hasType(testSetsNeg.get(currFold), la); 180 181 outputWriter("test set errors pos: " + tmp2); 182 outputWriter("test set errors neg: " + tmp3); 183 184 // calculate training accuracies 185 System.out.println(getCorrectPosClassified(rs, concept, trainingSetsPos.get(currFold))); 186// int trainingCorrectPosClassified = getCorrectPosClassified(rs, concept, trainingSetsPos.get(currFold)); 187 int trainingCorrectPosClassified = getCorrectPosClassified(trainingSetsPos.get(currFold), la); 188// int trainingCorrectNegClassified = getCorrectNegClassified(rs, concept, trainingSetsNeg.get(currFold)); 189 int trainingCorrectNegClassified = getCorrectNegClassified(trainingSetsNeg.get(currFold), la); 190 int trainingCorrectExamples = trainingCorrectPosClassified + trainingCorrectNegClassified; 191 double trainingAccuracy = 100*((double)trainingCorrectExamples/(trainingSetsPos.get(currFold).size()+ 192 trainingSetsNeg.get(currFold).size())); 193 accuracyTraining.addNumber(trainingAccuracy); 194 // calculate test accuracies 195// int correctPosClassified = getCorrectPosClassified(rs, concept, testSetsPos.get(currFold)); 196 int correctPosClassified = getCorrectPosClassified(testSetsPos.get(currFold), la); 197// int correctNegClassified = getCorrectNegClassified(rs, concept, testSetsNeg.get(currFold)); 198 int correctNegClassified = getCorrectNegClassified(testSetsNeg.get(currFold), la); 199 int correctExamples = correctPosClassified + correctNegClassified; 200 double currAccuracy = 100*((double)correctExamples/(testSetsPos.get(currFold).size()+ 201 testSetsNeg.get(currFold).size())); 202 accuracy.addNumber(currAccuracy); 203 // calculate training F-Score 204// int negAsPosTraining = rs.hasType(concept, trainingSetsNeg.get(currFold)).size(); 205 int negAsPosTraining = trainingSetsNeg.get(currFold).size() - trainingCorrectNegClassified; 206 double precisionTraining = trainingCorrectPosClassified + negAsPosTraining == 0 ? 0 : trainingCorrectPosClassified / (double) (trainingCorrectPosClassified + negAsPosTraining); 207 double recallTraining = trainingCorrectPosClassified / (double) trainingSetsPos.get(currFold).size(); 208 fMeasureTraining.addNumber(100*Heuristics.getFScore(recallTraining, precisionTraining)); 209 // calculate test F-Score 210// int negAsPos = rs.hasType(concept, testSetsNeg.get(currFold)).size(); 211 int negAsPos = testSetsNeg.get(currFold).size() - correctNegClassified; 212 double precision = correctPosClassified + negAsPos == 0 ? 0 : correctPosClassified / (double) (correctPosClassified + negAsPos); 213 double recall = correctPosClassified / (double) testSetsPos.get(currFold).size(); 214// System.out.println(precision);System.out.println(recall); 215 fMeasure.addNumber(100*Heuristics.getFScore(recall, precision)); 216 217 length.addNumber(OWLClassExpressionUtils.getLength(concept)); 218 219 outputWriter("fold " + currFold + ":"); 220 outputWriter(" training: " + pos.size() + " positive and " + neg.size() + " negative examples"); 221 outputWriter(" testing: " + correctPosClassified + "/" + testSetsPos.get(currFold).size() + " correct positives, " 222 + correctNegClassified + "/" + testSetsNeg.get(currFold).size() + " correct negatives"); 223 outputWriter(" concept: " + concept); 224 outputWriter(" accuracy: " + df.format(currAccuracy) + "% (" + df.format(trainingAccuracy) + "% on training set)"); 225 outputWriter(" length: " + df.format(OWLClassExpressionUtils.getLength(concept))); 226 outputWriter(" runtime: " + df.format(algorithmDuration/(double)1000000000) + "s"); 227 228 } 229 230 outputWriter(""); 231 outputWriter("Finished " + folds + "-folds cross-validation."); 232 outputWriter("runtime: " + statOutput(df, runtime, "s")); 233 outputWriter("length: " + statOutput(df, length, "")); 234 outputWriter("F-Measure on training set: " + statOutput(df, fMeasureTraining, "%")); 235 outputWriter("F-Measure: " + statOutput(df, fMeasure, "%")); 236 outputWriter("predictive accuracy on training set: " + statOutput(df, accuracyTraining, "%")); 237 outputWriter("predictive accuracy: " + statOutput(df, accuracy, "%")); 238 239 } 240 241 protected int getCorrectPosClassified(IndividualReasoner rs, OWLClassExpression concept, Set<OWLIndividual> testSetPos) { 242 return rs.hasType(concept, testSetPos).size(); 243 } 244 245 protected Set<OWLIndividual> hasType(Set<OWLIndividual> individuals, QTL2Disjunctive qtl) { 246 Set<OWLIndividual> coveredIndividuals = new HashSet<>(); 247 RDFResourceTree solutionTree = qtl.getBestSolution().getTree(); 248 249 for (OWLIndividual ind : individuals) { 250 throw new RuntimeException("Not implemented yet."); 251 } 252 return coveredIndividuals; 253 } 254 255 protected int getCorrectPosClassified(Set<OWLIndividual> testSetPos, QTL2Disjunctive qtl) { 256 return qtl.getBestSolution().getTreeScore().getCoveredPositives().size(); 257 } 258 259 protected int getCorrectNegClassified(SPARQLReasoner rs, OWLClassExpression concept, Set<OWLIndividual> testSetNeg) { 260 return testSetNeg.size() - rs.hasType(concept, testSetNeg).size(); 261 } 262 263 protected int getCorrectNegClassified(Set<OWLIndividual> testSetNeg, QTL2Disjunctive qtl) { 264 return qtl.getBestSolution().getTreeScore().getNotCoveredNegatives().size(); 265 } 266 267 public static Set<OWLIndividual> getTestingSet(List<OWLIndividual> examples, int[] splits, int fold) { 268 int fromIndex; 269 // we either start from 0 or after the last fold ended 270 if(fold == 0) 271 fromIndex = 0; 272 else 273 fromIndex = splits[fold-1]; 274 // the split corresponds to the ends of the folds 275 int toIndex = splits[fold]; 276 277// System.out.println("from " + fromIndex + " to " + toIndex); 278 279 Set<OWLIndividual> testingSet = new HashSet<>(); 280 // +1 because 2nd element is exclusive in subList method 281 testingSet.addAll(examples.subList(fromIndex, toIndex)); 282 return testingSet; 283 } 284 285 public static Set<OWLIndividual> getTrainingSet(Set<OWLIndividual> examples, Set<OWLIndividual> testingSet) { 286 return Sets.difference(examples, testingSet); 287 } 288 289 // takes nr. of examples and the nr. of folds for this examples; 290 // returns an array which says where each fold ends, i.e. 291 // splits[i] is the index of the last element of fold i in the examples 292 public static int[] calculateSplits(int nrOfExamples, int folds) { 293 int[] splits = new int[folds]; 294 for(int i=1; i<=folds; i++) { 295 // we always round up to the next integer 296 splits[i-1] = (int)Math.ceil(i*nrOfExamples/(double)folds); 297 } 298 return splits; 299 } 300 301 public static String statOutput(DecimalFormat df, Stat stat, String unit) { 302 String str = "av. " + df.format(stat.getMean()) + unit; 303 str += " (deviation " + df.format(stat.getStandardDeviation()) + unit + "; "; 304 str += "min " + df.format(stat.getMin()) + unit + "; "; 305 str += "max " + df.format(stat.getMax()) + unit + ")"; 306 return str; 307 } 308 309 public Stat getAccuracy() { 310 return accuracy; 311 } 312 313 public Stat getLength() { 314 return length; 315 } 316 317 public Stat getRuntime() { 318 return runtime; 319 } 320 321 protected void outputWriter(String output) { 322 if(writeToFile) { 323 Files.appendToFile(outputFile, output +"\n"); 324 System.out.println(output); 325 } else { 326 System.out.println(output); 327 } 328 329 } 330 331 public Stat getfMeasure() { 332 return fMeasure; 333 } 334 335 public Stat getfMeasureTraining() { 336 return fMeasureTraining; 337 } 338 339}