001/** 002 * Copyright (C) 2007-2011, Jens Lehmann 003 * 004 * This file is part of DL-Learner. 005 * 006 * DL-Learner is free software; you can redistribute it and/or modify 007 * it under the terms of the GNU General Public License as published by 008 * the Free Software Foundation; either version 3 of the License, or 009 * (at your option) any later version. 010 * 011 * DL-Learner is distributed in the hope that it will be useful, 012 * but WITHOUT ANY WARRANTY; without even the implied warranty of 013 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 014 * GNU General Public License for more details. 015 * 016 * You should have received a copy of the GNU General Public License 017 * along with this program. If not, see <http://www.gnu.org/licenses/>. 018 */ 019 020package org.dllearner.experiments; 021 022import org.apache.log4j.Logger; 023 024import java.util.*; 025 026public class ExMakerCrossFolds { 027 private static Logger logger = Logger.getLogger(ExMakerCrossFolds.class); 028 029 private final Examples examples; 030 031 public static int minElementsPerFold = 6; 032 033 public ExMakerCrossFolds(Examples examples){ 034 this.examples = examples; 035 } 036 037 public static void main(String[] args) { 038 Examples ex = new Examples(); 039 040 for (int i = 0; i < 10000; i++) { 041 ex.addPosTrain("p"+i); 042 ex.addNegTrain("n"+i); 043 } 044 long n = System.currentTimeMillis(); 045 System.out.println("initial size: "+ex.size()); 046 ExMakerCrossFolds r = new ExMakerCrossFolds(ex); 047 List<Examples> l = r.splitLeaveOneOut(10); 048 printFolds(l ); 049 System.out.println(System.currentTimeMillis()-n); 050 051 052 } 053 public static void printFolds(List<Examples> l ){ 054 int i = 1; 055 int totalsize = 0; 056 StringBuffer b = new StringBuffer(); 057 b.append("Number of folds ").append(l.size()).append("\n"); 058 for (Examples examples : l) { 059 b.append("Fold: ").append(i++).append("\n"); 060 b.append(examples.toString()); 061 b.append("\n"); 062 063 totalsize+=examples.size(); 064 } 065 b.append("total size: ").append(totalsize); 066 logger.info(b.toString()); 067 } 068 069 070 public List<Examples> splitLeaveOneOut(int folds){ 071 if( folds*minElementsPerFold > examples.sizeTotalOfPositives() 072 || folds*minElementsPerFold > examples.sizeTotalOfNegatives() 073 ){ 074 logger.error("Too many folds for, too few data. cant spread: "); 075 logger.error(examples.sizeTotalOfPositives()+" examples over "+folds+" folds OR"); 076 logger.error(examples.sizeTotalOfNegatives()+" examples over "+folds+" folds"); 077 logger.error("each fold must have more than "+minElementsPerFold+" elements"); 078 return null; 079 } 080 081 List<Examples> foldSets = new ArrayList<>(); 082 double foldPercentage = 1.0d/((double)folds); 083 int tenPercentPos = (int)Math.floor(((double)examples.sizeTotalOfPositives())*foldPercentage); 084 int tenPercentNeg = (int)Math.floor(((double)examples.sizeTotalOfNegatives())*foldPercentage); 085 086 List<String> posRemaining = new ArrayList<>(examples.getPositiveExamples()); 087 List<String> negRemaining = new ArrayList<>(examples.getNegativeExamples()); 088 Collections.shuffle(posRemaining); 089 Collections.shuffle(negRemaining); 090 091 092 Examples tmp; 093// Examples oneFold; 094 for(int i = 0; i<folds;i++){ 095// logger.trace("Foldprogess: "+i+" of "+folds); 096 SortedSet<String> newPos = new TreeSet<>(); 097 SortedSet<String> newNeg = new TreeSet<>(); 098 String one = ""; 099 100 for(int a =0; a<tenPercentPos&& !posRemaining.isEmpty();a++){ 101 one = posRemaining.remove(posRemaining.size()-1); 102 newPos.add(one); 103 } 104 for(int a =0; a <tenPercentNeg&& !negRemaining.isEmpty() ; a++){ 105 one = negRemaining.remove(negRemaining.size()-1); 106 newNeg.add(one); 107 } 108 109 tmp = new Examples(); 110 tmp.addPosTrain(newPos); 111 tmp.addNegTrain(newNeg); 112 foldSets.add(tmp); 113 114 } 115 List<Examples> ret = new ArrayList<>(); 116 for(int i =0; i<foldSets.size();i++){ 117 Examples oneFold = new Examples(); 118 oneFold.addPosTest(foldSets.get(i).getPositiveExamples()); 119 oneFold.addNegTest(foldSets.get(i).getNegativeExamples()); 120 for(int a =0; a<foldSets.size();a++){ 121 if(a==i){ 122 continue; 123 }else{ 124 oneFold.addPosTrain(foldSets.get(a).getPositiveExamples()); 125 oneFold.addNegTrain(foldSets.get(a).getNegativeExamples()); 126 } 127 128 } 129 ret.add(oneFold); 130 } 131 132 return ret; 133 } 134 135}