001package org.dllearner.algorithms.miles;
002
003import java.util.ArrayList;
004import java.util.Collection;
005import java.util.List;
006import java.util.Random;
007
008import com.google.common.collect.Lists;
009
010import weka.core.Attribute;
011import weka.core.DenseInstance;
012import weka.core.Instances;
013
014public class FoldGenerator<T> {
015
016        List<T> examples;
017
018        private Collection<T> posExamples;
019        private Collection<T> negExamples;
020
021        Random random = new Random(123);
022
023        public FoldGenerator(Collection<T> posExamples, Collection<T> negExamples) {
024                this.posExamples = posExamples;
025                this.negExamples = negExamples;
026
027                examples = new ArrayList<>();
028                examples.addAll(posExamples);
029                examples.addAll(negExamples);
030
031//              Collections.shuffle(examples, random);
032                
033                System.out.println(examples);
034        }
035
036        public List<T> trainCV(int numFolds, int numFold) {
037                int numInstForFold, first, offset;
038                List<T> train;
039
040                if (numFolds < 2) {
041                        throw new IllegalArgumentException("Number of folds must be at least 2!");
042                }
043                if (numFolds > examples.size()) {
044                        throw new IllegalArgumentException("Can't have more folds than instances!");
045                }
046                numInstForFold = examples.size() / numFolds;
047                if (numFold < examples.size() % numFolds) {
048                        numInstForFold++;
049                        offset = numFold;
050                } else {
051                        offset = examples.size() % numFolds;
052                }
053                train = new ArrayList<>(examples.size() - numInstForFold);
054                first = numFold * (examples.size() / numFolds) + offset;
055                train.addAll(examples.subList(0, first));
056                int from = first + numInstForFold;
057                int size = examples.size() - first - numInstForFold;
058                int to = from + size;
059                train.addAll(examples.subList(from, to));
060                
061                return train;
062        }
063        
064        public List<T> testCV(int numFolds, int numFold) {
065                int numInstForFold, first, offset;
066                List<T> test;
067
068                if (numFolds < 2) {
069                        throw new IllegalArgumentException("Number of folds must be at least 2!");
070                }
071                if (numFolds > examples.size()) {
072                        throw new IllegalArgumentException("Can't have more folds than instances!");
073                }
074                numInstForFold = examples.size() / numFolds;
075                if (numFold < examples.size() % numFolds) {
076                        numInstForFold++;
077                        offset = numFold;
078                } else {
079                        offset = examples.size() % numFolds;
080                }
081                test = new ArrayList<>(numInstForFold);
082                first = numFold * (examples.size() / numFolds) + offset;
083                test.addAll(examples.subList(first, first + numInstForFold));
084                return test;
085        }
086        
087        public static void main(String[] args) throws Exception {
088                int numPos = 4;
089                int numNeg = 5;
090                List<String> posExamples = Lists.newArrayList();
091                List<String> negExamples = Lists.newArrayList();
092                for (int i = 1; i <= numPos; i++) {
093                        posExamples.add("p" + i);
094                }
095                for (int i = 1; i <= numNeg; i++) {
096                        negExamples.add("n" + i);
097                }
098                
099                FoldGenerator<String> foldGenerator = new FoldGenerator<>(posExamples, negExamples);
100                
101                int numFolds = 5;
102                
103                int numAttributes = 2;
104                int num = 9;
105                ArrayList<Attribute> attInfo = new ArrayList<>(numAttributes);
106                attInfo.add(new Attribute("C1"));
107                attInfo.add(new Attribute("t", Lists.newArrayList("0","1")));
108                        
109                Instances data = new Instances("rel", attInfo, num);
110                for (int i = 1; i <= num; i++) {
111                        data.add(new DenseInstance(1.0, new double[]{i,0}));
112                }
113//              System.out.println(data);
114                for (int i = 0; i < numFolds; i++) {
115                        System.out.println("Fold " + i);
116                        System.out.println("Train:" + foldGenerator.trainCV(numFolds, i));
117                        System.out.println("Test:" + foldGenerator.testCV(numFolds, i));
118                        
119//                      System.out.println(data.trainCV(numFolds, i));
120                }
121                
122                
123        }
124
125}