001/**
002 * Copyright (C) 2007 - 2016, Jens Lehmann
003 *
004 * This file is part of DL-Learner.
005 *
006 * DL-Learner is free software; you can redistribute it and/or modify
007 * it under the terms of the GNU General Public License as published by
008 * the Free Software Foundation; either version 3 of the License, or
009 * (at your option) any later version.
010 *
011 * DL-Learner is distributed in the hope that it will be useful,
012 * but WITHOUT ANY WARRANTY; without even the implied warranty of
013 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
014 * GNU General Public License for more details.
015 *
016 * You should have received a copy of the GNU General Public License
017 * along with this program.  If not, see <http://www.gnu.org/licenses/>.
018 */
019package org.dllearner.utilities;
020
021import com.google.common.collect.ImmutableSet;
022import com.google.common.collect.Sets;
023import org.dllearner.core.AbstractReasonerComponent;
024import org.dllearner.core.Component;
025import org.dllearner.accuracymethods.AccMethodApproximate;
026import org.dllearner.accuracymethods.AccMethodTwoValued;
027import org.dllearner.accuracymethods.AccMethodTwoValuedApproximate;
028import org.dllearner.reasoning.SPARQLReasoner;
029import org.semanticweb.owlapi.model.OWLClassExpression;
030import org.semanticweb.owlapi.model.OWLIndividual;
031import org.slf4j.Logger;
032import org.slf4j.LoggerFactory;
033
034import java.util.Collection;
035import java.util.Set;
036import java.util.SortedSet;
037import java.util.TreeSet;
038
039/**
040 * Common utilities for using a reasoner in learning problems
041 */
042public class ReasoningUtils implements Component {
043
044        final static Logger logger = LoggerFactory.getLogger(ReasoningUtils.class);
045
046        /**
047         * binary counter to divide a set in 2 partitions
048         */
049        public class CoverageCount {
050                public int trueCount;
051                public int falseCount;
052                public int total;
053        }
054
055        /**
056         * binary counter to divide a set in 2 partitions and an additional counter for unkown individuals
057         */
058        class Coverage3Count extends CoverageCount {
059                int unknownCount;
060        }
061
062        /**
063         * binary set to divide a set in 2 partitions
064         */
065        public class Coverage extends CoverageCount {
066                public SortedSet<OWLIndividual> trueSet = new TreeSet<>();
067                public SortedSet<OWLIndividual> falseSet = new TreeSet<>();
068        }
069
070        /**
071         * binary set to divide a set in 2 partitions, and an additonal for unknown individuals
072         */
073        class Coverage3 extends Coverage3Count {
074                final SortedSet<OWLIndividual> trueSet = new TreeSet<>();
075                final SortedSet<OWLIndividual> falseSet = new TreeSet<>();
076                final SortedSet<OWLIndividual> unknownSet = new TreeSet<>();
077        }
078
079        protected AbstractReasonerComponent reasoner;
080
081        /**
082         * create new reasoning utils
083         * @param reasoner reasoner to use
084         */
085        public ReasoningUtils(AbstractReasonerComponent reasoner) {
086                this.reasoner = reasoner;
087        }
088
089        /**
090         * callback to interrupt individual instance check
091         * @return true when instance check loop should be aborted
092         */
093        protected boolean interrupted() { return false; }
094
095
096        /**
097         * binary partition a list of sets into true and false, depending on whether they satisfy concept. wrapper to convert collections to sets
098         * @param concept the concept for partitioning
099         * @param collections list of collections to partition. they will be converted to sets first
100         * @return array of coverage data
101         */
102        public final Coverage[] getCoverage(OWLClassExpression concept,
103                                            Collection<OWLIndividual>... collections) {
104                Set[] sets = new Set[collections.length];
105                for (int i = 0; i < collections.length; ++i) {
106                        sets[i] = makeSet(collections[i]);
107                }
108                return getCoverage(concept, sets);
109        }
110
111        /**
112         * binary partition a list of sets into true and false, depending on whether they satisfy concept
113         * @param concept the OWL concept used for partition
114         * @param sets list of sets to partition
115         * @return an array of Coverage data, one entry for each input set
116         */
117        @SafeVarargs
118        public final Coverage[] getCoverage(OWLClassExpression concept, Set<OWLIndividual>... sets) {
119                Coverage[] rv = new Coverage [ sets.length ];
120
121                if(!reasoner.isUseInstanceChecks()) {
122                        if (reasoner instanceof SPARQLReasoner &&
123                                        ((SPARQLReasoner)reasoner).isUseValueLists()) {
124                                for (int i = 0; i < sets.length; ++i) {
125                                        SortedSet<OWLIndividual> trueSet = reasoner.hasType(concept, sets[i]);
126
127                                        rv[i] = new Coverage();
128                                        rv[i].total = sets[i].size();
129
130                                        rv[i].trueSet.addAll(trueSet);
131                                        rv[i].falseSet.addAll(Sets.difference(sets[i], trueSet));
132
133                                        rv[i].trueCount = rv[i].trueSet.size();
134                                        rv[i].falseCount = rv[i].falseSet.size();
135                                }
136                        } else {
137                                SortedSet<OWLIndividual> individuals = reasoner.getIndividuals(concept);
138                                for (int i = 0; i < sets.length; ++i) {
139                                        rv[i] = new Coverage();
140                                        rv[i].total = sets[i].size();
141
142                                        rv[i].trueSet.addAll(Sets.intersection(sets[i], individuals));
143                                        rv[i].falseSet.addAll(Sets.difference(sets[i], individuals));
144
145                                        rv[i].trueCount = rv[i].trueSet.size();
146                                        rv[i].falseCount = rv[i].falseSet.size();
147                                }
148                        }
149                } else {
150                        for (int i = 0; i < sets.length; ++i) {
151                                rv[i] = new Coverage();
152                                rv[i].total = sets[i].size();
153
154                                for (OWLIndividual example : sets[i]) {
155                                        if (getReasoner().hasType(concept, example)) {
156                                                rv[i].trueSet.add(example);
157                                        } else {
158                                                rv[i].falseSet.add(example);
159                                        }
160                                        if (interrupted()) {
161                                                return null;
162                                        }
163                                }
164
165                                rv[i].trueCount = rv[i].trueSet.size();
166                                rv[i].falseCount = rv[i].falseSet.size();
167                        }
168                }
169                return rv;
170        }
171
172
173        /**
174         * count the numbers of individuals satisfying a concept. wrapper converting collections to set
175         * @param concept the OWL concept used for counting
176         * @param collections list of collections of individuals to count on. will be converted to sets first
177         * @return an array of Coverage counts, one entry for each input set
178         */
179        public final CoverageCount[] getCoverageCount(OWLClassExpression concept,
180                                                      Collection<OWLIndividual>... collections) {
181                Set[] sets = new Set [ collections.length ];
182                for (int i = 0; i < collections.length; ++i) {
183                        sets[i] = makeSet(collections[i]);
184                }
185                return getCoverageCount(concept, sets);
186        }
187
188        /**
189         * count the numbers of individuals satisfying a concept
190         * @param concept the OWL concept used for counting
191         * @param sets list of sets of individuals to count on
192         * @return an array of Coverage counts, one entry for each input set
193         */
194        @SafeVarargs
195        public final CoverageCount[] getCoverageCount(OWLClassExpression concept,
196                                                                                                  Set<OWLIndividual>... sets) {
197                CoverageCount[] rv = new CoverageCount [ sets.length ];
198
199                if(!reasoner.isUseInstanceChecks()) {
200                        if (reasoner instanceof SPARQLReasoner &&
201                                        ((SPARQLReasoner)reasoner).isUseValueLists()) {
202
203                                for (int i = 0; i < sets.length; ++i) {
204                                        int trueCount = ((SPARQLReasoner) reasoner).getIndividualsCount(concept, sets[i]);
205
206                                        rv[i] = new CoverageCount();
207                                        rv[i].total = sets[i].size();
208
209                                        rv[i].trueCount = trueCount;
210                                        rv[i].falseCount = sets[i].size()- trueCount;
211                                }
212                        } else {
213                                SortedSet<OWLIndividual> individuals = reasoner.getIndividuals(concept);
214                                for (int i = 0; i < sets.length; ++i) {
215                                        rv[i] = new CoverageCount();
216                                        rv[i].total = sets[i].size();
217                                
218                                        rv[i].trueCount  = Sets.intersection(sets[i], individuals).size();
219                                        rv[i].falseCount = Sets.difference(sets[i], individuals).size();
220                                }
221                        }
222                } else {
223                        for (int i = 0; i < sets.length; ++i) {
224                                rv[i] = new CoverageCount();
225                                rv[i].total = sets[i].size();
226
227                                for (OWLIndividual example : sets[i]) {
228                                        if (getReasoner().hasType(concept, example)) {
229                                                ++rv[i].trueCount;
230                                        } else {
231                                                ++rv[i].falseCount;
232                                        }
233                                        if (interrupted()) {
234                                                return null;
235                                        }
236                                }
237                        }
238                }
239                return rv;
240        }
241
242        /**
243         * partition an array of sets into true, false and unknown, depending on whether they satisfy concept A or B
244         * @param trueConcept the OWL concept used for true partition
245         * @param falseConcept the OWL concept used for false partition
246         * @param sets list of sets to partition
247         * @return an array of Coverage data, one entry for each input set
248         */
249        @SafeVarargs
250        public final Coverage3[] getCoverage3(OWLClassExpression trueConcept, OWLClassExpression falseConcept, Set<OWLIndividual>... sets) {
251                Coverage3[] rv = new Coverage3 [ sets.length ];
252
253                if(!reasoner.isUseInstanceChecks()) {
254                        if (reasoner instanceof SPARQLReasoner &&
255                                        ((SPARQLReasoner)reasoner).isUseValueLists()) {
256                                for (int i = 0; i < sets.length; ++i) {
257                                        rv[i] = new Coverage3();
258                                        rv[i].total = sets[i].size();
259
260                                        SortedSet<OWLIndividual> trueSet = reasoner.hasType(trueConcept, sets[i]);
261                                        SortedSet<OWLIndividual> falseSet = reasoner.hasType(falseConcept, sets[i]);
262                                        rv[i].trueSet.addAll(trueSet);
263                                        rv[i].falseSet.addAll(falseSet);
264                                        rv[i].unknownSet.addAll(Sets.difference(sets[i], Sets.union(trueSet, falseSet)));
265
266                                        rv[i].trueCount = rv[i].trueSet.size();
267                                        rv[i].falseCount = rv[i].falseSet.size();
268                                        rv[i].unknownCount = rv[i].unknownSet.size();
269                                }
270                        } else {
271                                SortedSet<OWLIndividual> trueIndividuals = reasoner.getIndividuals(trueConcept);
272                                SortedSet<OWLIndividual> falseIndividuals = reasoner.getIndividuals(falseConcept);
273                                for (int i = 0; i < sets.length; ++i) {
274                                        rv[i] = new Coverage3();
275                                        rv[i].total = sets[i].size();
276
277                                        rv[i].trueSet.addAll(Sets.intersection(sets[i], trueIndividuals));
278                                        rv[i].falseSet.addAll(Sets.intersection(sets[i], falseIndividuals));
279                                        rv[i].unknownSet.addAll(Sets.difference(sets[i], Sets.union(rv[i].trueSet, rv[i].falseSet)));
280
281                                        rv[i].trueCount = rv[i].trueSet.size();
282                                        rv[i].falseCount = rv[i].falseSet.size();
283                                        rv[i].unknownCount = rv[i].unknownSet.size();
284                                }
285                        }
286                } else {
287                        for (int i = 0; i < sets.length; ++i) {
288                                rv[i] = new Coverage3();
289                                rv[i].total = sets[i].size();
290
291                                for (OWLIndividual example : sets[i]) {
292                                        if (getReasoner().hasType(trueConcept, example)) {
293                                                rv[i].trueSet.add(example);
294                                        } else if (getReasoner().hasType(falseConcept, example)) {
295                                                rv[i].falseSet.add(example);
296                                        } else {
297                                                rv[i].unknownSet.add(example);
298                                        }
299                                        if (interrupted()) {
300                                                return null;
301                                        }
302                                }
303
304                                rv[i].trueCount = rv[i].trueSet.size();
305                                rv[i].falseCount = rv[i].falseSet.size();
306                                rv[i].unknownCount = rv[i].unknownSet.size();
307                        }
308                }
309                return rv;
310        }
311
312        /**
313         * calculate accuracy of a concept, using the supplied accuracy method
314         * @param accuracyMethod accuracy method to use
315         * @param description concept to test
316         * @param positiveExamples set of positive examples to use for calculating the accuracy
317         * @param negativeExamples set of negative examples to use for calculating the accuracy
318         * @param noise noise level of the data
319         * @return -1 when the concept is too weak or the accuracy value as calculated by the accuracy method
320         */
321        public double getAccuracyOrTooWeak2(AccMethodTwoValued accuracyMethod, OWLClassExpression description, Collection<OWLIndividual> positiveExamples,
322                        Collection<OWLIndividual> negativeExamples, double noise) {
323                if (accuracyMethod instanceof AccMethodApproximate) {
324                        logger.trace("AccMethodApproximate");
325                        return ((AccMethodTwoValuedApproximate) accuracyMethod).getAccApprox2(description, positiveExamples, negativeExamples, noise);
326                } else {
327                        CoverageCount[] cc = getCoverageCount(description, positiveExamples, negativeExamples);
328                        logger.trace("AccMethodExact: " + (new CoverageAdapter.CoverageCountAdapter2(cc)));
329                        return getAccuracyOrTooWeakExact2(accuracyMethod, cc, noise);
330                }
331        }
332
333
334        /**
335         * wrapper to call accuracy method with coverage count
336         * @param accuracyMethod method to use
337         * @param cc already calculated coverage count
338         * @param noise noise level
339         * @return @{AccMethodTwoValued.getAccOrTooWeak2}
340         */
341        public double getAccuracyOrTooWeakExact2(AccMethodTwoValued accuracyMethod, CoverageCount[] cc, double noise) {
342//              return accuracyMethod.getAccOrTooWeak2(cc[0].trueCount, cc[0].falseCount, cc[1].trueCount, cc[1].falseCount, noise);
343                CoverageAdapter.CoverageCountAdapter2 c2 = new CoverageAdapter.CoverageCountAdapter2(cc);
344                logger.trace("calling getAccOrToWeak2["+c2.tp()+","+c2.fn()+","+c2.fp()+","+c2.tn()+","+noise+"]");
345                return accuracyMethod.getAccOrTooWeak2(c2.tp(), c2.fn(), c2.fp(), c2.tn(), noise);
346        }
347
348        @Override
349        public void init() {
350        }
351
352        public AbstractReasonerComponent getReasoner() {
353                return reasoner;
354        }
355
356        public void setReasoner(AbstractReasonerComponent reasoner) {
357                this.reasoner = reasoner;
358        }
359
360        /**
361         * helper method to create a set from a collection
362         * @param collection
363         * @param <T>
364         * @return set (or hashset)
365         */
366        protected <T> Set<T> makeSet(Collection<T> collection) {
367                return collection instanceof Set ? (Set)collection : ImmutableSet.copyOf(collection);
368        }
369
370}