001/**
002 * This file is part of LEAP.
003 *
004 * LEAP was implemented as a plugin of DL-Learner http://dl-learner.org, but
005 * some components can be used as stand-alone.
006 *
007 * LEAP is free software; you can redistribute it and/or modify it under the
008 * terms of the GNU General Public License as published by the Free Software
009 * Foundation; either version 3 of the License, or (at your option) any later
010 * version.
011 *
012 * LEAP is distributed in the hope that it will be useful, but WITHOUT ANY
013 * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
014 * A PARTICULAR PURPOSE. See the GNU General Public License for more details.
015 *
016 * You should have received a copy of the GNU General Public License along with
017 * this program. If not, see <http://www.gnu.org/licenses/>.
018 *
019 */
020package org.dllearner.cli.unife;
021
022import java.text.DecimalFormat;
023import java.util.Collections;
024import java.util.HashSet;
025import java.util.LinkedList;
026import java.util.List;
027import java.util.Random;
028import java.util.Set;
029import java.util.TreeSet;
030import mpi.MPI;
031import org.apache.log4j.Logger;
032//import org.dllearner.core.AbstractCELA;
033import org.dllearner.core.AbstractLearningProblem;
034import org.dllearner.core.AbstractReasonerComponent;
035import org.dllearner.core.ComponentInitException;
036import org.dllearner.learningproblems.ClassLearningProblem;
037//import org.dllearner.learningproblems.Heuristics;
038import org.dllearner.learningproblems.PosNegLP;
039import org.dllearner.learningproblems.PosOnlyLP;
040import org.apache.commons.io.FilenameUtils;
041import org.dllearner.cli.CrossValidation;
042import org.dllearner.core.probabilistic.unife.AbstractPSLA;
043import org.dllearner.algorithms.probabilistic.parameter.unife.edge.AbstractEDGE;
044import org.dllearner.algorithms.probabilistic.structure.unife.leap.AbstractLEAP;
045import org.dllearner.utils.unife.OWLUtils;
046import org.dllearner.utils.unife.ReflectionHelper;
047import org.semanticweb.owlapi.apibinding.OWLManager;
048import org.semanticweb.owlapi.model.AxiomType;
049import org.semanticweb.owlapi.model.OWLAxiom;
050import org.semanticweb.owlapi.model.OWLClass;
051import org.semanticweb.owlapi.model.OWLClassAssertionAxiom;
052import org.semanticweb.owlapi.model.OWLDataFactory;
053import org.semanticweb.owlapi.model.OWLIndividual;
054import org.semanticweb.owlapi.model.OWLOntology;
055import org.semanticweb.owlapi.model.OWLOntologyCreationException;
056import org.semanticweb.owlapi.model.OWLOntologyStorageException;
057import unife.bundle.utilities.BundleUtilities;
058import unife.edge.mpi.MPIUtilities;
059
060/**
061 * Performs a pseudo cross validation for the given problem using the LEAP
062 * system. It is not a real k-fold cross validation, because this class executes
063 * only a k-fold training. It produces k output file which must be submitted to
064 * testing
065 *
066 * @author Jens Lehmann
067 * @author Giuseppe Cota <giuseppe.cota@unife.it>
068 */
069public class LEAPCrossValidation extends CrossValidation {
070
071    private static final Logger logger = Logger.getLogger(LEAPCrossValidation.class);
072
073    public LEAPCrossValidation(AbstractPSLA psla, int folds, boolean leaveOneOut, boolean parallel) throws OWLOntologyStorageException, OWLOntologyCreationException {
074
075        boolean master = true;
076        
077        if (parallel) {
078            master = MPIUtilities.isMaster(MPI.COMM_WORLD);
079        }
080        
081        AbstractLearningProblem lp = psla.getLearningProblem();
082
083        DecimalFormat df = new DecimalFormat();
084
085        // the training sets used later on
086        List<Set<OWLIndividual>> trainingSetsPos = new LinkedList<>();
087        List<Set<OWLIndividual>> trainingSetsNeg = new LinkedList<>();
088        List<Set<OWLIndividual>> testSetsPos = new LinkedList<>();
089        List<Set<OWLIndividual>> testSetsNeg = new LinkedList<>();
090
091        // get individuals and shuffle them too
092        Set<OWLIndividual> posExamples = new HashSet();
093        Set<OWLIndividual> negExamples = new HashSet();
094        logger.debug("Setting cross validation");
095        if (lp instanceof PosNegLP) {
096            posExamples = ((PosNegLP) lp).getPositiveExamples();
097            negExamples = ((PosNegLP) lp).getNegativeExamples();
098        } else if (lp instanceof PosOnlyLP) {
099            posExamples = ((PosNegLP) lp).getPositiveExamples();
100            //negExamples = Helper.difference(lp.getReasoner().getIndividuals(), posExamples);
101            negExamples = new HashSet<>();
102        } else if (lp instanceof ClassLearningProblem) {
103            try {
104                posExamples = new HashSet((List<OWLIndividual>) ReflectionHelper.getPrivateField(lp, "classInstances"));
105                negExamples = new HashSet((List<OWLIndividual>) ReflectionHelper.getPrivateField(lp, "superClassInstances"));
106                // if the number of negative examples is lower than the number of folds 
107                // get as negative examples all the individuals that are not instances of ClassToDescribe
108                if (negExamples.size() < folds) {
109                    logger.info("The number of folds is higher than the number of "
110                            + "negative examples. Selecting the instances of Thing which "
111                            + "are non instances of ClasstoDescribe as negative Examples");
112                    AbstractReasonerComponent reasoner = lp.getReasoner();
113                    // get as negative examples all the individuals which belong to the class Thing
114                    // but not to the ClassToDescribe
115                    negExamples = reasoner.getIndividuals(OWLManager.getOWLDataFactory().getOWLThing());
116                    negExamples.removeAll(posExamples);
117                }
118            } catch (Exception e) {
119                logger.error("Cannot get positive and negative individuals for the cross validation");
120                logger.error(e);
121                System.exit(-2);
122            }
123        } else {
124            throw new IllegalArgumentException("Only ClassLearningProblem, PosNeg and PosOnly learning problems are supported");
125        }
126        List<OWLIndividual> posExamplesList = new LinkedList<>(posExamples);
127        List<OWLIndividual> negExamplesList = new LinkedList<>(negExamples);
128        Collections.shuffle(posExamplesList, new Random(1));
129        Collections.shuffle(negExamplesList, new Random(2));
130
131        // sanity check whether nr. of folds makes sense for this benchmark
132        if (!leaveOneOut && (posExamples.size() < folds || negExamples.size() < folds)) {
133            logger.error("The number of folds is higher than the number of "
134                    + "positive/negative examples. This can result in empty test sets. Exiting.");
135            System.exit(0);
136        }
137
138        if (leaveOneOut) {
139            // note that leave-one-out is not identical to k-fold with
140            // k = nr. of examples in the current implementation, because
141            // with n folds and n examples there is no guarantee that a fold
142            // is never empty (this is an implementation issue)
143            int nrOfExamples = posExamples.size() + negExamples.size();
144            for (int i = 0; i < nrOfExamples; i++) {
145                // ...
146            }
147            logger.error("Leave-one-out not supported yet.");
148            System.exit(1);
149        } else {
150            // calculating where to split the sets, ; note that we split
151            // positive and negative examples separately such that the 
152            // distribution of positive and negative examples remains similar
153            // (note that there are better but more complex ways to implement this,
154            // which guarantee that the sum of the elements of a fold for pos
155            // and neg differs by at most 1 - it can differ by 2 in our implementation,
156            // e.g. with 3 folds, 4 pos. examples, 4 neg. examples)
157            int[] splitsPos = calculateSplits(posExamples.size(), folds);
158            int[] splitsNeg = calculateSplits(negExamples.size(), folds);
159
160//                              System.out.println(splitsPos[0]);
161//                              System.out.println(splitsNeg[0]);
162            // calculating training and test sets
163            for (int i = 0; i < folds; i++) {
164                Set<OWLIndividual> testPos = getTestingSet(posExamplesList, splitsPos, i);
165                Set<OWLIndividual> testNeg = getTestingSet(negExamplesList, splitsNeg, i);
166                testSetsPos.add(i, testPos);
167                testSetsNeg.add(i, testNeg);
168                trainingSetsPos.add(i, getTrainingSet(posExamples, testPos));
169                trainingSetsNeg.add(i, getTrainingSet(negExamples, testNeg));
170            }
171
172        }
173
174        String completeLearnedOntology = psla.getOutputFile();
175        String cloBase = FilenameUtils.removeExtension(completeLearnedOntology);
176        String cloExt = FilenameUtils.getExtension(completeLearnedOntology);
177
178        String positiveFile = "posExamples.owl";
179        String pfBase = FilenameUtils.removeExtension(positiveFile);
180        String pfExt = FilenameUtils.getExtension(positiveFile);
181        String negativeFile = "negExamples.owl";
182        String nfBase = FilenameUtils.removeExtension(negativeFile);
183        String nfExt = FilenameUtils.getExtension(negativeFile);
184
185        logger.debug("Performing Cross Validation");
186        // run the algorithm
187        for (int currFold = 0; currFold < folds; currFold++) {
188            logger.debug("Current Fold: " + (currFold + 1));
189            // setting positive and negative individuals
190            final Set<OWLIndividual> trainPos = trainingSetsPos.get(currFold);
191            final Set<OWLIndividual> trainNeg = trainingSetsNeg.get(currFold);
192            final Set<OWLIndividual> testPos = testSetsPos.get(currFold);
193            final Set<OWLIndividual> testNeg = testSetsNeg.get(currFold);
194            if (lp instanceof PosNegLP) {
195                ((PosNegLP) lp).setPositiveExamples(trainPos);
196                ((PosNegLP) lp).setNegativeExamples(trainNeg);
197                try {
198                    lp.init();
199                } catch (ComponentInitException e) {
200                    logger.error(e);
201                    logger.error(e.getLocalizedMessage());
202                    System.exit(-2);
203                }
204            } else if (lp instanceof PosOnlyLP) {
205                // il cross training viene fatto solo per gli esempi/individui positivi
206                ((PosOnlyLP) lp).setPositiveExamples(new TreeSet<OWLIndividual>(trainPos));
207                try {
208                    lp.init();
209                } catch (ComponentInitException e) {
210                    logger.error(e);
211                    logger.error(e.getLocalizedMessage());
212                    System.exit(-2);
213                }
214                // set negative f
215            } else if (lp instanceof ClassLearningProblem) {
216                try {
217                    // Initialize the ClassLearningProblem object first and then 
218                    // modify his  private fields
219                    //lp.init();
220                    ReflectionHelper.setPrivateField(lp, "classInstances", trainPos);
221                    ReflectionHelper.setPrivateField(lp, "superClassInstances", trainNeg);
222                    ReflectionHelper.setPrivateField(lp, "negatedClassInstances", trainNeg);
223                } catch (Exception e) {
224                    logger.error("Cannot set positive and negative individuals for the cross validation");
225                    logger.error(e);
226                    System.exit(-2);
227                }
228            }
229
230            AbstractEDGE edge = (AbstractEDGE) psla.getLearningParameterAlgorithm();
231            OWLOntology startOntology = null;
232            try {
233                startOntology = BundleUtilities.copyOntology(edge.getSourcesOntology());
234
235            } catch (OWLOntologyCreationException e) {
236                e.printStackTrace();
237            }
238
239            psla.setOutputFile(cloBase + (currFold + 1) + "." + cloExt);
240            try {
241                //rs.init();
242                edge.init();
243                psla.init();
244                //edge.setPositiveFile(pfBase + (currFold + 1) + "." + pfExt);
245                //edge.setNegativeFile(nfBase + (currFold + 1) + "." + nfExt);
246                //edge.init();
247            } catch (ComponentInitException e) {
248                // TODO Auto-generated catch block
249                e.printStackTrace();
250            }
251
252            psla.start();
253
254            if (master) {
255                Set<OWLAxiom> posExamplesAxioms = edge.getPositiveExampleAxioms();
256                Set<OWLAxiom> negExamplesAxioms = edge.getNegativeExampleAxioms();
257                OWLDataFactory odf = OWLManager.getOWLDataFactory();
258                // in the case replace superClass
259                if (lp instanceof ClassLearningProblem) {
260                    ClassLearningProblem clp = (ClassLearningProblem) lp;
261                    Set<OWLAxiom> tempPos = new HashSet<>();
262                    Set<OWLAxiom> tempNeg = new HashSet<>();
263
264                    for (OWLAxiom ax : posExamplesAxioms) {
265                        if (ax.isOfType(AxiomType.CLASS_ASSERTION)) {
266                            OWLClassAssertionAxiom ax1 = (OWLClassAssertionAxiom) ax;
267                            tempPos.add(odf.getOWLClassAssertionAxiom(clp.getClassToDescribe(), ax1.getIndividual()));
268                        }
269                    }
270                    for (OWLAxiom ax : negExamplesAxioms) {
271                        if (ax.isOfType(AxiomType.CLASS_ASSERTION)) {
272                            OWLClassAssertionAxiom ax1 = (OWLClassAssertionAxiom) ax;
273                            tempNeg.add(odf.getOWLClassAssertionAxiom(clp.getClassToDescribe(), ax1.getIndividual()));
274                        }
275                    }
276                    posExamplesAxioms = tempPos;
277                    negExamplesAxioms = tempNeg;
278                }
279                // convert test set into axioms
280                Set<OWLAxiom> testAxiomsPos = new HashSet<>();
281                Set<OWLAxiom> testAxiomsNeg = new HashSet<>();
282                OWLClass clazz = ((AbstractLEAP) psla).getDummyClass();
283                if (lp instanceof ClassLearningProblem) {
284                    clazz = ((ClassLearningProblem) lp).getClassToDescribe();
285                }
286                for (OWLIndividual ind : testPos) {
287                    testAxiomsPos.add(odf.getOWLClassAssertionAxiom(clazz, ind));
288                }
289                for (OWLIndividual ind : testNeg) {
290                    testAxiomsNeg.add(odf.getOWLClassAssertionAxiom(clazz, ind));
291                }
292
293                OWLUtils.saveAxioms(testAxiomsPos, "posTestExamples" + (currFold + 1) + "." + pfExt, "OWLXML");
294                OWLUtils.saveAxioms(testAxiomsNeg, "negTestExamples" + (currFold + 1) + "." + nfExt, "OWLXML");
295                OWLUtils.saveAxioms(posExamplesAxioms, pfBase + (currFold + 1) + "." + pfExt, "OWLXML");
296                OWLUtils.saveAxioms(negExamplesAxioms, nfBase + (currFold + 1) + "." + nfExt, "OWLXML");
297            }
298        }
299    }
300
301    protected int getCorrectPosClassified(AbstractReasonerComponent rs, OWLClass concept, Set<OWLIndividual> testSetPos) {
302        return rs.hasType(concept, testSetPos).size();
303    }
304
305    protected int getCorrectNegClassified(AbstractReasonerComponent rs, OWLClass concept, Set<OWLIndividual> testSetNeg) {
306        return testSetNeg.size() - rs.hasType(concept, testSetNeg).size();
307    }
308
309    public static Set<OWLIndividual> getTestingSet(List<OWLIndividual> examples, int[] splits, int fold) {
310        int fromIndex;
311        // we either start from 0 or after the last fold ended
312        if (fold == 0) {
313            fromIndex = 0;
314        } else {
315            fromIndex = splits[fold - 1];
316        }
317        // the split corresponds to the ends of the folds
318        int toIndex = splits[fold];
319
320//              System.out.println("from " + fromIndex + " to " + toIndex);
321        Set<OWLIndividual> testingSet = new HashSet<>();
322        // +1 because 2nd element is exclusive in subList method
323        testingSet.addAll(examples.subList(fromIndex, toIndex));
324        return testingSet;
325    }
326
327}