001/**
002 * Copyright (C) 2007-2008, Jens Lehmann
003 *
004 * This file is part of DL-Learner.
005 * 
006 * DL-Learner is free software; you can redistribute it and/or modify
007 * it under the terms of the GNU General Public License as published by
008 * the Free Software Foundation; either version 3 of the License, or
009 * (at your option) any later version.
010 *
011 * DL-Learner is distributed in the hope that it will be useful,
012 * but WITHOUT ANY WARRANTY; without even the implied warranty of
013 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
014 * GNU General Public License for more details.
015 *
016 * You should have received a copy of the GNU General Public License
017 * along with this program.  If not, see <http://www.gnu.org/licenses/>.
018 *
019 */
020package org.dllearner.cli;
021
022import com.google.common.collect.Sets;
023import org.dllearner.algorithms.qtl.QTL2Disjunctive;
024import org.dllearner.algorithms.qtl.datastructures.impl.QueryTreeImpl.LiteralNodeSubsumptionStrategy;
025import org.dllearner.algorithms.qtl.datastructures.impl.RDFResourceTree;
026import org.dllearner.core.AbstractClassExpressionLearningProblem;
027import org.dllearner.core.ComponentInitException;
028import org.dllearner.core.IndividualReasoner;
029import org.dllearner.learningproblems.Heuristics;
030import org.dllearner.learningproblems.PosNegLP;
031import org.dllearner.learningproblems.PosOnlyLP;
032import org.dllearner.reasoning.SPARQLReasoner;
033import org.dllearner.utilities.Files;
034import org.dllearner.utilities.Helper;
035import org.dllearner.utilities.owl.OWLClassExpressionUtils;
036import org.dllearner.utilities.statistics.Stat;
037import org.semanticweb.owlapi.model.OWLClassExpression;
038import org.semanticweb.owlapi.model.OWLIndividual;
039
040import java.io.File;
041import java.text.DecimalFormat;
042import java.util.*;
043
044/**
045 * Performs cross validation for the given problem. Supports
046 * k-fold cross-validation and leave-one-out cross-validation.
047 * 
048 * @author Jens Lehmann
049 *
050 */
051public class SPARQLCrossValidation {
052
053        // statistical values
054        protected Stat runtime = new Stat();
055        protected Stat accuracy = new Stat();
056        protected Stat length = new Stat();
057        protected Stat accuracyTraining = new Stat();
058        protected Stat fMeasure = new Stat();
059        protected Stat fMeasureTraining = new Stat();
060        protected static boolean writeToFile = false;
061        protected static File outputFile;
062        
063        
064        protected Stat trainingCompletenessStat = new Stat();
065        protected Stat trainingCorrectnessStat = new Stat();
066        
067        protected Stat testingCompletenessStat = new Stat();
068        protected Stat testingCorrectnessStat = new Stat();
069        
070        LiteralNodeSubsumptionStrategy literalNodeSubsumptionStrategy = LiteralNodeSubsumptionStrategy.INTERVAL;
071        
072        public SPARQLCrossValidation() {
073                
074        }
075        
076        public SPARQLCrossValidation(QTL2Disjunctive la, AbstractClassExpressionLearningProblem lp, IndividualReasoner rs, int folds, boolean leaveOneOut) {
077                
078                DecimalFormat df = new DecimalFormat();
079
080                // the training and test sets used later on
081                List<Set<OWLIndividual>> trainingSetsPos = new LinkedList<>();
082                List<Set<OWLIndividual>> trainingSetsNeg = new LinkedList<>();
083                List<Set<OWLIndividual>> testSetsPos = new LinkedList<>();
084                List<Set<OWLIndividual>> testSetsNeg = new LinkedList<>();
085                
086                        // get examples and shuffle them too
087                Set<OWLIndividual> posExamples;
088                Set<OWLIndividual> negExamples;
089                        if(lp instanceof PosNegLP){
090                                posExamples = ((PosNegLP)lp).getPositiveExamples();
091                                negExamples = ((PosNegLP)lp).getNegativeExamples();
092                        } else if(lp instanceof PosOnlyLP){
093                                posExamples = ((PosNegLP)lp).getPositiveExamples();
094                                negExamples = new HashSet<>();
095                        } else {
096                                throw new IllegalArgumentException("Only PosNeg and PosOnly learning problems are supported");
097                        }
098                        List<OWLIndividual> posExamplesList = new LinkedList<>(posExamples);
099                        List<OWLIndividual> negExamplesList = new LinkedList<>(negExamples);
100                        Collections.shuffle(posExamplesList, new Random(1));
101                        Collections.shuffle(negExamplesList, new Random(2));
102                        
103                        // sanity check whether nr. of folds makes sense for this benchmark
104                        if(!leaveOneOut && (posExamples.size()<folds && negExamples.size()<folds)) {
105                                System.out.println("The number of folds is higher than the number of "
106                                                + "positive/negative examples. This can result in empty test sets. Exiting.");
107                                System.exit(0);
108                        }
109                        
110                        if(leaveOneOut) {
111                                // note that leave-one-out is not identical to k-fold with
112                                // k = nr. of examples in the current implementation, because
113                                // with n folds and n examples there is no guarantee that a fold
114                                // is never empty (this is an implementation issue)
115                                int nrOfExamples = posExamples.size() + negExamples.size();
116                                for(int i = 0; i < nrOfExamples; i++) {
117                                        // ...
118                                }
119                                System.out.println("Leave-one-out not supported yet.");
120                                System.exit(1);
121                        } else {
122                                // calculating where to split the sets, ; note that we split
123                                // positive and negative examples separately such that the
124                                // distribution of positive and negative examples remains similar
125                                // (note that there are better but more complex ways to implement this,
126                                // which guarantee that the sum of the elements of a fold for pos
127                                // and neg differs by at most 1 - it can differ by 2 in our implementation,
128                                // e.g. with 3 folds, 4 pos. examples, 4 neg. examples)
129                                int[] splitsPos = calculateSplits(posExamples.size(),folds);
130                                int[] splitsNeg = calculateSplits(negExamples.size(),folds);
131                                
132//                              System.out.println(splitsPos[0]);
133//                              System.out.println(splitsNeg[0]);
134                                
135                                // calculating training and test sets
136                                for(int i=0; i<folds; i++) {
137                                        Set<OWLIndividual> testPos = getTestingSet(posExamplesList, splitsPos, i);
138                                        Set<OWLIndividual> testNeg = getTestingSet(negExamplesList, splitsNeg, i);
139                                        testSetsPos.add(i, testPos);
140                                        testSetsNeg.add(i, testNeg);
141                                        trainingSetsPos.add(i, getTrainingSet(posExamples, testPos));
142                                        trainingSetsNeg.add(i, getTrainingSet(negExamples, testNeg));
143                                }
144                                
145                        }
146
147                // run the algorithm
148                for(int currFold=0; currFold<folds; currFold++) {
149
150                        Set<String> pos = Helper.getStringSet(trainingSetsPos.get(currFold));
151                        Set<String> neg = Helper.getStringSet(trainingSetsNeg.get(currFold));
152                        if(lp instanceof PosNegLP){
153                                ((PosNegLP)lp).setPositiveExamples(trainingSetsPos.get(currFold));
154                                ((PosNegLP)lp).setNegativeExamples(trainingSetsNeg.get(currFold));
155                        } else if(lp instanceof PosOnlyLP){
156                                ((PosOnlyLP)lp).setPositiveExamples(new TreeSet<>(trainingSetsPos.get(currFold)));
157                        }
158                        
159
160                        try {
161                                lp.init();
162                                la.init();
163                        } catch (ComponentInitException e) {
164                                // TODO Auto-generated catch block
165                                e.printStackTrace();
166                        }
167                        
168                        long algorithmStartTime = System.nanoTime();
169                        la.start();
170                        long algorithmDuration = System.nanoTime() - algorithmStartTime;
171                        runtime.addNumber(algorithmDuration/(double)1000000000);
172                        
173                        OWLClassExpression concept = la.getCurrentlyBestDescription();
174                        System.out.println(concept);
175//                      Set<OWLIndividual> tmp = rs.hasType(concept, testSetsPos.get(currFold));
176                        Set<OWLIndividual> tmp = hasType(testSetsPos.get(currFold), la);
177                        Set<OWLIndividual> tmp2 = Sets.difference(testSetsPos.get(currFold), tmp);
178//                      Set<OWLIndividual> tmp3 = rs.hasType(concept, testSetsNeg.get(currFold));
179                        Set<OWLIndividual> tmp3 = hasType(testSetsNeg.get(currFold), la);
180                        
181                        outputWriter("test set errors pos: " + tmp2);
182                        outputWriter("test set errors neg: " + tmp3);
183                        
184                        // calculate training accuracies
185                        System.out.println(getCorrectPosClassified(rs, concept, trainingSetsPos.get(currFold)));
186//                      int trainingCorrectPosClassified = getCorrectPosClassified(rs, concept, trainingSetsPos.get(currFold));
187                        int trainingCorrectPosClassified = getCorrectPosClassified(trainingSetsPos.get(currFold), la);
188//                      int trainingCorrectNegClassified = getCorrectNegClassified(rs, concept, trainingSetsNeg.get(currFold));
189                        int trainingCorrectNegClassified = getCorrectNegClassified(trainingSetsNeg.get(currFold), la);
190                        int trainingCorrectExamples = trainingCorrectPosClassified + trainingCorrectNegClassified;
191                        double trainingAccuracy = 100*((double)trainingCorrectExamples/(trainingSetsPos.get(currFold).size()+
192                                        trainingSetsNeg.get(currFold).size()));
193                        accuracyTraining.addNumber(trainingAccuracy);
194                        // calculate test accuracies
195//                      int correctPosClassified = getCorrectPosClassified(rs, concept, testSetsPos.get(currFold));
196                        int correctPosClassified = getCorrectPosClassified(testSetsPos.get(currFold), la);
197//                      int correctNegClassified = getCorrectNegClassified(rs, concept, testSetsNeg.get(currFold));
198                        int correctNegClassified = getCorrectNegClassified(testSetsNeg.get(currFold), la);
199                        int correctExamples = correctPosClassified + correctNegClassified;
200                        double currAccuracy = 100*((double)correctExamples/(testSetsPos.get(currFold).size()+
201                                        testSetsNeg.get(currFold).size()));
202                        accuracy.addNumber(currAccuracy);
203                        // calculate training F-Score
204//                      int negAsPosTraining = rs.hasType(concept, trainingSetsNeg.get(currFold)).size();
205                        int negAsPosTraining = trainingSetsNeg.get(currFold).size() - trainingCorrectNegClassified;
206                        double precisionTraining = trainingCorrectPosClassified + negAsPosTraining == 0 ? 0 : trainingCorrectPosClassified / (double) (trainingCorrectPosClassified + negAsPosTraining);
207                        double recallTraining = trainingCorrectPosClassified / (double) trainingSetsPos.get(currFold).size();
208                        fMeasureTraining.addNumber(100*Heuristics.getFScore(recallTraining, precisionTraining));
209                        // calculate test F-Score
210//                      int negAsPos = rs.hasType(concept, testSetsNeg.get(currFold)).size();
211                        int negAsPos = testSetsNeg.get(currFold).size() - correctNegClassified;
212                        double precision = correctPosClassified + negAsPos == 0 ? 0 : correctPosClassified / (double) (correctPosClassified + negAsPos);
213                        double recall = correctPosClassified / (double) testSetsPos.get(currFold).size();
214//                      System.out.println(precision);System.out.println(recall);
215                        fMeasure.addNumber(100*Heuristics.getFScore(recall, precision));
216                        
217                        length.addNumber(OWLClassExpressionUtils.getLength(concept));
218                        
219                        outputWriter("fold " + currFold + ":");
220                        outputWriter("  training: " + pos.size() + " positive and " + neg.size() + " negative examples");
221                        outputWriter("  testing: " + correctPosClassified + "/" + testSetsPos.get(currFold).size() + " correct positives, "
222                                        + correctNegClassified + "/" + testSetsNeg.get(currFold).size() + " correct negatives");
223                        outputWriter("  concept: " + concept);
224                        outputWriter("  accuracy: " + df.format(currAccuracy) + "% (" + df.format(trainingAccuracy) + "% on training set)");
225                        outputWriter("  length: " + df.format(OWLClassExpressionUtils.getLength(concept)));
226                        outputWriter("  runtime: " + df.format(algorithmDuration/(double)1000000000) + "s");
227                                        
228                }
229                
230                outputWriter("");
231                outputWriter("Finished " + folds + "-folds cross-validation.");
232                outputWriter("runtime: " + statOutput(df, runtime, "s"));
233                outputWriter("length: " + statOutput(df, length, ""));
234                outputWriter("F-Measure on training set: " + statOutput(df, fMeasureTraining, "%"));
235                outputWriter("F-Measure: " + statOutput(df, fMeasure, "%"));
236                outputWriter("predictive accuracy on training set: " + statOutput(df, accuracyTraining, "%"));
237                outputWriter("predictive accuracy: " + statOutput(df, accuracy, "%"));
238                        
239        }
240        
241        protected int getCorrectPosClassified(IndividualReasoner rs, OWLClassExpression concept, Set<OWLIndividual> testSetPos) {
242                return rs.hasType(concept, testSetPos).size();
243        }
244        
245        protected Set<OWLIndividual> hasType(Set<OWLIndividual> individuals, QTL2Disjunctive qtl) {
246                Set<OWLIndividual> coveredIndividuals = new HashSet<>();
247                RDFResourceTree solutionTree = qtl.getBestSolution().getTree();
248                
249                for (OWLIndividual ind : individuals) {
250                        throw new RuntimeException("Not implemented yet.");
251                }
252                return coveredIndividuals;
253        }
254        
255        protected int getCorrectPosClassified(Set<OWLIndividual> testSetPos, QTL2Disjunctive qtl) {
256                return qtl.getBestSolution().getTreeScore().getCoveredPositives().size();
257        }
258        
259        protected int getCorrectNegClassified(SPARQLReasoner rs, OWLClassExpression concept, Set<OWLIndividual> testSetNeg) {
260                return testSetNeg.size() - rs.hasType(concept, testSetNeg).size();
261        }
262        
263        protected int getCorrectNegClassified(Set<OWLIndividual> testSetNeg, QTL2Disjunctive qtl) {
264                return qtl.getBestSolution().getTreeScore().getNotCoveredNegatives().size();
265        }
266        
267        public static Set<OWLIndividual> getTestingSet(List<OWLIndividual> examples, int[] splits, int fold) {
268                int fromIndex;
269                // we either start from 0 or after the last fold ended
270                if(fold == 0)
271                        fromIndex = 0;
272                else
273                        fromIndex = splits[fold-1];
274                // the split corresponds to the ends of the folds
275                int toIndex = splits[fold];
276                
277//              System.out.println("from " + fromIndex + " to " + toIndex);
278                
279                Set<OWLIndividual> testingSet = new HashSet<>();
280                // +1 because 2nd element is exclusive in subList method
281                testingSet.addAll(examples.subList(fromIndex, toIndex));
282                return testingSet;
283        }
284        
285        public static Set<OWLIndividual> getTrainingSet(Set<OWLIndividual> examples, Set<OWLIndividual> testingSet) {
286                return Sets.difference(examples, testingSet);
287        }
288        
289        // takes nr. of examples and the nr. of folds for this examples;
290        // returns an array which says where each fold ends, i.e.
291        // splits[i] is the index of the last element of fold i in the examples
292        public static int[] calculateSplits(int nrOfExamples, int folds) {
293                int[] splits = new int[folds];
294                for(int i=1; i<=folds; i++) {
295                        // we always round up to the next integer
296                        splits[i-1] = (int)Math.ceil(i*nrOfExamples/(double)folds);
297                }
298                return splits;
299        }
300        
301        public static String statOutput(DecimalFormat df, Stat stat, String unit) {
302                String str = "av. " + df.format(stat.getMean()) + unit;
303                str += " (deviation " + df.format(stat.getStandardDeviation()) + unit + "; ";
304                str += "min " + df.format(stat.getMin()) + unit + "; ";
305                str += "max " + df.format(stat.getMax()) + unit + ")";
306                return str;
307        }
308
309        public Stat getAccuracy() {
310                return accuracy;
311        }
312
313        public Stat getLength() {
314                return length;
315        }
316
317        public Stat getRuntime() {
318                return runtime;
319        }
320        
321        protected void outputWriter(String output) {
322                if(writeToFile) {
323                        Files.appendToFile(outputFile, output +"\n");
324                        System.out.println(output);
325                } else {
326                        System.out.println(output);
327                }
328                
329        }
330
331        public Stat getfMeasure() {
332                return fMeasure;
333        }
334
335        public Stat getfMeasureTraining() {
336                return fMeasureTraining;
337        }
338
339}