001package org.dllearner.algorithms.meta; 002 003import com.google.common.collect.Iterables; 004import org.dllearner.core.*; 005import org.dllearner.core.config.ConfigOption; 006import org.dllearner.learningproblems.ClassLearningProblem; 007import org.dllearner.learningproblems.PosNegLP; 008import org.dllearner.learningproblems.PosOnlyLP; 009import org.dllearner.utilities.Helper; 010import org.semanticweb.owlapi.model.OWLClassExpression; 011import org.semanticweb.owlapi.model.OWLIndividual; 012import org.semanticweb.owlapi.model.OWLObjectUnionOf; 013import org.slf4j.Logger; 014import org.slf4j.LoggerFactory; 015 016import java.util.*; 017import java.util.stream.Collectors; 018 019/** 020 * A meta algorithm that combines the (partial) solutions of multiple calls of the base class learning algorithm LA 021 * into a disjunction. 022 * In particular, a partial solution is computed by running LA for a given time only on examples of the learning 023 * problem that aren't already covered by previously computed solutions. 024 * 025 * @author Lorenz Buehmann 026 */ 027public class DisjunctiveCELA extends AbstractCELA { 028 029 private static final Logger log = LoggerFactory.getLogger(DisjunctiveCELA.class); 030 031 @ConfigOption(defaultValue = "0.1", description="Specifies the min accuracy for a partial solution.") 032 private double minAccuracyPartialSolution = 0.1; 033 034 @ConfigOption(defaultValue = "10", description="Specifies how long the algorithm should search for a partial solution.") 035 private int partialSolutionSearchTimeSeconds = 10; 036 037 @ConfigOption(defaultValue = "false", description = "If yes, then the algorithm tries to cover all positive examples. " + 038 "Note that while this improves accuracy on the testing set, it may lead to overfitting.") 039 private boolean tryFullCoverage = false; 040 041 @ConfigOption(defaultValue="false", description="algorithm will terminate immediately when a correct definition is found") 042 private boolean stopOnFirstDefinition = false; 043 044 @ConfigOption(defaultValue="0.0", description="the (approximated) percentage of noise within the examples") 045 private double noisePercentage = 0.0; 046 047 // the class with which we start the refinement process 048 @ConfigOption(defaultValue="owl:Thing", description="You can specify a start class for the algorithm. " + 049 "To do this, you have to use Manchester OWL syntax without using prefixes.") 050 private OWLClassExpression startClass; 051 052 053 // the core learning algorithm 054 private final AbstractCELA la; 055 private final AbstractClassExpressionLearningProblem<? extends Score> lp; 056 057 private Set<OWLIndividual> currentPosExamples; 058 private Set<OWLIndividual> currentNegExamples; 059 060 private Set<OWLIndividual> initialPosExamples; 061 062 private List<EvaluatedDescription<? extends Score>> partialSolutions = new ArrayList<>(); 063 064 /** 065 * @param la the basic learning algorithm 066 */ 067 public DisjunctiveCELA(AbstractCELA la) { 068 this.la = la; 069 this.lp = la.getLearningProblem(); 070 } 071 072 @Override 073 public void init() throws ComponentInitException { 074 la.setMaxExecutionTimeInSeconds(partialSolutionSearchTimeSeconds); 075 076 reset(); 077 initialized = true; 078 } 079 080 @Override 081 public void start() { 082 nanoStartTime = System.nanoTime(); 083 084 while(!stop && !stoppingCriteriaSatisfied()) { 085 086 // compute next partial solution 087 EvaluatedDescription<? extends Score> partialSolution = computePartialSolution(); 088 089 // add to global solution if criteria are satisfied 090 if(addPartialSolution(partialSolution)) { 091 log.info("new partial solution found: {}", partialSolution); 092 093 // update the learning problem 094 updateLearningProblem(partialSolution); 095 } 096 097 } 098 log.info("finished computation in {}.\n top 10 solutions:\n{}", 099 Helper.prettyPrintMilliSeconds(getCurrentRuntimeInMilliSeconds()), 100 getSolutionString()); 101 102 } 103 104 private void reset() { 105 currentPosExamples = new TreeSet<>(((PosNegLP) la.getLearningProblem()).getPositiveExamples()); 106 currentNegExamples = new TreeSet<>(((PosNegLP) la.getLearningProblem()).getNegativeExamples()); 107 108 // keep copy of the initial pos examples 109 initialPosExamples = new TreeSet<>(currentPosExamples); 110 } 111 112 private EvaluatedDescription<? extends Score> computePartialSolution() { 113 log.info("computing next partial solution..."); 114 la.start(); 115 EvaluatedDescription<? extends Score> partialSolution = la.getCurrentlyBestEvaluatedDescription(); 116 return partialSolution; 117 } 118 119 private boolean addPartialSolution(EvaluatedDescription<? extends Score> partialSolution) { 120 // check whether partial solution follows criteria (currently only accuracy threshold) 121 if(Double.compare(partialSolution.getAccuracy(), minAccuracyPartialSolution) > 0) { 122 partialSolutions.add(partialSolution); 123 124 // create combined solution 125 OWLObjectUnionOf combinedCE = dataFactory.getOWLObjectUnionOf( 126 partialSolutions.stream() 127 .map(EvaluatedHypothesis::getDescription) 128 .collect(Collectors.toSet())); 129 // evalute combined solution 130 EvaluatedDescription<? extends Score> combinedSolution = lp.evaluate(combinedCE); 131 bestEvaluatedDescriptions.add(combinedSolution); 132 133 return true; 134 } 135 return false; 136 } 137 138 private void updateLearningProblem(EvaluatedDescription<? extends Score> partialSolution) { 139 // get individuals covered by the solution 140 SortedSet<OWLIndividual> coveredExamples = la.getReasoner().getIndividuals(partialSolution.getDescription()); 141 142 // remove from pos examples as those are already covered 143 currentPosExamples.removeAll(coveredExamples); 144 145 // remove from neg examples as those will always be covered in the combined solution 146 currentNegExamples.removeAll(coveredExamples); 147 148 // update the learning problem itself // TODO do we need some re-init of the lp afterwards? 149 if(lp instanceof PosNegLP) { 150 ((PosNegLP) la.getLearningProblem()).setPositiveExamples(currentPosExamples); 151 ((PosNegLP) la.getLearningProblem()).setNegativeExamples(currentNegExamples); 152 } else if(lp instanceof PosOnlyLP) { 153 ((PosOnlyLP) la.getLearningProblem()).setPositiveExamples(currentPosExamples); 154 } else if(lp instanceof ClassLearningProblem){ 155 // TODO 156 } 157 158 } 159 160 private boolean stoppingCriteriaSatisfied() { 161 // global time expired 162 if(isTimeExpired()) { 163 return true; 164 } 165 166 // stop if there are no more positive examples to cover 167 if(stopOnFirstDefinition && currentPosExamples.size()==0) { 168 return true; 169 } 170 171 // we stop when the score of the last tree added is too low 172 // (indicating that the algorithm could not find anything appropriate 173 // in the timeframe set) 174 EvaluatedDescription<? extends Score> lastPartialSolution = Iterables.getLast(partialSolutions, null); 175 if(lastPartialSolution != null && Double.compare(lastPartialSolution.getAccuracy(), minAccuracyPartialSolution) <= 0) { 176 return true; 177 } 178 179 // stop when almost all positive examples have been covered 180 if(tryFullCoverage) { 181 return false; 182 } else { 183 int maxPosRemaining = (int) Math.ceil(initialPosExamples.size() * 0.05d); 184 return (currentPosExamples.size()<=maxPosRemaining); 185 } 186 } 187 188 @Override 189 public void stop() { 190 // we also have to stop the underlying learning algorithm 191 la.stop(); 192 super.stop(); 193 } 194 195 /** 196 * Sets the max. execution time of the whole algorithm. Note, this values should always be higher 197 * than the max. execution time to compute a partial solution. 198 * 199 * @param maxExecutionTimeInSeconds the overall the max. execution time 200 */ 201 @Override 202 public void setMaxExecutionTimeInSeconds(long maxExecutionTimeInSeconds) { 203 super.setMaxExecutionTimeInSeconds(maxExecutionTimeInSeconds); 204 } 205 206 public void setTryFullCoverage(boolean tryFullCoverage) { 207 this.tryFullCoverage = tryFullCoverage; 208 } 209 210 public void setMinAccuracyPartialSolution(double minAccuracyPartialSolution) { 211 this.minAccuracyPartialSolution = minAccuracyPartialSolution; 212 } 213 214 public void setPartialSolutionSearchTimeSeconds(int partialSolutionSearchTimeSeconds) { 215 this.partialSolutionSearchTimeSeconds = partialSolutionSearchTimeSeconds; 216 } 217}