001/**
002 * Copyright (C) 2007 - 2016, Jens Lehmann
003 *
004 * This file is part of DL-Learner.
005 *
006 * DL-Learner is free software; you can redistribute it and/or modify
007 * it under the terms of the GNU General Public License as published by
008 * the Free Software Foundation; either version 3 of the License, or
009 * (at your option) any later version.
010 *
011 * DL-Learner is distributed in the hope that it will be useful,
012 * but WITHOUT ANY WARRANTY; without even the implied warranty of
013 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
014 * GNU General Public License for more details.
015 *
016 * You should have received a copy of the GNU General Public License
017 * along with this program.  If not, see <http://www.gnu.org/licenses/>.
018 */
019package org.dllearner.algorithms.ocel;
020
021import org.dllearner.core.Component;
022import org.dllearner.core.ComponentAnn;
023import org.dllearner.core.ComponentInitException;
024import org.dllearner.core.annotations.NoConfigOption;
025import org.dllearner.core.config.ConfigOption;
026import org.semanticweb.owlapi.model.OWLClassExpression;
027import org.semanticweb.owlapi.model.OWLDataSomeValuesFrom;
028import org.semanticweb.owlapi.model.OWLObjectComplementOf;
029
030import java.util.Set;
031
032/**
033 * This heuristic combines the following criteria to assign a
034 * double score value to a node:
035 * <ul>
036 * <li>quality/accuracy of a concept (based on the full training set, not
037 *   the negative example coverage as the flexible heuristic)</li>
038 * <li>horizontal expansion</li>
039 * <li>accuracy gain: The heuristic takes into account the accuracy
040 *   difference between a node and its parent. If there is no gain (even
041 *   though we know that the refinement is proper) it is unlikely (although
042 *   not excluded) that the refinement is a necessary path to take towards a
043 *   solution.</li>
044 * </ul> 
045 *
046 * The heuristic has two parameters:
047 * <ul>
048 * <li>expansion penalty factor: describes how much accuracy gain is worth
049 *   an increase of horizontal expansion by one (typical value: 0.01)</li>
050 * <li>gain bonus factor: describes how accuracy gain should be weighted
051 *   versus accuracy itself (typical value: 1.00)</li>
052 * </ul>
053 *   
054 * The value of a node is calculated as follows:
055 * 
056 * <p><code>value = accuracy + gain bonus factor * accuracy gain - expansion penalty
057 * factor * horizontal expansion - node children penalty factor * number of children of node</code></p>
058 * 
059 * <p><code>accuracy = (TP + TN)/(P + N)</code></p>
060 * 
061 * <p><code>
062 * TP = number of true positives (= covered positives)<br />
063 * TN = number of true negatives (= nr of negatives examples - covered negatives)<br />
064 * P = number of positive examples<br />
065 * N = number of negative examples<br />
066 * </code></p>
067 * 
068 * @author Jens Lehmann
069 *
070 */
071@ComponentAnn(name = "multiple criteria heuristic", shortName = "multiheuristic", version = 0.7)
072public class MultiHeuristic implements ExampleBasedHeuristic, Component {
073        
074//      private OCELConfigurator configurator;
075        
076        // heuristic parameters
077        
078        @ConfigOption(description = "how much accuracy gain is worth an increase of horizontal expansion by one (typical value: 0.01)", defaultValue="0.02")
079        private double expansionPenaltyFactor = 0.02;
080        
081        @ConfigOption(description = "how accuracy gain should be weighted versus accuracy itself (typical value: 1.00)", defaultValue="0.5")
082        private double gainBonusFactor = 0.5;
083        
084        @ConfigOption(description = "penalty factor for the search tree node child count (use higher values for simple learning problems)", defaultValue="0.0001")
085        private double nodeChildPenalty = 0.0001;
086        
087        @ConfigOption(description = "the score value for the start node", defaultValue="0.1")
088        private double startNodeBonus = 0.1; //was 2.0
089        
090        // penalise errors on positive examples harder than on negative examples
091        // (positive weight = 1)
092        @ConfigOption(description = "weighting factor on the number of true negatives (true positives are weigthed with 1)", defaultValue="1.0")
093        private double negativeWeight = 1.0; // was 0.8;
094        
095        @ConfigOption(description = "penalty value to deduce for using a negated class expression (complementOf)", defaultValue="0")
096        private int negationPenalty = 0;
097
098        // examples
099        @NoConfigOption
100        private int nrOfNegativeExamples;
101        @NoConfigOption
102        private int nrOfExamples;
103        
104        @Deprecated
105        public MultiHeuristic(int nrOfPositiveExamples, int nrOfNegativeExamples) {
106                this.nrOfNegativeExamples = nrOfNegativeExamples;
107                nrOfExamples = nrOfPositiveExamples + nrOfNegativeExamples;
108//              this(nrOfPositiveExamples, nrOfNegativeExamples, 0.02, 0.5);
109        }
110        
111        public MultiHeuristic(int nrOfPositiveExamples, int nrOfNegativeExamples, double negativeWeight, double startNodeBonus, double expansionPenaltyFactor, int negationPenalty) {
112                this.nrOfNegativeExamples = nrOfNegativeExamples;
113                nrOfExamples = nrOfPositiveExamples + nrOfNegativeExamples;
114//              this.configurator = configurator;
115                this.negativeWeight = negativeWeight;
116                this.startNodeBonus = startNodeBonus;
117                this.expansionPenaltyFactor = expansionPenaltyFactor;
118        }
119
120    public MultiHeuristic(){
121
122    }
123
124//      public MultiHeuristic(int nrOfPositiveExamples, int nrOfNegativeExamples, double expansionPenaltyFactor, double gainBonusFactor) {
125//              this.nrOfNegativeExamples = nrOfNegativeExamples;
126//              nrOfExamples = nrOfPositiveExamples + nrOfNegativeExamples;
127//              this.expansionPenaltyFactor = expansionPenaltyFactor;
128//              this.gainBonusFactor = gainBonusFactor;
129//      }
130
131        @Override
132        public void init() throws ComponentInitException {
133                // nothing to do here
134        }       
135        
136        /* (non-Javadoc)
137         * @see java.util.Comparator#compare(java.lang.Object, java.lang.Object)
138         */
139        @Override
140        public int compare(ExampleBasedNode node1, ExampleBasedNode node2) {
141                double score1 = getNodeScore(node1);
142                double score2 = getNodeScore(node2);
143                double diff = score1 - score2;
144                if(diff>0)
145                        return 1;
146                else if(diff<0)
147                        return -1;
148                else
149                        // we cannot return 0 here otherwise different nodes/concepts with the
150                        // same score may be ignored (not added to a set because an equal element exists)
151                        return node1.getConcept().compareTo(node2.getConcept());
152        }
153
154        public double getNodeScore(ExampleBasedNode node) {
155                double accuracy = getWeightedAccuracy(node.getCoveredPositives().size(),node.getCoveredNegatives().size());
156                ExampleBasedNode parent = node.getParent();
157                double gain = 0;
158                if(parent != null) {
159                        double parentAccuracy =  getWeightedAccuracy(parent.getCoveredPositives().size(),parent.getCoveredNegatives().size());
160                        gain = accuracy - parentAccuracy;
161                } else {
162                        accuracy += startNodeBonus;
163                }
164                int he = node.getHorizontalExpansion() - getHeuristicLengthBonus(node.getConcept());
165                return accuracy + gainBonusFactor * gain - expansionPenaltyFactor * he - nodeChildPenalty * node.getChildren().size();
166        }
167        
168        private double getWeightedAccuracy(int coveredPositives, int coveredNegatives) {
169                return (coveredPositives + negativeWeight * (nrOfNegativeExamples - coveredNegatives))/(double)nrOfExamples;
170        }
171        
172        public static double getNodeScore(ExampleBasedNode node, int nrOfPositiveExamples, int nrOfNegativeExamples, double negativeWeight, double startNodeBonus, double expansionPenaltyFactor, int negationPenalty) {
173                MultiHeuristic multi = new MultiHeuristic(nrOfPositiveExamples, nrOfNegativeExamples, negativeWeight, startNodeBonus, expansionPenaltyFactor, negationPenalty);
174                return multi.getNodeScore(node);
175        }
176        
177        // this function can be used to give some constructs a length bonus
178        // compared to their syntactic length
179        private int getHeuristicLengthBonus(OWLClassExpression description) {
180                
181                
182                int bonus = 0;
183                
184                Set<OWLClassExpression> nestedClassExpressions = description.getNestedClassExpressions();
185                for (OWLClassExpression expression : nestedClassExpressions) {
186                        // do not count TOP symbols (in particular in ALL r.TOP and EXISTS r.TOP)
187                        // as they provide no extra information
188                        if(expression.isOWLThing())
189                                bonus = 1; //2;
190                        
191                        // we put a penalty on negations, because they often overfit
192                        // (TODO: make configurable)
193                        else if(expression instanceof OWLObjectComplementOf) {
194                                bonus = -negationPenalty;
195                        }
196                        
197//                      if(OWLClassExpression instanceof BooleanValueRestriction)
198//                              bonus = -1;
199                        
200                        // some bonus for doubles because they are already penalised by length 3
201                        else if(expression instanceof OWLDataSomeValuesFrom) {
202//                              System.out.println(description);
203                                bonus = 3; //2;
204                        }
205                }
206                
207                return bonus;
208        }
209
210    public double getExpansionPenaltyFactor() {
211        return expansionPenaltyFactor;
212    }
213
214    public void setExpansionPenaltyFactor(double expansionPenaltyFactor) {
215        this.expansionPenaltyFactor = expansionPenaltyFactor;
216    }
217
218        public int getNrOfNegativeExamples() {
219                return nrOfNegativeExamples;
220        }
221
222        public void setNrOfNegativeExamples(int nrOfNegativeExamples) {
223                this.nrOfNegativeExamples = nrOfNegativeExamples;
224        }
225
226        public int getNrOfExamples() {
227                return nrOfExamples;
228        }
229
230        public void setNrOfExamples(int nrOfExamples) {
231                this.nrOfExamples = nrOfExamples;
232        }
233
234        public double getGainBonusFactor() {
235                return gainBonusFactor;
236        }
237
238        public void setGainBonusFactor(double gainBonusFactor) {
239                this.gainBonusFactor = gainBonusFactor;
240        }
241
242        public double getNodeChildPenalty() {
243                return nodeChildPenalty;
244        }
245
246        public void setNodeChildPenalty(double nodeChildPenalty) {
247                this.nodeChildPenalty = nodeChildPenalty;
248        }
249
250        public double getStartNodeBonus() {
251                return startNodeBonus;
252        }
253
254        public void setStartNodeBonus(double startNodeBonus) {
255                this.startNodeBonus = startNodeBonus;
256        }
257
258        public double getNegativeWeight() {
259                return negativeWeight;
260        }
261
262        public void setNegativeWeight(double negativeWeight) {
263                this.negativeWeight = negativeWeight;
264        }
265
266        public int getNegationPenalty() {
267                return negationPenalty;
268        }
269
270        public void setNegationPenalty(int negationPenalty) {
271                this.negationPenalty = negationPenalty;
272        }
273}