001/**
002 * 
003 */
004package org.dllearner.algorithms.miles;
005
006import java.util.ArrayList;
007import java.util.HashSet;
008import java.util.List;
009import java.util.Set;
010import java.util.SortedSet;
011import java.util.Timer;
012import java.util.TimerTask;
013
014import org.dllearner.core.AbstractCELA;
015import org.dllearner.core.AbstractReasonerComponent;
016import org.dllearner.core.ComponentInitException;
017import org.dllearner.learningproblems.ClassLearningProblem;
018import org.dllearner.learningproblems.PosNegLP;
019import org.dllearner.learningproblems.PosNegLPStandard;
020import org.semanticweb.owlapi.model.OWLClassExpression;
021import org.semanticweb.owlapi.model.OWLIndividual;
022import org.slf4j.Logger;
023import org.slf4j.LoggerFactory;
024
025import com.google.common.collect.Sets;
026
027/**
028 * First draft of a new kind of learning algorithm:
029 * The basic idea is as follows: 
030 * We take an existing learning algorithm which holds internally a search tree.
031 * During the base algorithm run, we query periodically for nodes in the search tree,
032 * and check if a linear combination gives better results.
033 * @author Lorenz Buehmann
034 *
035 */
036public class MILES {
037        
038        private static final Logger logger = LoggerFactory.getLogger(MILES.class);
039        
040        
041        private AbstractCELA la;
042        private PosNegLP lp;
043        private AbstractReasonerComponent rc;
044        
045        // settings for the frequency
046        private int delay = 0;
047        private int period = 1000;
048        
049        // we can apply the base learning algorithm on a subset of the examples
050        // and evaluate the combined solution on the rest of the data
051        private boolean performInternalCV = true;
052        private double sampleSize = 0.9;
053        
054        //
055
056        public MILES(AbstractCELA la, PosNegLP lp, AbstractReasonerComponent rc) {
057                this.la = la;
058                this.lp = lp;
059                this.rc = rc;
060        }
061        
062        public MILES(AbstractCELA la, ClassLearningProblem lp, AbstractReasonerComponent rc) {
063                this.la = la;
064                this.rc = rc;
065                
066                // we convert to PosNegLP because we need at least the distinction between pos and neg examples
067                // for the sampling
068                // TODO do this in the PosNegLP constructor
069                this.lp = new PosNegLPStandard(rc);
070                SortedSet<OWLIndividual> posExamples = rc.getIndividuals(lp.getClassToDescribe());
071                Set<OWLIndividual> negExamples = Sets.difference(rc.getIndividuals(), posExamples);
072                this.lp.setPositiveExamples(posExamples);
073                this.lp.setNegativeExamples(negExamples);
074        }
075        
076        public void start(){
077                // if enabled, we split the data into a train and a test set
078                if(performInternalCV){
079                        List<OWLIndividual> posExamples = new ArrayList<>(lp.getPositiveExamples());
080                        List<OWLIndividual> negExamples = new ArrayList<>(lp.getNegativeExamples());
081                        
082                        // pos example subsets
083                        int trainSizePos = (int) (0.9 * posExamples.size());
084                        List<OWLIndividual> posExamplesTrain = posExamples.subList(0, trainSizePos);
085                        List<OWLIndividual> posExamplesTest = posExamples.subList(trainSizePos, posExamples.size());
086                        
087                        // neg example subsets
088                        int trainSizeNeg = (int) (0.9 * negExamples.size());
089                        List<OWLIndividual> negExamplesTrain = negExamples.subList(0, trainSizeNeg);
090                        List<OWLIndividual> negExamplesTest = negExamples.subList(trainSizeNeg, negExamples.size());
091                        
092                        lp.setPositiveExamples(new HashSet<>(posExamplesTrain));
093                        lp.setNegativeExamples(new HashSet<>(negExamplesTrain));
094                        
095                        // TODO replace by      
096                                                //FoldGenerator<OWLIndividual> foldGenerator = new FoldGenerator<OWLIndividual>(lp.getPositiveExamples(), lp.getNegativeExamples());
097                        
098                        try {
099                                lp.init();
100                        } catch (ComponentInitException e) {
101                                e.printStackTrace();
102                        }
103                }
104                
105                // 1. start the base learning algorithm in a separate thread
106                Thread t = new Thread(new Runnable() {
107                        
108                        @Override
109                        public void run() {
110                                la.start();
111                        }
112                });
113                t.start();
114                
115                // 2. each x seconds get the top n concepts and validate the linear combination
116                Timer timer = new Timer();
117                LinearClassificationTask linearClassificationTask = new LinearClassificationTask();
118                timer.schedule(linearClassificationTask, delay, period);
119                
120                try {
121                        t.join();
122                } catch (InterruptedException e) {
123                        e.printStackTrace();
124                }
125                timer.cancel();
126                
127                // run the task one more time to ensure that we did it also with the final data
128                linearClassificationTask.run();
129        }
130        
131        class LinearClassificationTask extends TimerTask {
132                
133                private DescriptionLinearClassifier classifier;
134
135                public LinearClassificationTask() {
136                        classifier = new DescriptionLinearClassifier(lp, rc);
137                }
138                
139                @Override
140                public void run() {
141                        logger.debug("Computing linear combination...");
142                        long start = System.currentTimeMillis();
143                        List<OWLClassExpression> descriptions = la.getCurrentlyBestDescriptions(5);
144                        classifier.getLinearCombination(descriptions); 
145                        long end = System.currentTimeMillis();
146                        if (logger.isDebugEnabled()) {
147                                logger.debug("Operation took " + (end - start) + "ms");
148                        }
149                }
150        }
151}