001/**
002 * Copyright (C) 2007-2008, Jens Lehmann
003 *
004 * This file is part of DL-Learner.
005 * 
006 * DL-Learner is free software; you can redistribute it and/or modify
007 * it under the terms of the GNU General Public License as published by
008 * the Free Software Foundation; either version 3 of the License, or
009 * (at your option) any later version.
010 *
011 * DL-Learner is distributed in the hope that it will be useful,
012 * but WITHOUT ANY WARRANTY; without even the implied warranty of
013 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
014 * GNU General Public License for more details.
015 *
016 * You should have received a copy of the GNU General Public License
017 * along with this program.  If not, see <http://www.gnu.org/licenses/>.
018 *
019 */
020package org.dllearner.cli;
021
022import com.google.common.collect.Sets;
023import org.dllearner.core.*;
024import org.dllearner.learningproblems.Heuristics;
025import org.dllearner.learningproblems.PosNegLP;
026import org.dllearner.learningproblems.PosOnlyLP;
027import org.dllearner.utilities.Files;
028import org.dllearner.utilities.Helper;
029import org.dllearner.utilities.owl.ManchesterOWLSyntaxOWLObjectRendererImplExt;
030import org.dllearner.utilities.owl.OWLClassExpressionUtils;
031import org.dllearner.utilities.statistics.Stat;
032import org.semanticweb.owlapi.model.OWLClassExpression;
033import org.semanticweb.owlapi.model.OWLIndividual;
034import org.semanticweb.owlapi.util.SimpleShortFormProvider;
035
036import java.io.File;
037import java.lang.reflect.InvocationTargetException;
038import java.text.DecimalFormat;
039import java.util.*;
040import java.util.concurrent.ExecutorService;
041import java.util.concurrent.Executors;
042import java.util.concurrent.TimeUnit;
043
044/**
045 * Performs cross validation for the given problem. Supports
046 * k-fold cross-validation and leave-one-out cross-validation.
047 * 
048 * @author Jens Lehmann
049 *
050 */
051public class CrossValidation {
052
053        // statistical values
054        protected Stat runtime = new Stat();
055        protected Stat accuracy = new Stat();
056        protected Stat length = new Stat();
057        protected Stat accuracyTraining = new Stat();
058        protected Stat fMeasure = new Stat();
059        protected Stat fMeasureTraining = new Stat();
060        public static boolean writeToFile = false;
061        public static File outputFile;
062        public static boolean multiThreaded = false;
063        
064        protected Stat trainingCompletenessStat = new Stat();
065        protected Stat trainingCorrectnessStat = new Stat();
066        
067        protected Stat testingCompletenessStat = new Stat();
068        protected Stat testingCorrectnessStat = new Stat();
069        
070        DecimalFormat df = new DecimalFormat();
071        
072        
073        
074        public CrossValidation() {
075                
076        }
077        
078        public CrossValidation(AbstractCELA la, AbstractClassExpressionLearningProblem lp, final AbstractReasonerComponent rs, int folds, boolean leaveOneOut) {
079                //console rendering of class expressions
080                ManchesterOWLSyntaxOWLObjectRendererImplExt renderer = new ManchesterOWLSyntaxOWLObjectRendererImplExt();
081                StringRenderer.setRenderer(renderer);
082                StringRenderer.setShortFormProvider(new SimpleShortFormProvider());
083                                
084                // the training and test sets used later on
085                List<Set<OWLIndividual>> trainingSetsPos = new LinkedList<>();
086                List<Set<OWLIndividual>> trainingSetsNeg = new LinkedList<>();
087                List<Set<OWLIndividual>> testSetsPos = new LinkedList<>();
088                List<Set<OWLIndividual>> testSetsNeg = new LinkedList<>();
089                
090                        // get examples and shuffle them too
091                Set<OWLIndividual> posExamples;
092                Set<OWLIndividual> negExamples;
093                        if(lp instanceof PosNegLP){
094                                posExamples = ((PosNegLP)lp).getPositiveExamples();
095                                negExamples = ((PosNegLP)lp).getNegativeExamples();
096                        } else if(lp instanceof PosOnlyLP){
097                                posExamples = ((PosNegLP)lp).getPositiveExamples();
098                                negExamples = new HashSet<>();
099                        } else {
100                                throw new IllegalArgumentException("Only PosNeg and PosOnly learning problems are supported");
101                        }
102                        List<OWLIndividual> posExamplesList = new LinkedList<>(posExamples);
103                        List<OWLIndividual> negExamplesList = new LinkedList<>(negExamples);
104                        Collections.shuffle(posExamplesList, new Random(1));
105                        Collections.shuffle(negExamplesList, new Random(2));
106                        
107                        // sanity check whether nr. of folds makes sense for this benchmark
108                        if(!leaveOneOut && (posExamples.size()<folds && negExamples.size()<folds)) {
109                                System.out.println("The number of folds is higher than the number of "
110                                                + "positive/negative examples. This can result in empty test sets. Exiting.");
111                                System.exit(0);
112                        }
113                        
114                        if(leaveOneOut) {
115                                // note that leave-one-out is not identical to k-fold with
116                                // k = nr. of examples in the current implementation, because
117                                // with n folds and n examples there is no guarantee that a fold
118                                // is never empty (this is an implementation issue)
119                                int nrOfExamples = posExamples.size() + negExamples.size();
120                                for(int i = 0; i < nrOfExamples; i++) {
121                                        // ...
122                                }
123                                System.out.println("Leave-one-out not supported yet.");
124                                System.exit(1);
125                        } else {
126                                // calculating where to split the sets, ; note that we split
127                                // positive and negative examples separately such that the
128                                // distribution of positive and negative examples remains similar
129                                // (note that there are better but more complex ways to implement this,
130                                // which guarantee that the sum of the elements of a fold for pos
131                                // and neg differs by at most 1 - it can differ by 2 in our implementation,
132                                // e.g. with 3 folds, 4 pos. examples, 4 neg. examples)
133                                int[] splitsPos = calculateSplits(posExamples.size(),folds);
134                                int[] splitsNeg = calculateSplits(negExamples.size(),folds);
135                                
136//                              System.out.println(splitsPos[0]);
137//                              System.out.println(splitsNeg[0]);
138                                
139                                // calculating training and test sets
140                                for(int i=0; i<folds; i++) {
141                                        Set<OWLIndividual> testPos = getTestingSet(posExamplesList, splitsPos, i);
142                                        Set<OWLIndividual> testNeg = getTestingSet(negExamplesList, splitsNeg, i);
143                                        testSetsPos.add(i, testPos);
144                                        testSetsNeg.add(i, testNeg);
145                                        trainingSetsPos.add(i, getTrainingSet(posExamples, testPos));
146                                        trainingSetsNeg.add(i, getTrainingSet(negExamples, testNeg));
147                                }
148                                
149                        }
150
151                // run the algorithm
152                        if( multiThreaded && lp instanceof Cloneable && la instanceof Cloneable){
153                                ExecutorService es = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors()-1);
154                                for(int currFold=0; currFold<folds; currFold++) {
155                                        try {
156                                                final AbstractClassExpressionLearningProblem lpClone = (AbstractClassExpressionLearningProblem) lp.getClass().getMethod("clone").invoke(lp);
157                                                final Set<OWLIndividual> trainPos = trainingSetsPos.get(currFold);
158                                                final Set<OWLIndividual> trainNeg = trainingSetsNeg.get(currFold);
159                                                final Set<OWLIndividual> testPos = testSetsPos.get(currFold);
160                                                final Set<OWLIndividual> testNeg = testSetsNeg.get(currFold);
161                                                if(lp instanceof PosNegLP){
162                                                        ((PosNegLP)lpClone).setPositiveExamples(trainPos);
163                                                        ((PosNegLP)lpClone).setNegativeExamples(trainNeg);
164                                                } else if(lp instanceof PosOnlyLP){
165                                                        ((PosOnlyLP)lpClone).setPositiveExamples(new TreeSet<>(trainPos));
166                                                }
167                                                final AbstractCELA laClone = (AbstractCELA) la.getClass().getMethod("clone").invoke(la);
168                                                final int i = currFold;
169                                                es.submit(new Runnable() {
170
171                                                        @Override
172                                                        public void run() {
173                                                                try {
174                                                                        validate(laClone, lpClone, rs, i, trainPos, trainNeg, testPos, testNeg);
175                                                                } catch (Exception e) {
176                                                                        e.printStackTrace();
177                                                                }
178                                                        }
179                                                });
180                                        } catch (IllegalAccessException | SecurityException | NoSuchMethodException | InvocationTargetException | IllegalArgumentException e) {
181                                                e.printStackTrace();
182                                        }
183                                }
184                                es.shutdown();
185                                try {
186                                        es.awaitTermination(1, TimeUnit.DAYS);
187                                } catch (InterruptedException e) {
188                                        e.printStackTrace();
189                                }
190                        } else {
191                                for(int currFold=0; currFold<folds; currFold++) {
192                                        final Set<OWLIndividual> trainPos = trainingSetsPos.get(currFold);
193                                        final Set<OWLIndividual> trainNeg = trainingSetsNeg.get(currFold);
194                                        final Set<OWLIndividual> testPos = testSetsPos.get(currFold);
195                                        final Set<OWLIndividual> testNeg = testSetsNeg.get(currFold);
196                                        
197                                        if(lp instanceof PosNegLP){
198                                                ((PosNegLP)lp).setPositiveExamples(trainPos);
199                                                ((PosNegLP)lp).setNegativeExamples(trainNeg);
200                                        } else if(lp instanceof PosOnlyLP){
201                                                ((PosOnlyLP)lp).setPositiveExamples(new TreeSet<>(trainPos));
202                                        }
203                                        
204                                        validate(la, lp, rs, currFold, trainPos, trainNeg, testPos, testNeg);
205                                }
206                        }
207                
208                outputWriter("");
209                outputWriter("Finished " + folds + "-folds cross-validation.");
210                outputWriter("runtime: " + statOutput(df, runtime, "s"));
211                outputWriter("length: " + statOutput(df, length, ""));
212                outputWriter("F-Measure on training set: " + statOutput(df, fMeasureTraining, "%"));
213                outputWriter("F-Measure: " + statOutput(df, fMeasure, "%"));
214                outputWriter("predictive accuracy on training set: " + statOutput(df, accuracyTraining, "%"));
215                outputWriter("predictive accuracy: " + statOutput(df, accuracy, "%"));
216                        
217        }
218        
219        private void validate(AbstractCELA la, AbstractClassExpressionLearningProblem lp, AbstractReasonerComponent rs,
220                        int currFold, Set<OWLIndividual> trainPos, Set<OWLIndividual> trainNeg, Set<OWLIndividual> testPos, Set<OWLIndividual> testNeg){
221                Set<String> pos = Helper.getStringSet(trainPos);
222                Set<String> neg = Helper.getStringSet(trainNeg);
223                String output = "";
224                output += "+" + new TreeSet<>(pos) + "\n";
225                output += "-" + new TreeSet<>(neg) + "\n";
226                try {
227                        lp.init();
228                        la.setLearningProblem(lp);
229                        la.init();
230                } catch (ComponentInitException e) {
231                        // TODO Auto-generated catch block
232                        e.printStackTrace();
233                }
234                
235                long algorithmStartTime = System.nanoTime();
236                la.start();
237                long algorithmDuration = System.nanoTime() - algorithmStartTime;
238                runtime.addNumber(algorithmDuration/(double)1000000000);
239                
240                OWLClassExpression concept = la.getCurrentlyBestDescription();
241                
242                Set<OWLIndividual> tmp = rs.hasType(concept, testPos);
243                Set<OWLIndividual> tmp2 = Sets.difference(testPos, tmp);
244                Set<OWLIndividual> tmp3 = rs.hasType(concept, testNeg);
245                
246                // calculate training accuracies
247                int trainingCorrectPosClassified = getCorrectPosClassified(rs, concept, trainPos);
248                int trainingCorrectNegClassified = getCorrectNegClassified(rs, concept, trainNeg);
249                int trainingCorrectExamples = trainingCorrectPosClassified + trainingCorrectNegClassified;
250                double trainingAccuracy = 100*((double)trainingCorrectExamples/(trainPos.size()+
251                                trainNeg.size()));
252                accuracyTraining.addNumber(trainingAccuracy);
253                // calculate test accuracies
254                int correctPosClassified = getCorrectPosClassified(rs, concept, testPos);
255                int correctNegClassified = getCorrectNegClassified(rs, concept, testNeg);
256                int correctExamples = correctPosClassified + correctNegClassified;
257                double currAccuracy = 100*((double)correctExamples/(testPos.size()+
258                                testNeg.size()));
259                accuracy.addNumber(currAccuracy);
260                // calculate training F-Score
261                int negAsPosTraining = rs.hasType(concept, trainNeg).size();
262                double precisionTraining = trainingCorrectPosClassified + negAsPosTraining == 0 ? 0 : trainingCorrectPosClassified / (double) (trainingCorrectPosClassified + negAsPosTraining);
263                double recallTraining = trainingCorrectPosClassified / (double) trainPos.size();
264                fMeasureTraining.addNumber(100*Heuristics.getFScore(recallTraining, precisionTraining));
265                // calculate test F-Score
266                int negAsPos = rs.hasType(concept, testNeg).size();
267                double precision = correctPosClassified + negAsPos == 0 ? 0 : correctPosClassified / (double) (correctPosClassified + negAsPos);
268                double recall = correctPosClassified / (double) testPos.size();
269//              System.out.println(precision);System.out.println(recall);
270                fMeasure.addNumber(100*Heuristics.getFScore(recall, precision));
271                
272                length.addNumber(OWLClassExpressionUtils.getLength(concept));
273                
274                
275                output += "test set errors pos: " + tmp2 + "\n";
276                output += "test set errors neg: " + tmp3 + "\n";
277                output += "fold " + currFold + ":" + "\n";
278                output += "  training: " + pos.size() + " positive and " + neg.size() + " negative examples";
279                output += "  testing: " + correctPosClassified + "/" + testPos.size() + " correct positives, "
280                                + correctNegClassified + "/" + testNeg.size() + " correct negatives" + "\n";
281                output += "  concept: " + concept.toString().replace("\n", " ") + "\n";
282                output += "  accuracy: " + df.format(currAccuracy) + "% (" + df.format(trainingAccuracy) + "% on training set)" + "\n";
283                output += "  length: " + df.format(OWLClassExpressionUtils.getLength(concept)) + "\n";
284                output += "  runtime: " + df.format(algorithmDuration/(double)1000000000) + "s" + "\n";
285                
286                outputWriter(output);
287        }
288        
289        protected int getCorrectPosClassified(AbstractReasonerComponent rs, OWLClassExpression concept, Set<OWLIndividual> testSetPos) {
290                return rs.hasType(concept, testSetPos).size();
291        }
292        
293        protected int getCorrectNegClassified(AbstractReasonerComponent rs, OWLClassExpression concept, Set<OWLIndividual> testSetNeg) {
294                return testSetNeg.size() - rs.hasType(concept, testSetNeg).size();
295        }
296        
297        public static Set<OWLIndividual> getTestingSet(List<OWLIndividual> examples, int[] splits, int fold) {
298                int fromIndex;
299                // we either start from 0 or after the last fold ended
300                if(fold == 0)
301                        fromIndex = 0;
302                else
303                        fromIndex = splits[fold-1];
304                // the split corresponds to the ends of the folds
305                int toIndex = splits[fold];
306                
307//              System.out.println("from " + fromIndex + " to " + toIndex);
308                
309                Set<OWLIndividual> testingSet = new HashSet<>();
310                // +1 because 2nd element is exclusive in subList method
311                testingSet.addAll(examples.subList(fromIndex, toIndex));
312                return testingSet;
313        }
314        
315        public static Set<OWLIndividual> getTrainingSet(Set<OWLIndividual> examples, Set<OWLIndividual> testingSet) {
316                return Sets.difference(examples, testingSet);
317        }
318        
319        // takes nr. of examples and the nr. of folds for this examples;
320        // returns an array which says where each fold ends, i.e.
321        // splits[i] is the index of the last element of fold i in the examples
322        public static int[] calculateSplits(int nrOfExamples, int folds) {
323                int[] splits = new int[folds];
324                for(int i=1; i<=folds; i++) {
325                        // we always round up to the next integer
326                        splits[i-1] = (int)Math.ceil(i*nrOfExamples/(double)folds);
327                }
328                return splits;
329        }
330        
331        public static String statOutput(DecimalFormat df, Stat stat, String unit) {
332                String str = "av. " + df.format(stat.getMean()) + unit;
333                str += " (deviation " + df.format(stat.getStandardDeviation()) + unit + "; ";
334                str += "min " + df.format(stat.getMin()) + unit + "; ";
335                str += "max " + df.format(stat.getMax()) + unit + ")";
336                return str;
337        }
338
339        public Stat getAccuracy() {
340                return accuracy;
341        }
342
343        public Stat getLength() {
344                return length;
345        }
346
347        public Stat getRuntime() {
348                return runtime;
349        }
350        
351        protected void outputWriter(String output) {
352                if(writeToFile) {
353                        Files.appendToFile(outputFile, output +"\n");
354                        System.out.println(output);
355                } else {
356                        System.out.println(output);
357                }
358                
359        }
360
361        public Stat getfMeasure() {
362                return fMeasure;
363        }
364
365        public Stat getfMeasureTraining() {
366                return fMeasureTraining;
367        }
368
369}