001/** 002 * Copyright (C) 2007-2011, Jens Lehmann 003 * 004 * This file is part of DL-Learner. 005 * 006 * DL-Learner is free software; you can redistribute it and/or modify 007 * it under the terms of the GNU General Public License as published by 008 * the Free Software Foundation; either version 3 of the License, or 009 * (at your option) any later version. 010 * 011 * DL-Learner is distributed in the hope that it will be useful, 012 * but WITHOUT ANY WARRANTY; without even the implied warranty of 013 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 014 * GNU General Public License for more details. 015 * 016 * You should have received a copy of the GNU General Public License 017 * along with this program. If not, see <http://www.gnu.org/licenses/>. 018 */ 019 020package org.dllearner.experiments; 021 022import com.google.common.collect.Sets; 023import org.apache.log4j.Logger; 024import org.dllearner.utilities.URLencodeUTF8; 025 026import java.io.FileWriter; 027import java.text.DecimalFormat; 028import java.util.Arrays; 029import java.util.Collection; 030import java.util.SortedSet; 031import java.util.TreeSet; 032 033/** 034 * a container for examples used for operations like randomization 035 * 036 * @author Sebastian Hellmann <hellmann@informatik.uni-leipzig.de> 037 * 038 */ 039public class Examples { 040 private static final Logger logger = Logger.getLogger(Examples.class); 041 public static DecimalFormat df1 = new DecimalFormat("00.#%"); 042 public static DecimalFormat df2 = new DecimalFormat("00.##%"); 043 public static DecimalFormat df3 = new DecimalFormat("00.###%"); 044 private DecimalFormat myDf = df2; 045 046 private final SortedSet<String> posTrain = new TreeSet<>(); 047 private final SortedSet<String> negTrain = new TreeSet<>(); 048 private final SortedSet<String> posTest = new TreeSet<>(); 049 private final SortedSet<String> negTest = new TreeSet<>(); 050 051 /** 052 * default constructor 053 */ 054 public Examples() { 055 } 056 057 /** 058 * constructor to add training examples 059 * 060 * @param posTrain 061 * @param negTrain 062 */ 063 public Examples(SortedSet<String> posTrain, SortedSet<String> negTrain) { 064 this.addPosTrain(posTrain); 065 this.addNegTrain(negTrain); 066 } 067 068 /** 069 * adds all examples, doublettes are removed automatically 070 * 071 * @param posTrain 072 * @param negTrain 073 * @param posTest 074 * @param negTest 075 */ 076 public Examples(SortedSet<String> posTrain, SortedSet<String> negTrain, SortedSet<String> posTest, 077 SortedSet<String> negTest) { 078 this.addPosTrain(posTrain); 079 this.addPosTest(posTest); 080 this.addNegTrain(negTrain); 081 this.addNegTest(negTest); 082 } 083 084 /** 085 * calculates precision based on the test set removes all training data from 086 * retrieved first 087 * 088 * @param retrieved 089 * @return 090 */ 091 public double precision(SortedSet<String> retrieved) { 092 if (retrieved.size() == 0) { 093 return 0.0d; 094 } 095 SortedSet<String> retrievedClean = new TreeSet<>(retrieved); 096 retrievedClean.removeAll(posTrain); 097 retrievedClean.removeAll(negTrain); 098 099 int posAsPos = Sets.intersection(retrievedClean, getPosTest()).size(); 100 return ((double) posAsPos) / ((double) retrievedClean.size()); 101 } 102 103 /** 104 * calculates recall based on the test set 105 * 106 * 107 * @param retrieved 108 * @return 109 */ 110 public double recall(SortedSet<String> retrieved) { 111 if (sizeTotalOfPositives() == 0) { 112 return 0.0d; 113 } 114 int posAsPos = Sets.intersection(getPosTest(), retrieved).size(); 115 return ((double) posAsPos) / ((double) posTest.size()); 116 } 117 118 private void _remove(String toBeRemoved) { 119 _removeAll(Arrays.asList(toBeRemoved)); 120 } 121 122 private void _removeAll(Collection<String> toBeRemoved) { 123 if (posTrain.removeAll(toBeRemoved) || negTrain.removeAll(toBeRemoved) 124 || posTest.removeAll(toBeRemoved) || negTest.removeAll(toBeRemoved)) { 125 logger.warn("There has been some overlap in the examples, but it was removed automatically"); 126 } 127 } 128 129 public void addPosTrain(Collection<String> pos) { 130 _removeAll(pos); 131 posTrain.addAll(pos); 132 } 133 134 public void addPosTest(Collection<String> pos) { 135 _removeAll(pos); 136 posTest.addAll(pos); 137 } 138 139 public void addNegTrain(Collection<String> neg) { 140 _removeAll(neg); 141 negTrain.addAll(neg); 142 } 143 144 public void addNegTest(Collection<String> neg) { 145 _removeAll(neg); 146 negTest.addAll(neg); 147 } 148 149 public void addPosTrain(String pos) { 150 _remove(pos); 151 posTrain.add(pos); 152 } 153 154 public void addPosTest(String pos) { 155 _remove(pos); 156 posTest.add(pos); 157 } 158 159 public void addNegTrain(String neg) { 160 _remove(neg); 161 negTrain.add(neg); 162 } 163 164 public void addNegTest(String neg) { 165 _remove(neg); 166 negTest.add(neg); 167 } 168 169 public boolean checkConsistency() { 170 for (String one : posTrain) { 171 if (negTrain.contains(one)) { 172 logger.error("positve and negative example overlap " + one); 173 return false; 174 } 175 } 176 return true; 177 } 178 179 @Override 180 public String toString() { 181 String ret = "Total: " + size(); 182 double posPercent = posTrain.size() / (double) sizeTotalOfPositives(); 183 double negPercent = negTrain.size() / (double) sizeTotalOfNegatives(); 184 ret += "\nPositive: " + posTrain.size() + " | " + posTest.size() + " (" + myDf.format(posPercent) 185 + ")"; 186 ret += "\nNegative: " + negTrain.size() + " | " + negTest.size() + " (" + myDf.format(negPercent) 187 + ")"; 188 189 return ret; 190 } 191 192 public String toFullString() { 193 194 String ret = "Training:\n"; 195 for (String one : posTrain) { 196 ret += "+\"" + one + "\"\n"; 197 } 198 for (String one : negTrain) { 199 ret += "-\"" + one + "\"\n"; 200 } 201 ret += "Testing:\n"; 202 for (String one : posTest) { 203 ret += "+\"" + one + "\"\n"; 204 } 205 for (String one : negTest) { 206 ret += "-\"" + one + "\"\n"; 207 } 208 209 return ret + this.toString(); 210 211 } 212 213 public void writeExamples(String filename) { 214 try { 215 FileWriter a = new FileWriter(filename, false); 216 217 StringBuffer buffer = new StringBuffer(); 218 buffer.append("\n\n\n\n\n"); 219 for (String s : posTrain) { 220 a.write("import(\"" + URLencodeUTF8.encode(s) + "\");\n"); 221 buffer.append("+\"").append(s).append("\"\n"); 222 } 223 for (String s : negTrain) { 224 a.write("import(\"" + URLencodeUTF8.encode(s) + "\");\n"); 225 buffer.append("-\"").append(s).append("\"\n"); 226 } 227 228 a.write(buffer.toString()); 229 a.flush(); 230 a.close(); 231 logger.info("wrote examples to " + filename); 232 } catch (Exception e) { 233 e.printStackTrace(); 234 } 235 } 236 237 /** 238 * sum of training and test data 239 * @return 240 */ 241 public int size() { 242 return posTrain.size() + negTrain.size() + posTest.size() + negTest.size(); 243 } 244 245 public int sizeTotalOfPositives() { 246 return posTrain.size() + posTest.size(); 247 } 248 249 public int sizeTotalOfNegatives() { 250 return negTrain.size() + negTest.size(); 251 } 252 253 public int sizeOfTrainingSets() { 254 return posTrain.size() + negTrain.size(); 255 } 256 257 public int sizeOfTestSets() { 258 return posTest.size() + negTest.size(); 259 } 260 261 public SortedSet<String> getAllExamples() { 262 SortedSet<String> total = new TreeSet<>(); 263 total.addAll(getPositiveExamples()); 264 total.addAll(getNegativeExamples()); 265 return total; 266 } 267 268 public SortedSet<String> getPositiveExamples() { 269 SortedSet<String> total = new TreeSet<>(); 270 total.addAll(posTrain); 271 total.addAll(posTest); 272 return total; 273 } 274 275 public SortedSet<String> getNegativeExamples() { 276 SortedSet<String> total = new TreeSet<>(); 277 total.addAll(negTrain); 278 total.addAll(negTest); 279 return total; 280 } 281 282 public SortedSet<String> getTestExamples() { 283 SortedSet<String> total = new TreeSet<>(); 284 total.addAll(posTest); 285 total.addAll(negTest); 286 return total; 287 } 288 289 public SortedSet<String> getTrainExamples() { 290 SortedSet<String> total = new TreeSet<>(); 291 total.addAll(posTrain); 292 total.addAll(negTrain); 293 return total; 294 } 295 296 public SortedSet<String> getPosTrain() { 297 return posTrain; 298 } 299 300 public SortedSet<String> getNegTrain() { 301 return negTrain; 302 } 303 304 public SortedSet<String> getPosTest() { 305 return posTest; 306 } 307 308 public SortedSet<String> getNegTest() { 309 return negTest; 310 } 311 312}