001package org.dllearner.algorithms.meta;
002
003import com.google.common.collect.Iterables;
004import org.dllearner.core.*;
005import org.dllearner.core.config.ConfigOption;
006import org.dllearner.learningproblems.ClassLearningProblem;
007import org.dllearner.learningproblems.PosNegLP;
008import org.dllearner.learningproblems.PosOnlyLP;
009import org.dllearner.utilities.Helper;
010import org.semanticweb.owlapi.model.OWLClassExpression;
011import org.semanticweb.owlapi.model.OWLIndividual;
012import org.semanticweb.owlapi.model.OWLObjectUnionOf;
013import org.slf4j.Logger;
014import org.slf4j.LoggerFactory;
015
016import java.util.*;
017import java.util.stream.Collectors;
018
019/**
020 * A meta algorithm that combines the (partial) solutions of multiple calls of the base class learning algorithm LA
021 * into a disjunction.
022 * In particular, a partial solution is computed by running LA for a given time only on examples of the learning
023 * problem that aren't already covered by previously computed solutions.
024 *
025 * @author Lorenz Buehmann
026 */
027public class DisjunctiveCELA extends AbstractCELA {
028
029    private static final Logger log = LoggerFactory.getLogger(DisjunctiveCELA.class);
030
031    @ConfigOption(defaultValue = "0.1", description="Specifies the min accuracy for a partial solution.")
032    private double minAccuracyPartialSolution = 0.1;
033
034    @ConfigOption(defaultValue = "10", description="Specifies how long the algorithm should search for a partial solution.")
035    private int partialSolutionSearchTimeSeconds = 10;
036
037    @ConfigOption(defaultValue = "false", description = "If yes, then the algorithm tries to cover all positive examples. " +
038            "Note that while this improves accuracy on the testing set, it may lead to overfitting.")
039    private boolean tryFullCoverage = false;
040
041    @ConfigOption(defaultValue="false", description="algorithm will terminate immediately when a correct definition is found")
042    private boolean stopOnFirstDefinition = false;
043
044    @ConfigOption(defaultValue="0.0", description="the (approximated) percentage of noise within the examples")
045    private double noisePercentage = 0.0;
046
047    // the class with which we start the refinement process
048    @ConfigOption(defaultValue="owl:Thing", description="You can specify a start class for the algorithm. " +
049            "To do this, you have to use Manchester OWL syntax without using prefixes.")
050    private OWLClassExpression startClass;
051
052
053    // the core learning algorithm
054    private final AbstractCELA la;
055    private final AbstractClassExpressionLearningProblem<? extends Score> lp;
056
057    private Set<OWLIndividual> currentPosExamples;
058    private Set<OWLIndividual> currentNegExamples;
059
060    private Set<OWLIndividual> initialPosExamples;
061
062    private List<EvaluatedDescription<? extends Score>> partialSolutions = new ArrayList<>();
063
064    /**
065     * @param la the basic learning algorithm
066     */
067    public DisjunctiveCELA(AbstractCELA la) {
068        this.la = la;
069        this.lp = la.getLearningProblem();
070    }
071
072    @Override
073    public void init() throws ComponentInitException {
074        la.setMaxExecutionTimeInSeconds(partialSolutionSearchTimeSeconds);
075
076        reset();
077        initialized = true;
078    }
079
080    @Override
081    public void start() {
082        nanoStartTime = System.nanoTime();
083
084        while(!stop && !stoppingCriteriaSatisfied()) {
085
086            // compute next partial solution
087            EvaluatedDescription<? extends Score> partialSolution = computePartialSolution();
088
089            // add to global solution if criteria are satisfied
090            if(addPartialSolution(partialSolution)) {
091                log.info("new partial solution found: {}", partialSolution);
092
093                // update the learning problem
094                updateLearningProblem(partialSolution);
095            }
096
097        }
098        log.info("finished computation in {}.\n top 10 solutions:\n{}",
099                Helper.prettyPrintMilliSeconds(getCurrentRuntimeInMilliSeconds()),
100                getSolutionString());
101
102    }
103
104    private void reset() {
105        currentPosExamples = new TreeSet<>(((PosNegLP) la.getLearningProblem()).getPositiveExamples());
106        currentNegExamples = new TreeSet<>(((PosNegLP) la.getLearningProblem()).getNegativeExamples());
107
108        // keep copy of the initial pos examples
109        initialPosExamples = new TreeSet<>(currentPosExamples);
110    }
111
112    private EvaluatedDescription<? extends Score> computePartialSolution() {
113        log.info("computing next partial solution...");
114        la.start();
115        EvaluatedDescription<? extends Score> partialSolution = la.getCurrentlyBestEvaluatedDescription();
116        return partialSolution;
117    }
118
119    private boolean addPartialSolution(EvaluatedDescription<? extends Score> partialSolution) {
120        // check whether partial solution follows criteria (currently only accuracy threshold)
121        if(Double.compare(partialSolution.getAccuracy(), minAccuracyPartialSolution) > 0) {
122            partialSolutions.add(partialSolution);
123
124            // create combined solution
125            OWLObjectUnionOf combinedCE = dataFactory.getOWLObjectUnionOf(
126                                                            partialSolutions.stream()
127                                                                            .map(EvaluatedHypothesis::getDescription)
128                                                                            .collect(Collectors.toSet()));
129            // evalute combined solution
130            EvaluatedDescription<? extends Score> combinedSolution = lp.evaluate(combinedCE);
131            bestEvaluatedDescriptions.add(combinedSolution);
132
133            return true;
134        }
135        return false;
136    }
137
138    private void updateLearningProblem(EvaluatedDescription<? extends Score> partialSolution) {
139        // get individuals covered by the solution
140        SortedSet<OWLIndividual> coveredExamples = la.getReasoner().getIndividuals(partialSolution.getDescription());
141
142        // remove from pos examples as those are already covered
143        currentPosExamples.removeAll(coveredExamples);
144
145        // remove from neg examples as those will always be covered in the combined solution
146        currentNegExamples.removeAll(coveredExamples);
147
148        // update the learning problem itself // TODO do we need some re-init of the lp afterwards?
149        if(lp instanceof PosNegLP) {
150            ((PosNegLP) la.getLearningProblem()).setPositiveExamples(currentPosExamples);
151            ((PosNegLP) la.getLearningProblem()).setNegativeExamples(currentNegExamples);
152        } else if(lp instanceof PosOnlyLP) {
153            ((PosOnlyLP) la.getLearningProblem()).setPositiveExamples(currentPosExamples);
154        } else if(lp instanceof ClassLearningProblem){
155            // TODO
156        }
157
158    }
159
160    private boolean stoppingCriteriaSatisfied() {
161        // global time expired
162        if(isTimeExpired()) {
163            return true;
164        }
165
166        // stop if there are no more positive examples to cover
167        if(stopOnFirstDefinition && currentPosExamples.size()==0) {
168            return true;
169        }
170
171        // we stop when the score of the last tree added is too low
172        // (indicating that the algorithm could not find anything appropriate
173        // in the timeframe set)
174        EvaluatedDescription<? extends Score> lastPartialSolution = Iterables.getLast(partialSolutions, null);
175        if(lastPartialSolution != null && Double.compare(lastPartialSolution.getAccuracy(), minAccuracyPartialSolution) <= 0) {
176            return true;
177        }
178
179        // stop when almost all positive examples have been covered
180        if(tryFullCoverage) {
181            return false;
182        } else {
183            int maxPosRemaining = (int) Math.ceil(initialPosExamples.size() * 0.05d);
184            return (currentPosExamples.size()<=maxPosRemaining);
185        }
186    }
187
188    @Override
189    public void stop() {
190        // we also have to stop the underlying learning algorithm
191        la.stop();
192        super.stop();
193    }
194
195    /**
196     * Sets the max. execution time of the whole algorithm. Note, this values should always be higher
197     * than the max. execution time to compute a partial solution.
198     *
199     * @param maxExecutionTimeInSeconds the overall the max. execution time
200     */
201    @Override
202    public void setMaxExecutionTimeInSeconds(long maxExecutionTimeInSeconds) {
203        super.setMaxExecutionTimeInSeconds(maxExecutionTimeInSeconds);
204    }
205
206    public void setTryFullCoverage(boolean tryFullCoverage) {
207        this.tryFullCoverage = tryFullCoverage;
208    }
209
210    public void setMinAccuracyPartialSolution(double minAccuracyPartialSolution) {
211        this.minAccuracyPartialSolution = minAccuracyPartialSolution;
212    }
213
214    public void setPartialSolutionSearchTimeSeconds(int partialSolutionSearchTimeSeconds) {
215        this.partialSolutionSearchTimeSeconds = partialSolutionSearchTimeSeconds;
216    }
217}