001/**
002 * Copyright (C) 2007 - 2016, Jens Lehmann
003 *
004 * This file is part of DL-Learner.
005 *
006 * DL-Learner is free software; you can redistribute it and/or modify
007 * it under the terms of the GNU General Public License as published by
008 * the Free Software Foundation; either version 3 of the License, or
009 * (at your option) any later version.
010 *
011 * DL-Learner is distributed in the hope that it will be useful,
012 * but WITHOUT ANY WARRANTY; without even the implied warranty of
013 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
014 * GNU General Public License for more details.
015 *
016 * You should have received a copy of the GNU General Public License
017 * along with this program.  If not, see <http://www.gnu.org/licenses/>.
018 */
019package org.dllearner.learningproblems;
020
021/**
022 * Implementation of various heuristics. The methods can be used in learning
023 * problems and various evaluation scripts. They are verified in unit tests
024 * and, thus, should be fairly stable.
025 * 
026 * @author Jens Lehmann
027 * 
028 */
029public class Heuristics {
030
031        public enum HeuristicType { PRED_ACC, AMEASURE, JACCARD, FMEASURE, GEN_FMEASURE, ENTROPY, MATTHEWS_CORRELATION, YOUDEN_INDEX }
032
033        /**
034         * Computes F1-Score.
035         * @param recall Recall.
036         * @param precision Precision.
037         * @return Harmonic mean of precision and recall.
038         */
039        public static double getFScore(double recall, double precision) {
040                return (precision + recall == 0) ? 0 :
041                          ( 2 * (precision * recall) / (precision + recall) );
042        }
043        
044        /**
045         * Computes F-beta-Score.
046         * @param recall Recall.
047         * @param precision Precision.
048         * @param beta Weights precision and recall. If beta is >1, then recall is more important
049         * than precision.
050         * @return Harmonic mean of precision and recall weighted by beta.
051         */
052        public static double getFScore(double recall, double precision, double beta) {
053                return (precision + recall == 0) ? 0 :
054                          ( (1+ beta * beta) * (precision * recall)
055                                        / (beta * beta * precision + recall) );
056        }
057
058        public static double getFScoreBalanced(double recall, double precision, double beta) {
059                // balanced F measure
060                return (precision + recall == 0) ? 0 :
061                  ( (1+Math.sqrt(beta)) * (precision * recall)
062                                / (Math.sqrt(beta) * precision + recall) );
063        }
064        /**
065         * Computes arithmetic mean of precision and recall, which is called "A-Score"
066         * here (A=arithmetic), but is not an established notion in machine learning.
067         * @param recall Recall.
068         * @param precision Precison.
069         * @return Arithmetic mean of precision and recall.
070         */
071        public static double getAScore(double recall, double precision) {
072                return (recall + precision) / 2;
073        }
074
075        /**
076         * Computes arithmetic mean of precision and recall, which is called "A-Score"
077         * here (A=arithmetic), but is not an established notion in machine learning.
078         * @param recall Recall.
079         * @param precision Precison.
080         * @param beta Weights precision and recall. If beta is >1, then recall is more important
081         * than precision.
082         * @return Arithmetic mean of precision and recall.
083         */
084        public static double getAScore(double recall, double precision, double beta) {
085                return (beta * recall + precision) / (beta + 1);
086        }
087        
088        /**
089         * Computes the Jaccard coefficient of two sets.
090         * @param elementsIntersection Number of elements in the intersection of the two sets.
091         * @param elementsUnion Number of elements in the union of the two sets.
092         * @return #intersection divided by #union.
093         */
094        public static double getJaccardCoefficient(int elementsIntersection, int elementsUnion) {
095                if(elementsIntersection > elementsUnion || elementsUnion < 1) {
096                        throw new IllegalArgumentException();
097                }
098                return elementsIntersection / (double) elementsUnion;
099        }
100        
101        public static double getPredictiveAccuracy(int nrOfExamples, int nrOfPosClassifiedPositives, int nrOfNegClassifiedNegatives) {
102                return (nrOfPosClassifiedPositives + nrOfNegClassifiedNegatives) / (double) nrOfExamples;
103        }
104
105        public static double getPredictiveAccuracy(int nrOfPosExamples, int nrOfNegExamples, int nrOfPosClassifiedPositives, int nrOfNegClassifiedNegatives, double beta) {
106                return (nrOfPosClassifiedPositives + beta * nrOfNegClassifiedNegatives) / (nrOfPosExamples + beta * nrOfNegExamples);
107        }
108        
109        public static double getPredictiveAccuracy2(int nrOfExamples, int nrOfPosClassifiedPositives, int nrOfPosClassifiedNegatives) {
110                return (nrOfPosClassifiedPositives + nrOfExamples - nrOfPosClassifiedNegatives) / (double) nrOfExamples;
111        }
112        
113        public static double getPredictiveAccuracy2(int nrOfPosExamples, int nrOfNegExamples, int nrOfPosClassifiedPositives, int nrOfNegClassifiedNegatives, double beta) {
114                return (nrOfPosClassifiedPositives + beta * nrOfNegClassifiedNegatives) / (nrOfPosExamples + beta * nrOfNegExamples);
115        }
116        
117        public static double getMatthewsCorrelationCoefficient(int tp, int fp, int tn, int fn) {
118                return (tp * tn - fp * fn) / Math.sqrt((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn));
119        }
120        
121        /**
122         * Computes the 95% confidence interval of an experiment with boolean outcomes,
123         * e.g. heads or tails coin throws. It uses the very efficient, but still accurate
124         * Wald method.
125         * @param success Number of successes, e.g. number of times the coin shows head.
126         * @param total Total number of tries, e.g. total number of times the coin was thrown.
127         * @return A two element double array, where element 0 is the lower border and element
128         * 1 the upper border of the 95% confidence interval.
129         */
130        public static double[] getConfidenceInterval95Wald(int total, int success) {
131                if(success > total || total < 1) {
132                        throw new IllegalArgumentException("95% confidence interval for " + success + " out of " + total + " trials cannot be estimated.");
133                }
134                double[] ret = new double[2];
135                double p1 = (success+2)/(double)(total+4);
136                double p2 = 1.96 * Math.sqrt(p1*(1-p1)/(total+4));
137                ret[0] = Math.max(0, p1 - p2);
138                ret[1] = Math.min(1, p1 + p2);
139                return ret;
140        }
141        
142        /**
143         * Computes the 95% confidence interval average of an experiment with boolean outcomes,
144         * e.g. heads or tails coin throws. It uses the very efficient, but still accurate
145         * Wald method.
146         * @param success Number of successes, e.g. number of times the coin shows head.
147         * @param total Total number of tries, e.g. total number of times the coin was thrown.
148         * @return The average of the lower border and upper border of the 95% confidence interval.
149         */
150        public static double getConfidenceInterval95WaldAverage(int total, int success) {
151                if(success > total || total < 1) {
152                        throw new IllegalArgumentException("95% confidence interval for " + success + " out of " + total + " trials cannot be estimated.");
153                }
154                double[] interval = getConfidenceInterval95Wald(total, success);
155                return (interval[0] + interval[1]) / 2;
156        }
157        
158        /**
159         * Computes whether a hypothesis is too weak, i.e. it has more errors on the positive examples
160         * than allowed by the noise parameter.
161         * @param nrOfPositiveExamples The number of positive examples in the learning problem.
162         * @param nrOfPosClassifiedPositives The number of positive examples, which were indeed classified as positive by the hypothesis.
163         * @param noise The noise parameter is a value between 0 and 1, which indicates how noisy the example data is (0 = no noise, 1 = completely random).
164         * If a hypothesis contains more errors on the positive examples than the noise value multiplied by the
165         * number of all examples, then the hypothesis is too weak.
166         * @return True if the hypothesis is too weak and false otherwise.
167         */
168        public boolean isTooWeak(int nrOfPositiveExamples, int nrOfPosClassifiedPositives, double noise) {
169                if(noise < 0 || noise > 1 || nrOfPosClassifiedPositives <= nrOfPositiveExamples || nrOfPositiveExamples < 1) {
170                        throw new IllegalArgumentException();
171                }
172                return (noise * nrOfPositiveExamples) < (nrOfPositiveExamples - nrOfPosClassifiedPositives);
173        }
174
175        /**
176         * Computes whether a hypothesis is too weak, i.e. it has more errors on the positive examples
177         * than allowed by the noise parameter.
178         * @param nrOfPositiveExamples The number of positive examples in the learning problem.
179         * @param nrOfNegClassifiedPositives The number of positive examples, which were indeed classified as negative by the hypothesis.
180         * @param noise The noise parameter is a value between 0 and 1, which indicates how noisy the example data is (0 = no noise, 1 = completely random).
181         * If a hypothesis contains more errors on the positive examples than the noise value multiplied by the
182         * number of all examples, then the hypothesis is too weak.
183         * @return True if the hypothesis is too weak and false otherwise.
184         */
185        public boolean isTooWeak2(int nrOfPositiveExamples, int nrOfNegClassifiedPositives, double noise) {
186                if(noise < 0 || noise > 1 || nrOfNegClassifiedPositives <= nrOfPositiveExamples || nrOfPositiveExamples < 1) {
187                        throw new IllegalArgumentException();
188                }
189                return (noise * nrOfPositiveExamples) < nrOfNegClassifiedPositives;
190        }
191
192        // see paper: p'
193        public static double p1(int success, int total) {
194                return (success+2)/(double)(total+4);
195        }
196
197        // see paper: expression used in confidence interval estimation
198        public static double p3(double p1, int total) {
199                return 1.96 * Math.sqrt(p1*(1-p1)/(total+4));
200        }
201
202        /**
203         * This method can be used to approximate F-Measure and thereby saving a lot of
204         * instance checks. It assumes that all positive examples (or instances of a class)
205         * have already been tested via instance checks, i.e. recall is already known and
206         * precision is approximated.
207         * @param nrOfPosClassifiedPositives Positive examples (instance of a class), which are classified as positives.
208         * @param recall The already known recall.
209         * @param beta Weights precision and recall. If beta is >1, then recall is more important
210         * than precision.
211         * @param nrOfRelevantInstances Number of relevant instances, i.e. number of instances, which
212         * would have been tested without approximations. TODO: relevant = pos + neg examples?
213         * @param nrOfInstanceChecks Performed instance checks for the approximation.
214         * @param nrOfSuccessfulInstanceChecks Number of successful performed instance checks.
215         * @return A two element array, where the first element is the computed F-beta score and the
216         * second element is the length of the 95% confidence interval around it.
217         */
218        public static double[] getFScoreApproximation(int nrOfPosClassifiedPositives, double recall, double beta, int nrOfRelevantInstances, int nrOfInstanceChecks, int nrOfSuccessfulInstanceChecks) {
219                // compute 95% confidence interval
220                double[] interval = Heuristics.getConfidenceInterval95Wald(nrOfInstanceChecks, nrOfSuccessfulInstanceChecks);
221                // multiply by number of instances from which the random samples are drawn
222                double lowerBorder = interval[0] * nrOfRelevantInstances;
223                double upperBorder = interval[1] * nrOfRelevantInstances;
224                // compute F-Measure for both borders (lower value = higher F-Measure)
225                double fMeasureHigh = (1 + Math.sqrt(beta)) * (nrOfPosClassifiedPositives/(nrOfPosClassifiedPositives+lowerBorder)*recall) / (Math.sqrt(beta)*nrOfPosClassifiedPositives/(nrOfPosClassifiedPositives+lowerBorder)+recall);
226                double fMeasureLow = (1 + Math.sqrt(beta)) * (nrOfPosClassifiedPositives/(nrOfPosClassifiedPositives+upperBorder)*recall) / (Math.sqrt(beta)*nrOfPosClassifiedPositives/(nrOfPosClassifiedPositives+upperBorder)+recall);
227                double diff = fMeasureHigh - fMeasureLow;
228                // compute F-score for proportion ?
229                // double proportionInstanceChecks = successfulInstanceChecks / (double) nrOfInstanceChecks * nrOfRelevantInstances; //
230                // => don't do it for now, because the difference between proportion and center of interval is usually quite small
231                // for sufficiently small diffs
232                // return interval length and center
233                double[] ret = new double[2];
234                ret[0] = fMeasureLow + 0.5 * diff;
235                ret[1] = diff;
236                return ret;
237        }
238        
239        /**
240         * In the first step of the AScore approximation, we estimate recall (taking the factor
241         * beta into account). This is not much more than a wrapper around the modified Wald method.
242         * @param beta Weights precision and recall. If beta is >1, then recall is more important
243         * than precision.
244         * @param nrOfPosExamples Number of positive examples (or instances of the considered class).
245         * @param nrOfInstanceChecks Number of positive examples (or instances of the considered class) which have been checked.
246         * @param nrOfSuccessfulInstanceChecks Number of positive examples (or instances of the considered class), where the instance check returned true.
247         * @return A two element array, where the first element is the recall multiplied by beta and the
248         * second element is the length of the 95% confidence interval around it.
249         */
250        public static double[] getAScoreApproximationStep1(double beta, int nrOfPosExamples, int nrOfInstanceChecks, int nrOfSuccessfulInstanceChecks) {
251                // the method is just a wrapper around a single confidence interval approximation;
252                // method approximates t * a / |R(A)|
253                double[] interval = Heuristics.getConfidenceInterval95Wald(nrOfSuccessfulInstanceChecks, nrOfInstanceChecks);
254                double diff = beta * (interval[1] - interval[0]);
255                double[] ret = new double[2];
256                ret[0] = beta * interval[0] + 0.5*diff;
257                ret[1] = diff;
258                return ret;
259        }
260        
261        /**
262         * In step 2 of the A-Score approximation, the precision and overall A-Score is estimated based on
263         * the estimated recall.
264         * @param nrOfPosClassifiedPositives Positive examples (instance of a class), which are classified as positives.
265         * @param recallInterval The estimated recall, which needs to be given as a two element array with the first element being the mean value and the second element being the length of the interval (to be compatible with the step1 method).
266         * @param beta Weights precision and recall. If beta is >1, then recall is more important
267         * than precision.
268         * @param nrOfRelevantInstances Number of relevant instances, i.e. number of instances, which
269         * would have been tested without approximations.
270         * @param nrOfInstanceChecks Performed instance checks for the approximation.
271         * @param nrOfSuccessfulInstanceChecks Number of performed instance checks, which returned true.
272         * @return A two element array, where the first element is the estimated A-Score and the
273         * second element is the length of the 95% confidence interval around it.
274         */
275        public static double[] getAScoreApproximationStep2(int nrOfPosClassifiedPositives, double[] recallInterval, double beta, int nrOfRelevantInstances, int nrOfInstanceChecks, int nrOfSuccessfulInstanceChecks) {
276                // recall interval is given as mean + interval size (to fit the other method calls) => computer lower and upper border
277                double recallLowerBorder = (recallInterval[0] - 0.5*recallInterval[1]) / beta;
278                double recallUpperBorder = (recallInterval[0] + 0.5*recallInterval[1]) / beta;
279                // estimate precision
280                double[] interval = Heuristics.getConfidenceInterval95Wald(nrOfInstanceChecks, nrOfSuccessfulInstanceChecks);
281                
282                double precisionLowerBorder = nrOfPosClassifiedPositives / (nrOfPosClassifiedPositives + interval[1] * nrOfRelevantInstances);
283                double precisionUpperBorder = nrOfPosClassifiedPositives / (nrOfPosClassifiedPositives + interval[0] * nrOfRelevantInstances);
284                
285//              System.out.println("rec low: " + recallLowerBorder);
286//              System.out.println("rec up: " + recallUpperBorder);
287//              System.out.println("prec low: " + precisionLowerBorder);
288//              System.out.println("prec up: " + precisionUpperBorder);
289                double lowerBorder = Heuristics.getAScore(recallLowerBorder, precisionLowerBorder, beta);
290                double upperBorder = Heuristics.getAScore(recallUpperBorder, precisionUpperBorder, beta);
291                double diff = upperBorder - lowerBorder;
292                double[] ret = new double[2];
293                ret[0] = lowerBorder + 0.5*diff;
294                ret[1] = diff;
295                return ret;
296        }
297        
298        // WARNING: unstable/untested
299        // uses the following formula: (|R(C) \cap E^+| + beta * |E^- \ R(C)|) / (|E^+|+|E^-|)
300        // approximates |R(C) \cap E^+| and beta * |E^- \ R(C)| separately; and adds their lower and upper borders (pessimistic estimate)
301        // TODO: only works well if there are many negatives at the moment, so speedup is not great
302        public static double[] getPredAccApproximation(int nrOfPositiveExamples, int nrOfNegativeExamples, double beta, int nrOfPosExampleInstanceChecks, int nrOfSuccessfulPosExampleChecks, int nrOfNegExampleInstanceChecks, int nrOfNegativeNegExampleChecks) {
303                // compute both 95% confidence intervals
304                double[] intervalPos = Heuristics.getConfidenceInterval95Wald(nrOfPosExampleInstanceChecks, nrOfSuccessfulPosExampleChecks);
305                double[] intervalNeg = Heuristics.getConfidenceInterval95Wald(nrOfNegExampleInstanceChecks, nrOfNegativeNegExampleChecks);
306                // multiply by number of instances from which the random samples are drawn
307                double lowerBorder = intervalPos[0] * nrOfPositiveExamples + beta * intervalNeg[0] * nrOfNegativeExamples;
308                double upperBorder = intervalNeg[1] * nrOfPositiveExamples + beta * intervalNeg[1] * nrOfNegativeExamples;
309                double predAccLow = lowerBorder / (nrOfPositiveExamples + beta * nrOfNegativeExamples);
310                double predAccHigh = upperBorder / (nrOfPositiveExamples + beta * nrOfNegativeExamples);
311                double diff = predAccHigh - predAccLow;
312                // return interval length and center
313                double[] ret = new double[2];
314                ret[0] = predAccLow + 0.5 * diff;
315                ret[1] = diff;
316                return ret;
317        }
318
319        public static double divideOrZero(int numerator, int denominator) {
320                return denominator == 0 ? 0 : numerator / (double)denominator;
321        }
322        
323}