001/**
002 * Copyright (C) 2007 - 2016, Jens Lehmann
003 *
004 * This file is part of DL-Learner.
005 *
006 * DL-Learner is free software; you can redistribute it and/or modify
007 * it under the terms of the GNU General Public License as published by
008 * the Free Software Foundation; either version 3 of the License, or
009 * (at your option) any later version.
010 *
011 * DL-Learner is distributed in the hope that it will be useful,
012 * but WITHOUT ANY WARRANTY; without even the implied warranty of
013 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
014 * GNU General Public License for more details.
015 *
016 * You should have received a copy of the GNU General Public License
017 * along with this program.  If not, see <http://www.gnu.org/licenses/>.
018 */
019package org.dllearner.utilities.statistics;
020
021import java.text.DecimalFormat;
022import java.util.Set;
023
024/**
025 * Utility class for calculating the mean and standard deviation of a given set
026 * of numbers. The class also contains convenience methods for printing values.
027 * 
028 * @author Jens Lehmann
029 * 
030 */
031public class Stat {
032
033    private int count = 0;
034    private double sum = 0;
035    private double squareSum = 0;
036    private double min = Double.MAX_VALUE;
037    private double max = Double.MIN_NORMAL;
038    //used to give a good percentage output
039    private DecimalFormat df = new DecimalFormat( ".00%" ); 
040
041    public Stat() {
042        
043    }
044    
045    /**
046     * Creates a new stat object by merging two stat objects. The result is the same as if
047     * the numbers, which have been added to stat1 and stat2 would have been added to this
048     * stat object.
049     * @param stat1 Statistical object 1.
050     * @param stat2 Statistical object 2.
051     */
052    public Stat(Stat stat1, Stat stat2) {
053        count = stat1.count + stat2.count;
054        sum = stat1.sum + stat2.sum;
055        squareSum = stat1.squareSum + stat2.squareSum;
056        min = Math.min(stat1.min, stat2.min);
057        max = Math.max(stat1.max, stat2.max);
058    }
059    
060    /**
061     * Creates a new stat object by merging several stat objects. The result is the same as if
062     * the numbers, which have been added to each stat would have been added to this
063     * stat object.
064     * @param stats the stats
065     */
066    public Stat(Set<Stat> stats) {
067        for(Stat stat : stats){
068                count += stat.count;
069                sum += stat.sum;
070                squareSum += stat.squareSum;
071                min = Math.min(min, stat.min);
072                max = Math.max(max, stat.max);
073        }
074    }
075    
076    public void add(Stat stat){
077        count += stat.count;
078                sum += stat.sum;
079        squareSum += stat.squareSum;
080        min = Math.min(min, stat.min);
081        max = Math.max(max, stat.max);
082    }
083    
084    /**
085     * Add a number to this object.
086     * 
087     * @param number
088     *            The new number.
089     */
090    public void addNumber(double number) {
091        count++;
092        sum += number;
093        squareSum += number * number;
094        if(number<min)
095                min=number;
096        if(number>max)
097                max=number;
098    }
099
100    /**
101     * Gets the number of numbers.
102     * 
103     * @return The number of numbers.
104     */
105    public int getCount() {
106        return count;
107    }
108
109    /**
110     * Gets the sum of all numbers.
111     * 
112     * @return The sum of all numbers.
113     */
114    public double getSum() {
115        return sum;
116    }
117
118    /**
119     * Gets the mean of all numbers.
120     * 
121     * @return The mean of all numbers.
122     */
123    public double getMean() {
124        return sum / count;
125    }
126    
127    /**
128     * Gets the mean of all numbers as percentage 
129     * *100 so 0.5678 -> "56.78%"
130     * @return The mean as formatted string.
131     */
132    public String getMeanAsPercentage(){
133        return df.format(getMean());
134    }
135
136    /**
137     * Gets the standard deviation of all numbers.
138     * 
139     * @return The standard deviation of all numbers.
140     */
141    public double getStandardDeviation() {      
142        if(count <= 1)
143                return 0.0;
144        
145        // formula from http://de.wikipedia.org/wiki/Standardabweichung
146        double val = (count*squareSum-sum*sum)/(count*(count-1));
147        double root = Math.sqrt(val);
148        
149        // due to rounding errors it can happen that "val" is negative
150        // this means that the real value is 0 (or very close to it), so
151        // we return 0
152        if(Double.isNaN(root)) 
153                return 0.0;
154        else
155                return root;
156    }
157
158        /**
159         * @return the min
160         */
161        public double getMin() {
162                return min;
163        }
164
165        /**
166         * @return the max
167         */
168        public double getMax() {
169                return max;
170        }
171
172        public String prettyPrint() {
173                return prettyPrint("");
174        }
175        
176        public String prettyPrint(String unit) {
177                if(count > 0) {
178                        DecimalFormat df = new DecimalFormat();
179                        String str = "av. " + df.format(getMean()) + unit;
180                        str += " (deviation " + df.format(getStandardDeviation()) + unit + "; ";
181                        str += "min " + df.format(getMin()) + unit + "; ";
182                        str += "max " + df.format(getMax()) + unit + "; ";
183                        str += "count " + count + ")";
184                        return str;
185                } else {
186                        return "no data collected";
187                }
188        }       
189        
190        public String prettyPrint(String unit, DecimalFormat df) {
191                String str = "av. " + df.format(getMean()) + unit;
192                str += " (deviation " + df.format(getStandardDeviation()) + unit + "; ";
193                str += "min " + df.format(getMin()) + unit + "; ";
194                str += "max " + df.format(getMax()) + unit + ")";               
195                return str;
196        }       
197        
198        /**
199         * Pretty prints the results under the assumption that the input
200         * values are time spans measured in nano seconds.
201         * 
202         * @see System#nanoTime()
203         * @return A string summarising statistical values.
204         */
205//      public String prettyPrintNanoSeconds() {
206//              DecimalFormat df = new DecimalFormat();
207//              String str = "av. " + df.format(getMean()) + unit;
208//              str += " (deviation " + df.format(getStandardDeviation()) + unit + "; ";
209//              str += "min " + df.format(getMin()) + unit + "; ";
210//              str += "max " + df.format(getMax()) + unit + ")";               
211//              return str;             
212//      }
213        
214}