001package org.dllearner.algorithms.miles; 002 003import java.util.ArrayList; 004import java.util.Collection; 005import java.util.List; 006import java.util.Random; 007 008import com.google.common.collect.Lists; 009 010import weka.core.Attribute; 011import weka.core.DenseInstance; 012import weka.core.Instances; 013 014public class FoldGenerator<T> { 015 016 List<T> examples; 017 018 private Collection<T> posExamples; 019 private Collection<T> negExamples; 020 021 Random random = new Random(123); 022 023 public FoldGenerator(Collection<T> posExamples, Collection<T> negExamples) { 024 this.posExamples = posExamples; 025 this.negExamples = negExamples; 026 027 examples = new ArrayList<>(); 028 examples.addAll(posExamples); 029 examples.addAll(negExamples); 030 031// Collections.shuffle(examples, random); 032 033 System.out.println(examples); 034 } 035 036 public List<T> trainCV(int numFolds, int numFold) { 037 int numInstForFold, first, offset; 038 List<T> train; 039 040 if (numFolds < 2) { 041 throw new IllegalArgumentException("Number of folds must be at least 2!"); 042 } 043 if (numFolds > examples.size()) { 044 throw new IllegalArgumentException("Can't have more folds than instances!"); 045 } 046 numInstForFold = examples.size() / numFolds; 047 if (numFold < examples.size() % numFolds) { 048 numInstForFold++; 049 offset = numFold; 050 } else { 051 offset = examples.size() % numFolds; 052 } 053 train = new ArrayList<>(examples.size() - numInstForFold); 054 first = numFold * (examples.size() / numFolds) + offset; 055 train.addAll(examples.subList(0, first)); 056 int from = first + numInstForFold; 057 int size = examples.size() - first - numInstForFold; 058 int to = from + size; 059 train.addAll(examples.subList(from, to)); 060 061 return train; 062 } 063 064 public List<T> testCV(int numFolds, int numFold) { 065 int numInstForFold, first, offset; 066 List<T> test; 067 068 if (numFolds < 2) { 069 throw new IllegalArgumentException("Number of folds must be at least 2!"); 070 } 071 if (numFolds > examples.size()) { 072 throw new IllegalArgumentException("Can't have more folds than instances!"); 073 } 074 numInstForFold = examples.size() / numFolds; 075 if (numFold < examples.size() % numFolds) { 076 numInstForFold++; 077 offset = numFold; 078 } else { 079 offset = examples.size() % numFolds; 080 } 081 test = new ArrayList<>(numInstForFold); 082 first = numFold * (examples.size() / numFolds) + offset; 083 test.addAll(examples.subList(first, first + numInstForFold)); 084 return test; 085 } 086 087 public static void main(String[] args) throws Exception { 088 int numPos = 4; 089 int numNeg = 5; 090 List<String> posExamples = Lists.newArrayList(); 091 List<String> negExamples = Lists.newArrayList(); 092 for (int i = 1; i <= numPos; i++) { 093 posExamples.add("p" + i); 094 } 095 for (int i = 1; i <= numNeg; i++) { 096 negExamples.add("n" + i); 097 } 098 099 FoldGenerator<String> foldGenerator = new FoldGenerator<>(posExamples, negExamples); 100 101 int numFolds = 5; 102 103 int numAttributes = 2; 104 int num = 9; 105 ArrayList<Attribute> attInfo = new ArrayList<>(numAttributes); 106 attInfo.add(new Attribute("C1")); 107 attInfo.add(new Attribute("t", Lists.newArrayList("0","1"))); 108 109 Instances data = new Instances("rel", attInfo, num); 110 for (int i = 1; i <= num; i++) { 111 data.add(new DenseInstance(1.0, new double[]{i,0})); 112 } 113// System.out.println(data); 114 for (int i = 0; i < numFolds; i++) { 115 System.out.println("Fold " + i); 116 System.out.println("Train:" + foldGenerator.trainCV(numFolds, i)); 117 System.out.println("Test:" + foldGenerator.testCV(numFolds, i)); 118 119// System.out.println(data.trainCV(numFolds, i)); 120 } 121 122 123 } 124 125}