001package org.dllearner.algorithms.isle.index;
002
003import com.google.common.base.Splitter;
004import com.google.common.collect.Lists;
005import org.semanticweb.owlapi.model.IRI;
006import org.semanticweb.owlapi.model.OWLDataFactory;
007import org.semanticweb.owlapi.model.OWLEntity;
008import uk.ac.manchester.cs.owl.owlapi.OWLDataFactoryImpl;
009
010import java.util.*;
011
012/**
013 * Tree for finding longest matching Token sequence
014 *
015 * @author Daniel Fleischhacker
016 */
017public class TokenTree {
018    public static final double WORDNET_FACTOR = 0.3d;
019    public static final double ORIGINAL_FACTOR = 1.0d;
020
021    private LinkedHashMap<Token, TokenTree> children;
022    private Set<OWLEntity> entities;
023    private List<Token> originalTokens;
024    private boolean ignoreStopWords = true;
025
026    public TokenTree() {
027        this.children = new LinkedHashMap<>();
028        this.entities = new HashSet<>();
029        this.originalTokens = new ArrayList<>();
030    }
031
032    /**
033     * If set to TRUE, stopwords like 'of, on' are ignored during creation and retrieval operations.
034     *
035     * @param ignoreStopWords the ignoreStopWords to set
036     */
037    public void setIgnoreStopWords(boolean ignoreStopWords) {
038        this.ignoreStopWords = ignoreStopWords;
039    }
040
041    /**
042     * Adds all given entities to the end of the path resulting from the given tokens.
043     *
044     * @param tokens   tokens to locate insertion point for entities
045     * @param entities entities to add
046     */
047    public void add(List<Token> tokens, Set<OWLEntity> entities, List<Token> originalTokens) {
048        TokenTree curNode = this;
049        for (Token t : tokens) {
050            if (!ignoreStopWords || (ignoreStopWords && !t.isStopWord())) {
051                TokenTree nextNode = curNode.children.get(t);
052                if (nextNode == null) {
053                    nextNode = new TokenTree();
054                    curNode.children.put(t, nextNode);
055                }
056                curNode = nextNode;
057            }
058        }
059        curNode.entities.addAll(entities);
060        curNode.originalTokens = new ArrayList<>(originalTokens);
061    }
062
063    public void add(List<Token> tokens, Set<OWLEntity> entities) {
064        add(tokens, entities, tokens);
065    }
066
067    /**
068     * Adds the given entity to the tree.
069     *
070     * @param tokens tokens to locate insertion point for entities
071     * @param entity entity to add
072     */
073    public void add(List<Token> tokens, OWLEntity entity) {
074        add(tokens, Collections.singleton(entity));
075    }
076
077    public void add(List<Token> tokens, OWLEntity entity, List<Token> originalTokens) {
078        add(tokens, Collections.singleton(entity), originalTokens);
079    }
080
081    /**
082     * Returns the set of entities located by the given list of tokens. This method does not consider alternative forms.
083     *
084     * @param tokens tokens to locate the information to get
085     * @return located set of entities or null if token sequence not contained in tree
086     */
087    public Set<OWLEntity> get(List<Token> tokens) {
088        TokenTree curNode = this;
089        for (Token t : tokens) {
090            TokenTree nextNode = getNextTokenTree(curNode, t);
091            if (nextNode == null) {
092                return null;
093            }
094            curNode = nextNode;
095        }
096        return curNode.entities;
097    }
098
099    public Set<EntityScorePair> getAllEntitiesScored(List<Token> tokens) {
100        HashSet<EntityScorePair> resEntities = new HashSet<>();
101        getAllEntitiesScoredRec(tokens, 0, this, resEntities, 1.0);
102
103        // only keep highest confidence for each entity
104        HashMap<OWLEntity, Double> entityScores = new HashMap<>();
105
106        for (EntityScorePair p : resEntities) {
107            if (!entityScores.containsKey(p.getEntity())) {
108                entityScores.put(p.getEntity(), p.getScore());
109            }
110            else {
111                entityScores.put(p.getEntity(), Math.max(p.getScore(), entityScores.get(p.getEntity())));
112            }
113        }
114
115        TreeSet<EntityScorePair> result = new TreeSet<>();
116        for (Map.Entry<OWLEntity, Double> e : entityScores.entrySet()) {
117            result.add(new EntityScorePair(e.getKey(), e.getValue()));
118        }
119
120        return result;
121    }
122
123    public void getAllEntitiesScoredRec(List<Token> tokens, int curPosition, TokenTree curTree,
124                                        HashSet<EntityScorePair> resEntities, Double curScore) {
125
126        if (curPosition == tokens.size()) {
127            for (OWLEntity e : curTree.entities) {
128                resEntities.add(new EntityScorePair(e, curScore));
129            }
130            return;
131        }
132        Token currentTextToken = tokens.get(curPosition);
133        for (Map.Entry<Token, TokenTree> treeTokenEntry : curTree.children.entrySet()) {
134            if (currentTextToken.equals(treeTokenEntry.getKey())) {
135                getAllEntitiesScoredRec(tokens, curPosition + 1, treeTokenEntry.getValue(), resEntities,
136                        curScore * ORIGINAL_FACTOR);
137            }
138            else {
139                for (Map.Entry<String, Double> treeAlternativeForm : treeTokenEntry.getKey().getScoredAlternativeForms()
140                        .entrySet()) {
141                    if (currentTextToken.getStemmedForm().equals(treeAlternativeForm.getKey())) {
142                        getAllEntitiesScoredRec(tokens, curPosition + 1, treeTokenEntry.getValue(), resEntities,
143                                curScore * ORIGINAL_FACTOR * treeAlternativeForm.getValue());
144                    }
145                }
146                for (Map.Entry<String, Double> textAlternativeForm : currentTextToken.getScoredAlternativeForms()
147                        .entrySet()) {
148                    if (treeTokenEntry.getKey().getStemmedForm().equals(textAlternativeForm.getKey())) {
149                        getAllEntitiesScoredRec(tokens, curPosition + 1, treeTokenEntry.getValue(), resEntities,
150                                curScore * ORIGINAL_FACTOR * textAlternativeForm.getValue());
151                    }
152                }
153
154                for (Map.Entry<String, Double> treeAlternativeForm : treeTokenEntry.getKey().getScoredAlternativeForms()
155                        .entrySet()) {
156                    for (Map.Entry<String, Double> textAlternativeForm : currentTextToken.getScoredAlternativeForms()
157                            .entrySet()) {
158                        if (treeAlternativeForm.getKey().equals(textAlternativeForm.getKey())) {
159                            getAllEntitiesScoredRec(tokens, curPosition + 1, treeTokenEntry.getValue(), resEntities,
160                                    curScore * treeAlternativeForm.getValue() * textAlternativeForm.getValue());
161                        }
162                    }
163                }
164            }
165        }
166    }
167
168    public Set<OWLEntity> getAllEntities(List<Token> tokens) {
169        HashSet<OWLEntity> resEntities = new HashSet<>();
170        getAllEntitiesRec(tokens, 0, this, resEntities);
171        return resEntities;
172    }
173
174    public void getAllEntitiesRec(List<Token> tokens, int curPosition, TokenTree curTree, HashSet<OWLEntity> resEntities) {
175        if (curPosition == tokens.size()) {
176            resEntities.addAll(curTree.entities);
177            return;
178        }
179        Token t = tokens.get(curPosition);
180        for (Map.Entry<Token, TokenTree> entry : curTree.children.entrySet()) {
181            if (t.equalsWithAlternativeForms(entry.getKey())) {
182                getAllEntitiesRec(tokens, curPosition + 1, entry.getValue(), resEntities);
183            }
184        }
185    }
186
187    /**
188     * Returns the list of tokens which are the longest match with entities assigned in this tree.
189     *
190     * @param tokens list of tokens to check for longest match
191     * @return list of tokens being the longest match, sublist of {@code tokens} anchored at the first token
192     */
193    public List<Token> getLongestMatch(List<Token> tokens) {
194        List<Token> fallbackTokenList = new ArrayList<>();
195        TokenTree curNode = this;
196
197        for (Token t : tokens) {
198            TokenTree nextNode = getNextTokenTree(curNode, t);
199            if (nextNode == null) {
200                return fallbackTokenList;
201            }
202            curNode = nextNode;
203            fallbackTokenList.add(t);
204        }
205        return fallbackTokenList;
206    }
207
208    private TokenTree getNextTokenTree(TokenTree current, Token t) {
209        TokenTree next = current.children.get(t);
210        if (next != null) {
211            return next;
212        }
213        for (Map.Entry<Token, TokenTree> child : current.children.entrySet()) {
214            if (child.getKey().equalsWithAlternativeForms(t)) {
215                return child.getValue();
216            }
217        }
218        return null;
219    }
220
221    /**
222     * Returns the set of entities assigned to the longest matching token subsequence of the given token sequence.
223     *
224     * @param tokens token sequence to search for longest match
225     * @return set of entities assigned to the longest matching token subsequence of the given token sequence
226     */
227    public Set<OWLEntity> getEntitiesForLongestMatch(List<Token> tokens) {
228        TokenTree fallback = this.entities.isEmpty() ? null : this;
229        TokenTree curNode = this;
230
231        for (Token t : tokens) {
232            TokenTree nextNode = getNextTokenTree(curNode, t);
233            if (nextNode == null) {
234                return fallback == null ? null : fallback.entities;
235            }
236            curNode = nextNode;
237            if (!curNode.entities.isEmpty()) {
238                fallback = curNode;
239            }
240        }
241
242        return fallback == null ? Collections.<OWLEntity>emptySet() : fallback.entities;
243    }
244
245    /**
246     * Returns the original ontology tokens for the longest match
247     */
248    public List<Token> getOriginalTokensForLongestMatch(List<Token> tokens) {
249        TokenTree fallback = this.entities.isEmpty() ? null : this;
250        TokenTree curNode = this;
251
252        for (Token t : tokens) {
253            TokenTree nextNode = getNextTokenTree(curNode, t);
254            if (nextNode == null) {
255                return fallback == null ? null : fallback.originalTokens;
256            }
257            curNode = nextNode;
258            if (!curNode.entities.isEmpty()) {
259                fallback = curNode;
260            }
261        }
262
263        return fallback == null ? Collections.<Token>emptyList() : fallback.originalTokens;
264    }
265
266    public static void main(String[] args) throws Exception {
267        List<Token> tokens1 = Lists.newLinkedList();
268        for (String s : Splitter.on(" ").split("this is a token tree")) {
269            tokens1.add(new Token(s, s, s, false, false));
270        }
271
272        List<Token> tokens2 = Lists.newLinkedList();
273        for (String s : Splitter.on(" ").split("this is a tokenized tree")) {
274            tokens2.add(new Token(s, s, s, false, false));
275        }
276
277        OWLDataFactory df = new OWLDataFactoryImpl();
278        TokenTree tree = new TokenTree();
279        tree.add(tokens1, df.getOWLClass(IRI.create("TokenTree")));
280        tree.add(tokens2, df.getOWLClass(IRI.create("TokenizedTree")));
281        System.out.println(tree);
282
283        System.out.println(tree.getEntitiesForLongestMatch(tokens1));
284        System.out.println(tree.getLongestMatch(tokens1));
285
286        List<Token> tokens3 = Lists.newLinkedList();
287        for (String s : Splitter.on(" ").split("this is a very nice tokenized tree")) {
288            tokens3.add(new Token(s, s, s, false, false));
289        }
290        System.out.println(tree.getLongestMatch(tokens3));
291    }
292
293
294    public String toString() {
295        return "TokenTree\n" + toString(0);
296    }
297
298    public String toString(int indent) {
299        StringBuilder indentStringBuilder = new StringBuilder();
300        for (int i = 0; i < indent; i++) {
301            indentStringBuilder.append(" ");
302        }
303        String indentString = indentStringBuilder.toString();
304        StringBuilder sb = new StringBuilder();
305        for (Map.Entry<Token, TokenTree> e : new TreeMap<>(children).entrySet()) {
306            sb.append(indentString).append(e.getKey().toString());
307            sb.append("\n");
308            sb.append(e.getValue().toString(indent + 1));
309        }
310        return sb.toString();
311    }
312
313
314}