001/**
002 * Copyright (C) 2007-2011, Jens Lehmann
003 *
004 * This file is part of DL-Learner.
005 *
006 * DL-Learner is free software; you can redistribute it and/or modify
007 * it under the terms of the GNU General Public License as published by
008 * the Free Software Foundation; either version 3 of the License, or
009 * (at your option) any later version.
010 *
011 * DL-Learner is distributed in the hope that it will be useful,
012 * but WITHOUT ANY WARRANTY; without even the implied warranty of
013 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
014 * GNU General Public License for more details.
015 *
016 * You should have received a copy of the GNU General Public License
017 * along with this program.  If not, see <http://www.gnu.org/licenses/>.
018 */
019
020package org.dllearner.experiments;
021
022import com.google.common.collect.Sets;
023import org.apache.log4j.Logger;
024import org.dllearner.utilities.URLencodeUTF8;
025
026import java.io.FileWriter;
027import java.text.DecimalFormat;
028import java.util.Arrays;
029import java.util.Collection;
030import java.util.SortedSet;
031import java.util.TreeSet;
032
033/**
034 * a container for examples used for operations like randomization
035 * 
036 * @author Sebastian Hellmann <hellmann@informatik.uni-leipzig.de>
037 * 
038 */
039public class Examples {
040        private static final Logger logger = Logger.getLogger(Examples.class);
041        public static DecimalFormat df1 = new DecimalFormat("00.#%");
042        public static DecimalFormat df2 = new DecimalFormat("00.##%");
043        public static DecimalFormat df3 = new DecimalFormat("00.###%");
044        private DecimalFormat myDf = df2;
045
046        private final SortedSet<String> posTrain = new TreeSet<>();
047        private final SortedSet<String> negTrain = new TreeSet<>();
048        private final SortedSet<String> posTest = new TreeSet<>();
049        private final SortedSet<String> negTest = new TreeSet<>();
050
051        /**
052         * default constructor
053         */
054        public Examples() {
055        }
056
057        /**
058         * constructor to add training examples
059         * 
060         * @param posTrain
061         * @param negTrain
062         */
063        public Examples(SortedSet<String> posTrain, SortedSet<String> negTrain) {
064                this.addPosTrain(posTrain);
065                this.addNegTrain(negTrain);
066        }
067
068        /**
069         * adds all examples, doublettes are removed automatically
070         * 
071         * @param posTrain
072         * @param negTrain
073         * @param posTest
074         * @param negTest
075         */
076        public Examples(SortedSet<String> posTrain, SortedSet<String> negTrain, SortedSet<String> posTest,
077                        SortedSet<String> negTest) {
078                this.addPosTrain(posTrain);
079                this.addPosTest(posTest);
080                this.addNegTrain(negTrain);
081                this.addNegTest(negTest);
082        }
083
084        /**
085         * calculates precision based on the test set removes all training data from
086         * retrieved first
087         * 
088         * @param retrieved
089         * @return
090         */
091        public double precision(SortedSet<String> retrieved) {
092                if (retrieved.size() == 0) {
093                        return 0.0d;
094                }
095                SortedSet<String> retrievedClean = new TreeSet<>(retrieved);
096                retrievedClean.removeAll(posTrain);
097                retrievedClean.removeAll(negTrain);
098
099                int posAsPos = Sets.intersection(retrievedClean, getPosTest()).size();
100                return ((double) posAsPos) / ((double) retrievedClean.size());
101        }
102
103        /**
104         * calculates recall based on the test set
105         * 
106         * 
107         * @param retrieved
108         * @return
109         */
110        public double recall(SortedSet<String> retrieved) {
111                if (sizeTotalOfPositives() == 0) {
112                        return 0.0d;
113                }
114                int posAsPos = Sets.intersection(getPosTest(), retrieved).size();
115                return ((double) posAsPos) / ((double) posTest.size());
116        }
117
118        private void _remove(String toBeRemoved) {
119                _removeAll(Arrays.asList(toBeRemoved));
120        }
121
122        private void _removeAll(Collection<String> toBeRemoved) {
123                if (posTrain.removeAll(toBeRemoved) || negTrain.removeAll(toBeRemoved)
124                                || posTest.removeAll(toBeRemoved) || negTest.removeAll(toBeRemoved)) {
125                        logger.warn("There has been some overlap in the examples, but it was removed automatically");
126                }
127        }
128
129        public void addPosTrain(Collection<String> pos) {
130                _removeAll(pos);
131                posTrain.addAll(pos);
132        }
133
134        public void addPosTest(Collection<String> pos) {
135                _removeAll(pos);
136                posTest.addAll(pos);
137        }
138
139        public void addNegTrain(Collection<String> neg) {
140                _removeAll(neg);
141                negTrain.addAll(neg);
142        }
143
144        public void addNegTest(Collection<String> neg) {
145                _removeAll(neg);
146                negTest.addAll(neg);
147        }
148
149        public void addPosTrain(String pos) {
150                _remove(pos);
151                posTrain.add(pos);
152        }
153
154        public void addPosTest(String pos) {
155                _remove(pos);
156                posTest.add(pos);
157        }
158
159        public void addNegTrain(String neg) {
160                _remove(neg);
161                negTrain.add(neg);
162        }
163
164        public void addNegTest(String neg) {
165                _remove(neg);
166                negTest.add(neg);
167        }
168
169        public boolean checkConsistency() {
170                for (String one : posTrain) {
171                        if (negTrain.contains(one)) {
172                                logger.error("positve and negative example overlap " + one);
173                                return false;
174                        }
175                }
176                return true;
177        }
178
179        @Override
180        public String toString() {
181                String ret = "Total: " + size();
182                double posPercent = posTrain.size() / (double) sizeTotalOfPositives();
183                double negPercent = negTrain.size() / (double) sizeTotalOfNegatives();
184                ret += "\nPositive: " + posTrain.size() + " | " + posTest.size() + " (" + myDf.format(posPercent)
185                                + ")";
186                ret += "\nNegative: " + negTrain.size() + " | " + negTest.size() + " (" + myDf.format(negPercent)
187                                + ")";
188
189                return ret;
190        }
191
192        public String toFullString() {
193
194                String ret = "Training:\n";
195                for (String one : posTrain) {
196                        ret += "+\"" + one + "\"\n";
197                }
198                for (String one : negTrain) {
199                        ret += "-\"" + one + "\"\n";
200                }
201                ret += "Testing:\n";
202                for (String one : posTest) {
203                        ret += "+\"" + one + "\"\n";
204                }
205                for (String one : negTest) {
206                        ret += "-\"" + one + "\"\n";
207                }
208
209                return ret + this.toString();
210
211        }
212
213        public void writeExamples(String filename) {
214                try {
215                        FileWriter a = new FileWriter(filename, false);
216
217                        StringBuffer buffer = new StringBuffer();
218                        buffer.append("\n\n\n\n\n");
219                        for (String s : posTrain) {
220                                a.write("import(\"" + URLencodeUTF8.encode(s) + "\");\n");
221                                buffer.append("+\"").append(s).append("\"\n");
222                        }
223                        for (String s : negTrain) {
224                                a.write("import(\"" + URLencodeUTF8.encode(s) + "\");\n");
225                                buffer.append("-\"").append(s).append("\"\n");
226                        }
227
228                        a.write(buffer.toString());
229                        a.flush();
230                        a.close();
231                        logger.info("wrote examples to " + filename);
232                } catch (Exception e) {
233                        e.printStackTrace();
234                }
235        }
236
237        /**
238         * sum of training and test data
239         * @return
240         */
241        public int size() {
242                return posTrain.size() + negTrain.size() + posTest.size() + negTest.size();
243        }
244
245        public int sizeTotalOfPositives() {
246                return posTrain.size() + posTest.size();
247        }
248
249        public int sizeTotalOfNegatives() {
250                return negTrain.size() + negTest.size();
251        }
252
253        public int sizeOfTrainingSets() {
254                return posTrain.size() + negTrain.size();
255        }
256
257        public int sizeOfTestSets() {
258                return posTest.size() + negTest.size();
259        }
260
261        public SortedSet<String> getAllExamples() {
262                SortedSet<String> total = new TreeSet<>();
263                total.addAll(getPositiveExamples());
264                total.addAll(getNegativeExamples());
265                return total;
266        }
267
268        public SortedSet<String> getPositiveExamples() {
269                SortedSet<String> total = new TreeSet<>();
270                total.addAll(posTrain);
271                total.addAll(posTest);
272                return total;
273        }
274
275        public SortedSet<String> getNegativeExamples() {
276                SortedSet<String> total = new TreeSet<>();
277                total.addAll(negTrain);
278                total.addAll(negTest);
279                return total;
280        }
281
282        public SortedSet<String> getTestExamples() {
283                SortedSet<String> total = new TreeSet<>();
284                total.addAll(posTest);
285                total.addAll(negTest);
286                return total;
287        }
288
289        public SortedSet<String> getTrainExamples() {
290                SortedSet<String> total = new TreeSet<>();
291                total.addAll(posTrain);
292                total.addAll(negTrain);
293                return total;
294        }
295
296        public SortedSet<String> getPosTrain() {
297                return posTrain;
298        }
299
300        public SortedSet<String> getNegTrain() {
301                return negTrain;
302        }
303
304        public SortedSet<String> getPosTest() {
305                return posTest;
306        }
307
308        public SortedSet<String> getNegTest() {
309                return negTest;
310        }
311
312}