001/**
002 * Copyright (C) 2007-2011, Jens Lehmann
003 *
004 * This file is part of DL-Learner.
005 *
006 * DL-Learner is free software; you can redistribute it and/or modify
007 * it under the terms of the GNU General Public License as published by
008 * the Free Software Foundation; either version 3 of the License, or
009 * (at your option) any later version.
010 *
011 * DL-Learner is distributed in the hope that it will be useful,
012 * but WITHOUT ANY WARRANTY; without even the implied warranty of
013 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
014 * GNU General Public License for more details.
015 *
016 * You should have received a copy of the GNU General Public License
017 * along with this program.  If not, see <http://www.gnu.org/licenses/>.
018 */
019
020package org.dllearner.experiments;
021
022import org.apache.log4j.Logger;
023
024import java.util.*;
025
026public class ExMakerCrossFolds {
027        private static Logger logger = Logger.getLogger(ExMakerCrossFolds.class);
028
029        private final Examples examples;
030        
031        public static int minElementsPerFold = 6;
032        
033        public ExMakerCrossFolds(Examples examples){
034                this.examples = examples;
035        }
036        
037        public static void main(String[] args) {
038                Examples ex = new Examples();
039                
040                for (int i = 0; i < 10000; i++) {
041                        ex.addPosTrain("p"+i);
042                        ex.addNegTrain("n"+i);
043                }
044                long n = System.currentTimeMillis();
045                 System.out.println("initial size: "+ex.size());
046                 ExMakerCrossFolds r = new ExMakerCrossFolds(ex);
047                 List<Examples> l = r.splitLeaveOneOut(10);
048                 printFolds(l );
049                 System.out.println(System.currentTimeMillis()-n);
050                
051                
052        }
053        public static void printFolds(List<Examples> l ){
054                 int i = 1;
055                 int totalsize = 0;
056                 StringBuffer b = new StringBuffer();
057                 b.append("Number of folds ").append(l.size()).append("\n");
058                 for (Examples examples : l) {
059                         b.append("Fold: ").append(i++).append("\n");
060                         b.append(examples.toString());
061                         b.append("\n");
062                        
063                         totalsize+=examples.size();
064                }
065                 b.append("total size: ").append(totalsize);
066                 logger.info(b.toString());
067        }
068        
069        
070        public List<Examples> splitLeaveOneOut(int folds){
071                if(     folds*minElementsPerFold > examples.sizeTotalOfPositives()
072                                || folds*minElementsPerFold > examples.sizeTotalOfNegatives()
073                ){
074                        logger.error("Too many folds for, too few data. cant spread: ");
075                        logger.error(examples.sizeTotalOfPositives()+" examples over "+folds+" folds OR");
076                        logger.error(examples.sizeTotalOfNegatives()+" examples over "+folds+" folds");
077                        logger.error("each fold must have more than "+minElementsPerFold+" elements");
078                        return null;
079                }
080                
081                List<Examples> foldSets = new ArrayList<>();
082                double foldPercentage = 1.0d/((double)folds);
083                int tenPercentPos = (int)Math.floor(((double)examples.sizeTotalOfPositives())*foldPercentage);
084                int tenPercentNeg = (int)Math.floor(((double)examples.sizeTotalOfNegatives())*foldPercentage);
085                
086                List<String> posRemaining = new ArrayList<>(examples.getPositiveExamples());
087                List<String> negRemaining  = new ArrayList<>(examples.getNegativeExamples());
088                Collections.shuffle(posRemaining);
089                Collections.shuffle(negRemaining);
090                
091                
092                Examples tmp;
093//              Examples oneFold;
094                for(int i = 0; i<folds;i++){
095//                      logger.trace("Foldprogess: "+i+" of "+folds);
096                        SortedSet<String> newPos = new TreeSet<>();
097                        SortedSet<String> newNeg = new TreeSet<>();
098                        String one = "";
099
100                        for(int a =0; a<tenPercentPos&& !posRemaining.isEmpty();a++){
101                                one = posRemaining.remove(posRemaining.size()-1);
102                                newPos.add(one);
103                        }
104                        for(int a =0; a <tenPercentNeg&& !negRemaining.isEmpty() ; a++){
105                                one = negRemaining.remove(negRemaining.size()-1);
106                                newNeg.add(one);
107                        }
108
109                        tmp = new Examples();
110                        tmp.addPosTrain(newPos);
111                        tmp.addNegTrain(newNeg);
112                        foldSets.add(tmp);
113
114                }
115                List<Examples> ret = new ArrayList<>();
116                for(int i =0; i<foldSets.size();i++){
117                        Examples oneFold = new Examples();
118                        oneFold.addPosTest(foldSets.get(i).getPositiveExamples());
119                        oneFold.addNegTest(foldSets.get(i).getNegativeExamples());
120                        for(int a =0; a<foldSets.size();a++){
121                                if(a==i){
122                                        continue;
123                                }else{
124                                        oneFold.addPosTrain(foldSets.get(a).getPositiveExamples());
125                                        oneFold.addNegTrain(foldSets.get(a).getNegativeExamples());
126                                }
127                                
128                        }
129                        ret.add(oneFold);
130                }
131                
132                return ret;
133        }
134        
135}