001/** 002 * 003 */ 004package org.dllearner.algorithms.miles; 005 006import java.util.ArrayList; 007import java.util.HashSet; 008import java.util.List; 009import java.util.Set; 010import java.util.SortedSet; 011import java.util.Timer; 012import java.util.TimerTask; 013 014import org.dllearner.core.AbstractCELA; 015import org.dllearner.core.AbstractReasonerComponent; 016import org.dllearner.core.ComponentInitException; 017import org.dllearner.learningproblems.ClassLearningProblem; 018import org.dllearner.learningproblems.PosNegLP; 019import org.dllearner.learningproblems.PosNegLPStandard; 020import org.semanticweb.owlapi.model.OWLClassExpression; 021import org.semanticweb.owlapi.model.OWLIndividual; 022import org.slf4j.Logger; 023import org.slf4j.LoggerFactory; 024 025import com.google.common.collect.Sets; 026 027/** 028 * First draft of a new kind of learning algorithm: 029 * The basic idea is as follows: 030 * We take an existing learning algorithm which holds internally a search tree. 031 * During the base algorithm run, we query periodically for nodes in the search tree, 032 * and check if a linear combination gives better results. 033 * @author Lorenz Buehmann 034 * 035 */ 036public class MILES { 037 038 private static final Logger logger = LoggerFactory.getLogger(MILES.class); 039 040 041 private AbstractCELA la; 042 private PosNegLP lp; 043 private AbstractReasonerComponent rc; 044 045 // settings for the frequency 046 private int delay = 0; 047 private int period = 1000; 048 049 // we can apply the base learning algorithm on a subset of the examples 050 // and evaluate the combined solution on the rest of the data 051 private boolean performInternalCV = true; 052 private double sampleSize = 0.9; 053 054 // 055 056 public MILES(AbstractCELA la, PosNegLP lp, AbstractReasonerComponent rc) { 057 this.la = la; 058 this.lp = lp; 059 this.rc = rc; 060 } 061 062 public MILES(AbstractCELA la, ClassLearningProblem lp, AbstractReasonerComponent rc) { 063 this.la = la; 064 this.rc = rc; 065 066 // we convert to PosNegLP because we need at least the distinction between pos and neg examples 067 // for the sampling 068 // TODO do this in the PosNegLP constructor 069 this.lp = new PosNegLPStandard(rc); 070 SortedSet<OWLIndividual> posExamples = rc.getIndividuals(lp.getClassToDescribe()); 071 Set<OWLIndividual> negExamples = Sets.difference(rc.getIndividuals(), posExamples); 072 this.lp.setPositiveExamples(posExamples); 073 this.lp.setNegativeExamples(negExamples); 074 } 075 076 public void start(){ 077 // if enabled, we split the data into a train and a test set 078 if(performInternalCV){ 079 List<OWLIndividual> posExamples = new ArrayList<>(lp.getPositiveExamples()); 080 List<OWLIndividual> negExamples = new ArrayList<>(lp.getNegativeExamples()); 081 082 // pos example subsets 083 int trainSizePos = (int) (0.9 * posExamples.size()); 084 List<OWLIndividual> posExamplesTrain = posExamples.subList(0, trainSizePos); 085 List<OWLIndividual> posExamplesTest = posExamples.subList(trainSizePos, posExamples.size()); 086 087 // neg example subsets 088 int trainSizeNeg = (int) (0.9 * negExamples.size()); 089 List<OWLIndividual> negExamplesTrain = negExamples.subList(0, trainSizeNeg); 090 List<OWLIndividual> negExamplesTest = negExamples.subList(trainSizeNeg, negExamples.size()); 091 092 lp.setPositiveExamples(new HashSet<>(posExamplesTrain)); 093 lp.setNegativeExamples(new HashSet<>(negExamplesTrain)); 094 095 // TODO replace by 096 //FoldGenerator<OWLIndividual> foldGenerator = new FoldGenerator<OWLIndividual>(lp.getPositiveExamples(), lp.getNegativeExamples()); 097 098 try { 099 lp.init(); 100 } catch (ComponentInitException e) { 101 e.printStackTrace(); 102 } 103 } 104 105 // 1. start the base learning algorithm in a separate thread 106 Thread t = new Thread(new Runnable() { 107 108 @Override 109 public void run() { 110 la.start(); 111 } 112 }); 113 t.start(); 114 115 // 2. each x seconds get the top n concepts and validate the linear combination 116 Timer timer = new Timer(); 117 LinearClassificationTask linearClassificationTask = new LinearClassificationTask(); 118 timer.schedule(linearClassificationTask, delay, period); 119 120 try { 121 t.join(); 122 } catch (InterruptedException e) { 123 e.printStackTrace(); 124 } 125 timer.cancel(); 126 127 // run the task one more time to ensure that we did it also with the final data 128 linearClassificationTask.run(); 129 } 130 131 class LinearClassificationTask extends TimerTask { 132 133 private DescriptionLinearClassifier classifier; 134 135 public LinearClassificationTask() { 136 classifier = new DescriptionLinearClassifier(lp, rc); 137 } 138 139 @Override 140 public void run() { 141 logger.debug("Computing linear combination..."); 142 long start = System.currentTimeMillis(); 143 List<OWLClassExpression> descriptions = la.getCurrentlyBestDescriptions(5); 144 classifier.getLinearCombination(descriptions); 145 long end = System.currentTimeMillis(); 146 if (logger.isDebugEnabled()) { 147 logger.debug("Operation took " + (end - start) + "ms"); 148 } 149 } 150 } 151}