001/** 002 * Copyright (C) 2007 - 2016, Jens Lehmann 003 * 004 * This file is part of DL-Learner. 005 * 006 * DL-Learner is free software; you can redistribute it and/or modify 007 * it under the terms of the GNU General Public License as published by 008 * the Free Software Foundation; either version 3 of the License, or 009 * (at your option) any later version. 010 * 011 * DL-Learner is distributed in the hope that it will be useful, 012 * but WITHOUT ANY WARRANTY; without even the implied warranty of 013 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 014 * GNU General Public License for more details. 015 * 016 * You should have received a copy of the GNU General Public License 017 * along with this program. If not, see <http://www.gnu.org/licenses/>. 018 */ 019package org.dllearner.learningproblems; 020 021/** 022 * Implementation of various heuristics. The methods can be used in learning 023 * problems and various evaluation scripts. They are verified in unit tests 024 * and, thus, should be fairly stable. 025 * 026 * @author Jens Lehmann 027 * 028 */ 029public class Heuristics { 030 031 public enum HeuristicType { PRED_ACC, AMEASURE, JACCARD, FMEASURE, GEN_FMEASURE, ENTROPY, MATTHEWS_CORRELATION, YOUDEN_INDEX } 032 033 /** 034 * Computes F1-Score. 035 * @param recall Recall. 036 * @param precision Precision. 037 * @return Harmonic mean of precision and recall. 038 */ 039 public static double getFScore(double recall, double precision) { 040 return (precision + recall == 0) ? 0 : 041 ( 2 * (precision * recall) / (precision + recall) ); 042 } 043 044 /** 045 * Computes F-beta-Score. 046 * @param recall Recall. 047 * @param precision Precision. 048 * @param beta Weights precision and recall. If beta is >1, then recall is more important 049 * than precision. 050 * @return Harmonic mean of precision and recall weighted by beta. 051 */ 052 public static double getFScore(double recall, double precision, double beta) { 053 return (precision + recall == 0) ? 0 : 054 ( (1+ beta * beta) * (precision * recall) 055 / (beta * beta * precision + recall) ); 056 } 057 058 public static double getFScoreBalanced(double recall, double precision, double beta) { 059 // balanced F measure 060 return (precision + recall == 0) ? 0 : 061 ( (1+Math.sqrt(beta)) * (precision * recall) 062 / (Math.sqrt(beta) * precision + recall) ); 063 } 064 /** 065 * Computes arithmetic mean of precision and recall, which is called "A-Score" 066 * here (A=arithmetic), but is not an established notion in machine learning. 067 * @param recall Recall. 068 * @param precision Precison. 069 * @return Arithmetic mean of precision and recall. 070 */ 071 public static double getAScore(double recall, double precision) { 072 return (recall + precision) / 2; 073 } 074 075 /** 076 * Computes arithmetic mean of precision and recall, which is called "A-Score" 077 * here (A=arithmetic), but is not an established notion in machine learning. 078 * @param recall Recall. 079 * @param precision Precison. 080 * @param beta Weights precision and recall. If beta is >1, then recall is more important 081 * than precision. 082 * @return Arithmetic mean of precision and recall. 083 */ 084 public static double getAScore(double recall, double precision, double beta) { 085 return (beta * recall + precision) / (beta + 1); 086 } 087 088 /** 089 * Computes the Jaccard coefficient of two sets. 090 * @param elementsIntersection Number of elements in the intersection of the two sets. 091 * @param elementsUnion Number of elements in the union of the two sets. 092 * @return #intersection divided by #union. 093 */ 094 public static double getJaccardCoefficient(int elementsIntersection, int elementsUnion) { 095 if(elementsIntersection > elementsUnion || elementsUnion < 1) { 096 throw new IllegalArgumentException(); 097 } 098 return elementsIntersection / (double) elementsUnion; 099 } 100 101 public static double getPredictiveAccuracy(int nrOfExamples, int nrOfPosClassifiedPositives, int nrOfNegClassifiedNegatives) { 102 return (nrOfPosClassifiedPositives + nrOfNegClassifiedNegatives) / (double) nrOfExamples; 103 } 104 105 public static double getPredictiveAccuracy(int nrOfPosExamples, int nrOfNegExamples, int nrOfPosClassifiedPositives, int nrOfNegClassifiedNegatives, double beta) { 106 return (nrOfPosClassifiedPositives + beta * nrOfNegClassifiedNegatives) / (nrOfPosExamples + beta * nrOfNegExamples); 107 } 108 109 public static double getPredictiveAccuracy2(int nrOfExamples, int nrOfPosClassifiedPositives, int nrOfPosClassifiedNegatives) { 110 return (nrOfPosClassifiedPositives + nrOfExamples - nrOfPosClassifiedNegatives) / (double) nrOfExamples; 111 } 112 113 public static double getPredictiveAccuracy2(int nrOfPosExamples, int nrOfNegExamples, int nrOfPosClassifiedPositives, int nrOfNegClassifiedNegatives, double beta) { 114 return (nrOfPosClassifiedPositives + beta * nrOfNegClassifiedNegatives) / (nrOfPosExamples + beta * nrOfNegExamples); 115 } 116 117 public static double getMatthewsCorrelationCoefficient(int tp, int fp, int tn, int fn) { 118 return (tp * tn - fp * fn) / Math.sqrt((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn)); 119 } 120 121 /** 122 * Computes the 95% confidence interval of an experiment with boolean outcomes, 123 * e.g. heads or tails coin throws. It uses the very efficient, but still accurate 124 * Wald method. 125 * @param success Number of successes, e.g. number of times the coin shows head. 126 * @param total Total number of tries, e.g. total number of times the coin was thrown. 127 * @return A two element double array, where element 0 is the lower border and element 128 * 1 the upper border of the 95% confidence interval. 129 */ 130 public static double[] getConfidenceInterval95Wald(int total, int success) { 131 if(success > total || total < 1) { 132 throw new IllegalArgumentException("95% confidence interval for " + success + " out of " + total + " trials cannot be estimated."); 133 } 134 double[] ret = new double[2]; 135 double p1 = (success+2)/(double)(total+4); 136 double p2 = 1.96 * Math.sqrt(p1*(1-p1)/(total+4)); 137 ret[0] = Math.max(0, p1 - p2); 138 ret[1] = Math.min(1, p1 + p2); 139 return ret; 140 } 141 142 /** 143 * Computes the 95% confidence interval average of an experiment with boolean outcomes, 144 * e.g. heads or tails coin throws. It uses the very efficient, but still accurate 145 * Wald method. 146 * @param success Number of successes, e.g. number of times the coin shows head. 147 * @param total Total number of tries, e.g. total number of times the coin was thrown. 148 * @return The average of the lower border and upper border of the 95% confidence interval. 149 */ 150 public static double getConfidenceInterval95WaldAverage(int total, int success) { 151 if(success > total || total < 1) { 152 throw new IllegalArgumentException("95% confidence interval for " + success + " out of " + total + " trials cannot be estimated."); 153 } 154 double[] interval = getConfidenceInterval95Wald(total, success); 155 return (interval[0] + interval[1]) / 2; 156 } 157 158 /** 159 * Computes whether a hypothesis is too weak, i.e. it has more errors on the positive examples 160 * than allowed by the noise parameter. 161 * @param nrOfPositiveExamples The number of positive examples in the learning problem. 162 * @param nrOfPosClassifiedPositives The number of positive examples, which were indeed classified as positive by the hypothesis. 163 * @param noise The noise parameter is a value between 0 and 1, which indicates how noisy the example data is (0 = no noise, 1 = completely random). 164 * If a hypothesis contains more errors on the positive examples than the noise value multiplied by the 165 * number of all examples, then the hypothesis is too weak. 166 * @return True if the hypothesis is too weak and false otherwise. 167 */ 168 public boolean isTooWeak(int nrOfPositiveExamples, int nrOfPosClassifiedPositives, double noise) { 169 if(noise < 0 || noise > 1 || nrOfPosClassifiedPositives <= nrOfPositiveExamples || nrOfPositiveExamples < 1) { 170 throw new IllegalArgumentException(); 171 } 172 return (noise * nrOfPositiveExamples) < (nrOfPositiveExamples - nrOfPosClassifiedPositives); 173 } 174 175 /** 176 * Computes whether a hypothesis is too weak, i.e. it has more errors on the positive examples 177 * than allowed by the noise parameter. 178 * @param nrOfPositiveExamples The number of positive examples in the learning problem. 179 * @param nrOfNegClassifiedPositives The number of positive examples, which were indeed classified as negative by the hypothesis. 180 * @param noise The noise parameter is a value between 0 and 1, which indicates how noisy the example data is (0 = no noise, 1 = completely random). 181 * If a hypothesis contains more errors on the positive examples than the noise value multiplied by the 182 * number of all examples, then the hypothesis is too weak. 183 * @return True if the hypothesis is too weak and false otherwise. 184 */ 185 public boolean isTooWeak2(int nrOfPositiveExamples, int nrOfNegClassifiedPositives, double noise) { 186 if(noise < 0 || noise > 1 || nrOfNegClassifiedPositives <= nrOfPositiveExamples || nrOfPositiveExamples < 1) { 187 throw new IllegalArgumentException(); 188 } 189 return (noise * nrOfPositiveExamples) < nrOfNegClassifiedPositives; 190 } 191 192 // see paper: p' 193 public static double p1(int success, int total) { 194 return (success+2)/(double)(total+4); 195 } 196 197 // see paper: expression used in confidence interval estimation 198 public static double p3(double p1, int total) { 199 return 1.96 * Math.sqrt(p1*(1-p1)/(total+4)); 200 } 201 202 /** 203 * This method can be used to approximate F-Measure and thereby saving a lot of 204 * instance checks. It assumes that all positive examples (or instances of a class) 205 * have already been tested via instance checks, i.e. recall is already known and 206 * precision is approximated. 207 * @param nrOfPosClassifiedPositives Positive examples (instance of a class), which are classified as positives. 208 * @param recall The already known recall. 209 * @param beta Weights precision and recall. If beta is >1, then recall is more important 210 * than precision. 211 * @param nrOfRelevantInstances Number of relevant instances, i.e. number of instances, which 212 * would have been tested without approximations. TODO: relevant = pos + neg examples? 213 * @param nrOfInstanceChecks Performed instance checks for the approximation. 214 * @param nrOfSuccessfulInstanceChecks Number of successful performed instance checks. 215 * @return A two element array, where the first element is the computed F-beta score and the 216 * second element is the length of the 95% confidence interval around it. 217 */ 218 public static double[] getFScoreApproximation(int nrOfPosClassifiedPositives, double recall, double beta, int nrOfRelevantInstances, int nrOfInstanceChecks, int nrOfSuccessfulInstanceChecks) { 219 // compute 95% confidence interval 220 double[] interval = Heuristics.getConfidenceInterval95Wald(nrOfInstanceChecks, nrOfSuccessfulInstanceChecks); 221 // multiply by number of instances from which the random samples are drawn 222 double lowerBorder = interval[0] * nrOfRelevantInstances; 223 double upperBorder = interval[1] * nrOfRelevantInstances; 224 // compute F-Measure for both borders (lower value = higher F-Measure) 225 double fMeasureHigh = (1 + Math.sqrt(beta)) * (nrOfPosClassifiedPositives/(nrOfPosClassifiedPositives+lowerBorder)*recall) / (Math.sqrt(beta)*nrOfPosClassifiedPositives/(nrOfPosClassifiedPositives+lowerBorder)+recall); 226 double fMeasureLow = (1 + Math.sqrt(beta)) * (nrOfPosClassifiedPositives/(nrOfPosClassifiedPositives+upperBorder)*recall) / (Math.sqrt(beta)*nrOfPosClassifiedPositives/(nrOfPosClassifiedPositives+upperBorder)+recall); 227 double diff = fMeasureHigh - fMeasureLow; 228 // compute F-score for proportion ? 229 // double proportionInstanceChecks = successfulInstanceChecks / (double) nrOfInstanceChecks * nrOfRelevantInstances; // 230 // => don't do it for now, because the difference between proportion and center of interval is usually quite small 231 // for sufficiently small diffs 232 // return interval length and center 233 double[] ret = new double[2]; 234 ret[0] = fMeasureLow + 0.5 * diff; 235 ret[1] = diff; 236 return ret; 237 } 238 239 /** 240 * In the first step of the AScore approximation, we estimate recall (taking the factor 241 * beta into account). This is not much more than a wrapper around the modified Wald method. 242 * @param beta Weights precision and recall. If beta is >1, then recall is more important 243 * than precision. 244 * @param nrOfPosExamples Number of positive examples (or instances of the considered class). 245 * @param nrOfInstanceChecks Number of positive examples (or instances of the considered class) which have been checked. 246 * @param nrOfSuccessfulInstanceChecks Number of positive examples (or instances of the considered class), where the instance check returned true. 247 * @return A two element array, where the first element is the recall multiplied by beta and the 248 * second element is the length of the 95% confidence interval around it. 249 */ 250 public static double[] getAScoreApproximationStep1(double beta, int nrOfPosExamples, int nrOfInstanceChecks, int nrOfSuccessfulInstanceChecks) { 251 // the method is just a wrapper around a single confidence interval approximation; 252 // method approximates t * a / |R(A)| 253 double[] interval = Heuristics.getConfidenceInterval95Wald(nrOfSuccessfulInstanceChecks, nrOfInstanceChecks); 254 double diff = beta * (interval[1] - interval[0]); 255 double[] ret = new double[2]; 256 ret[0] = beta * interval[0] + 0.5*diff; 257 ret[1] = diff; 258 return ret; 259 } 260 261 /** 262 * In step 2 of the A-Score approximation, the precision and overall A-Score is estimated based on 263 * the estimated recall. 264 * @param nrOfPosClassifiedPositives Positive examples (instance of a class), which are classified as positives. 265 * @param recallInterval The estimated recall, which needs to be given as a two element array with the first element being the mean value and the second element being the length of the interval (to be compatible with the step1 method). 266 * @param beta Weights precision and recall. If beta is >1, then recall is more important 267 * than precision. 268 * @param nrOfRelevantInstances Number of relevant instances, i.e. number of instances, which 269 * would have been tested without approximations. 270 * @param nrOfInstanceChecks Performed instance checks for the approximation. 271 * @param nrOfSuccessfulInstanceChecks Number of performed instance checks, which returned true. 272 * @return A two element array, where the first element is the estimated A-Score and the 273 * second element is the length of the 95% confidence interval around it. 274 */ 275 public static double[] getAScoreApproximationStep2(int nrOfPosClassifiedPositives, double[] recallInterval, double beta, int nrOfRelevantInstances, int nrOfInstanceChecks, int nrOfSuccessfulInstanceChecks) { 276 // recall interval is given as mean + interval size (to fit the other method calls) => computer lower and upper border 277 double recallLowerBorder = (recallInterval[0] - 0.5*recallInterval[1]) / beta; 278 double recallUpperBorder = (recallInterval[0] + 0.5*recallInterval[1]) / beta; 279 // estimate precision 280 double[] interval = Heuristics.getConfidenceInterval95Wald(nrOfInstanceChecks, nrOfSuccessfulInstanceChecks); 281 282 double precisionLowerBorder = nrOfPosClassifiedPositives / (nrOfPosClassifiedPositives + interval[1] * nrOfRelevantInstances); 283 double precisionUpperBorder = nrOfPosClassifiedPositives / (nrOfPosClassifiedPositives + interval[0] * nrOfRelevantInstances); 284 285// System.out.println("rec low: " + recallLowerBorder); 286// System.out.println("rec up: " + recallUpperBorder); 287// System.out.println("prec low: " + precisionLowerBorder); 288// System.out.println("prec up: " + precisionUpperBorder); 289 double lowerBorder = Heuristics.getAScore(recallLowerBorder, precisionLowerBorder, beta); 290 double upperBorder = Heuristics.getAScore(recallUpperBorder, precisionUpperBorder, beta); 291 double diff = upperBorder - lowerBorder; 292 double[] ret = new double[2]; 293 ret[0] = lowerBorder + 0.5*diff; 294 ret[1] = diff; 295 return ret; 296 } 297 298 // WARNING: unstable/untested 299 // uses the following formula: (|R(C) \cap E^+| + beta * |E^- \ R(C)|) / (|E^+|+|E^-|) 300 // approximates |R(C) \cap E^+| and beta * |E^- \ R(C)| separately; and adds their lower and upper borders (pessimistic estimate) 301 // TODO: only works well if there are many negatives at the moment, so speedup is not great 302 public static double[] getPredAccApproximation(int nrOfPositiveExamples, int nrOfNegativeExamples, double beta, int nrOfPosExampleInstanceChecks, int nrOfSuccessfulPosExampleChecks, int nrOfNegExampleInstanceChecks, int nrOfNegativeNegExampleChecks) { 303 // compute both 95% confidence intervals 304 double[] intervalPos = Heuristics.getConfidenceInterval95Wald(nrOfPosExampleInstanceChecks, nrOfSuccessfulPosExampleChecks); 305 double[] intervalNeg = Heuristics.getConfidenceInterval95Wald(nrOfNegExampleInstanceChecks, nrOfNegativeNegExampleChecks); 306 // multiply by number of instances from which the random samples are drawn 307 double lowerBorder = intervalPos[0] * nrOfPositiveExamples + beta * intervalNeg[0] * nrOfNegativeExamples; 308 double upperBorder = intervalNeg[1] * nrOfPositiveExamples + beta * intervalNeg[1] * nrOfNegativeExamples; 309 double predAccLow = lowerBorder / (nrOfPositiveExamples + beta * nrOfNegativeExamples); 310 double predAccHigh = upperBorder / (nrOfPositiveExamples + beta * nrOfNegativeExamples); 311 double diff = predAccHigh - predAccLow; 312 // return interval length and center 313 double[] ret = new double[2]; 314 ret[0] = predAccLow + 0.5 * diff; 315 ret[1] = diff; 316 return ret; 317 } 318 319 public static double divideOrZero(int numerator, int denominator) { 320 return denominator == 0 ? 0 : numerator / (double)denominator; 321 } 322 323}