001/**
002 * Copyright (C) 2007-2011, Jens Lehmann
003 *
004 * This file is part of DL-Learner.
005 *
006 * DL-Learner is free software; you can redistribute it and/or modify
007 * it under the terms of the GNU General Public License as published by
008 * the Free Software Foundation; either version 3 of the License, or
009 * (at your option) any later version.
010 *
011 * DL-Learner is distributed in the hope that it will be useful,
012 * but WITHOUT ANY WARRANTY; without even the implied warranty of
013 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
014 * GNU General Public License for more details.
015 *
016 * You should have received a copy of the GNU General Public License
017 * along with this program.  If not, see <http://www.gnu.org/licenses/>.
018 */
019
020package org.dllearner.experiments;
021
022import java.util.ArrayList;
023import java.util.List;
024import java.util.Random;
025
026/**
027 * used to randomize examples and split them into training and test sets
028 * gets a percentage of the examples 
029 * @author Sebastian Hellmann <hellmann@informatik.uni-leipzig.de>
030 *
031 */
032public class ExMakerRandomizer {
033//      private static Logger logger = Logger.getLogger(ExMakerRandomizer.class);
034
035        private final Examples examples;
036        
037        public ExMakerRandomizer(Examples examples ){
038                this.examples = examples;
039        }
040        
041        public static void main(String[] args) {
042                Examples ex = new Examples();
043                long n = System.currentTimeMillis();
044                for (int i = 0; i < 100000; i++) {
045                        ex.addPosTrain("p"+i);
046                        ex.addNegTrain("n"+i);
047                }
048                
049                ExMakerRandomizer r = new ExMakerRandomizer(ex);
050                ex = r.split(0.7d);
051                System.out.println("needed: "+(System.currentTimeMillis()-n)+ " ms");
052                System.out.println(ex.toString());
053                
054        }
055        
056        
057        /**
058         * Not quite exact, but fast
059         * has an error of 0.1 % 
060         * @param percentageOfTrainingSet
061         * @return
062         */
063        public Examples split(double percentageOfTrainingSet){
064                int sizeOfPosTrainingSet = (int)Math.floor(((double)examples.sizeTotalOfPositives())*percentageOfTrainingSet);
065                int sizeOfNegTrainingSet = (int)Math.floor(((double)examples.sizeTotalOfNegatives())*percentageOfTrainingSet);
066                
067                int sizeOfPosTestSet = examples.sizeTotalOfPositives()-sizeOfPosTrainingSet;
068                int sizeOfNegTestSet = examples.sizeTotalOfNegatives()-sizeOfNegTrainingSet;
069                
070//              System.out.println(sizeOfPosTrainingSet);
071//              System.out.println(sizeOfNegTrainingSet);
072//              System.out.println(sizeOfPosTestSet);
073//              System.out.println(sizeOfNegTestSet);
074                
075                List<String> posRemaining = new ArrayList<>(examples.getPositiveExamples());
076                List<String> negRemaining  = new ArrayList<>(examples.getNegativeExamples());
077        
078                Random r = new Random();
079                Examples ret = new Examples();
080                for (String one : posRemaining) {
081                        if (ret.getPosTrain().size() > sizeOfPosTrainingSet) {
082                                ret.addPosTest(one);
083                                continue;
084                        }
085                        if (ret.getPosTest().size() > sizeOfPosTestSet) {
086                                ret.addPosTrain(one);
087                                continue;
088                        }
089
090                        if (r.nextDouble() < percentageOfTrainingSet) {
091                                ret.addPosTrain(one);
092                        } else {
093                                ret.addPosTest(one);
094                        }
095
096                }
097                for (String one : negRemaining) {
098                        if (ret.getNegTrain().size() > sizeOfNegTrainingSet) {
099                                ret.addNegTest(one);
100                                continue;
101                        }
102                        if (ret.getNegTest().size() > sizeOfNegTestSet) {
103                                ret.addNegTrain(one);
104                                continue;
105                        }
106
107                        if (r.nextDouble() < percentageOfTrainingSet) {
108                                ret.addNegTrain(one);
109                        } else {
110                                ret.addNegTest(one);
111                        }
112
113                }
114                
115//              Collections.shuffle(posRemaining);
116//              Collections.shuffle(negRemaining);
117//              
118//              List<String> newPos = new ArrayList<String>();
119//              List<String> newNeg = new ArrayList<String>();
120//              
121//              Examples ret = new Examples();
122//              String one;
123//              while (posRemaining.size()>sizeOfPosTrainingSet){
124//                      one = posRemaining.remove(posRemaining.size()-1);
125//                      newPos.add(one);
126//                      
127//              }
128//              
129//              ret.addPosTest(newPos);
130//              ret.addPosTrain(posRemaining);
131//              
132//              while (negRemaining.size()>sizeOfNegTrainingSet){
133//                      one = negRemaining.remove(negRemaining.size()-1);
134//                      newNeg.add(one);
135//              }
136//              
137//              ret.addNegTest(newNeg);
138//              ret.addNegTrain(negRemaining);
139//              
140//              double posPercent = ret.getPosTrain().size()/(double)examples.getPositiveExamples().size();
141//              double negPercent = ret.getNegTrain().size()/(double)examples.getNegativeExamples().size();
142//              
143////            if there is more than a 10% error
144//              if(Math.abs(posPercent - percentageOfTrainingSet)>0.1d || Math.abs(negPercent - percentageOfTrainingSet)>0.1d ){
145//                      logger.info("repeating, unevenly matched");
146//                      return split(percentageOfTrainingSet);
147//              }
148                return ret;
149        }
150        
151        
152        
153}