001/**
002 * 
003 */
004package org.dllearner.algorithms.isle;
005
006import java.io.IOException;
007import java.nio.file.Files;
008import java.util.HashMap;
009import java.util.HashSet;
010import java.util.Map;
011import java.util.Set;
012
013import org.apache.commons.math3.linear.ArrayRealVector;
014import org.apache.commons.math3.linear.RealVector;
015import org.apache.lucene.analysis.Analyzer;
016import org.apache.lucene.analysis.core.SimpleAnalyzer;
017import org.apache.lucene.document.Document;
018import org.apache.lucene.document.Field;
019import org.apache.lucene.document.FieldType;
020import org.apache.lucene.index.DirectoryReader;
021import org.apache.lucene.index.IndexOptions;
022import org.apache.lucene.index.IndexReader;
023import org.apache.lucene.index.IndexWriter;
024import org.apache.lucene.index.IndexWriterConfig;
025import org.apache.lucene.index.Term;
026import org.apache.lucene.index.Terms;
027import org.apache.lucene.index.TermsEnum;
028import org.apache.lucene.store.Directory;
029import org.apache.lucene.store.MMapDirectory;
030import org.apache.lucene.util.BytesRef;
031
032/**
033 * Imagine an N-dimensional space where N is the number of unique words in a pair of texts. Each of the two texts 
034 * can be treated like a vector in this N-dimensional space. The distance between the two vectors is an indication 
035 * of the similarity of the two texts. The cosine of the angle between the two vectors is the most common distance measure.
036 * @author Lorenz Buehmann
037 *
038 */
039public class VSMCosineDocumentSimilarity {
040        
041        enum TermWeighting {
042                TF, TF_IDF
043        }
044        
045        public static final String CONTENT = "Content";
046    public static final FieldType TYPE_STORED = new FieldType();
047    
048    private final Set<String> terms = new HashSet<>();
049    private final RealVector v1;
050    private final RealVector v2;
051    
052    static {
053        TYPE_STORED.setIndexOptions(IndexOptions.DOCS_AND_FREQS);
054        TYPE_STORED.setTokenized(true);
055        TYPE_STORED.setStored(true);
056        TYPE_STORED.setStoreTermVectors(true);
057        TYPE_STORED.setStoreTermVectorPositions(true);
058        TYPE_STORED.freeze();
059    }
060    
061    public VSMCosineDocumentSimilarity(String s1, String s2, TermWeighting termWeighting) throws IOException {
062        //create the index
063        Directory directory = createIndex(s1, s2);
064        IndexReader reader = DirectoryReader.open(directory);
065        //generate the document vectors
066        if(termWeighting == TermWeighting.TF){//based on term frequency only
067                //compute the term frequencies for document 1
068            Map<String, Integer> f1 = getTermFrequencies(reader, 0);
069            //compute the term frequencies for document 2
070            Map<String, Integer> f2 = getTermFrequencies(reader, 1);
071            reader.close();
072            //map both documents to vector objects
073            v1 = getTermVectorInteger(f1);
074            v2 = getTermVectorInteger(f2);
075        } else if(termWeighting == TermWeighting.TF_IDF){//based on tf*idf weighting
076                //compute the term frequencies for document 1
077            Map<String, Double> f1 = getTermWeights(reader, 0);
078            //compute the term frequencies for document 2
079            Map<String, Double> f2 = getTermWeights(reader, 1);
080            reader.close();
081            //map both documents to vector objects
082            v1 = getTermVectorDouble(f1);
083            v2 = getTermVectorDouble(f2);
084        } else {
085                v1 = null;
086                v2 = null;
087        }
088    }
089    
090    public VSMCosineDocumentSimilarity(String s1, String s2) throws IOException {
091        this(s1, s2, TermWeighting.TF_IDF);
092    }
093    
094    /**
095     * Returns the cosine document similarity between document {@code doc1} and {@code doc2} using TF-IDF as weighting for each term.
096     * The resulting similarity ranges from -1 meaning exactly opposite, to 1 meaning exactly the same, 
097     * with 0 usually indicating independence, and in-between values indicating intermediate similarity or dissimilarity.
098     * @param doc1
099     * @param doc2
100     * @return
101     * @throws IOException
102     */
103    public static double getCosineSimilarity(String doc1, String doc2)
104            throws IOException {
105        return new VSMCosineDocumentSimilarity(doc1, doc2).getCosineSimilarity();
106    }
107    
108    /**
109     * Returns the cosine document similarity between document {@code doc1} and {@code doc2} based on {@code termWeighting} to compute the weight
110     * for each term in the documents.
111     * The resulting similarity ranges from -1 meaning exactly opposite, to 1 meaning exactly the same, 
112     * with 0 usually indicating independence, and in-between values indicating intermediate similarity or dissimilarity.
113     * @param doc1
114     * @param doc2
115     * @return
116     * @throws IOException
117     */
118    public static double getCosineSimilarity(String doc1, String doc2, TermWeighting termWeighting)
119            throws IOException {
120        return new VSMCosineDocumentSimilarity(doc1, doc2, termWeighting).getCosineSimilarity();
121    }
122    
123    /**
124     * Create a in-memory Lucene index for both documents.
125     * @param s1
126     * @param s2
127     * @return
128     * @throws IOException
129     */
130    private Directory createIndex(String s1, String s2) throws IOException {
131        Directory directory = new MMapDirectory(Files.createTempDirectory("Lucene"));
132        Analyzer analyzer = new SimpleAnalyzer();
133        IndexWriterConfig iwc = new IndexWriterConfig(analyzer);
134        IndexWriter writer = new IndexWriter(directory, iwc);
135        addDocument(writer, s1);
136        addDocument(writer, s2);
137        writer.close();
138        return directory;
139    }
140    
141    /**
142     * Add the document to the Lucene index.
143     * @param writer
144     * @param content
145     * @throws IOException
146     */
147    private void addDocument(IndexWriter writer, String content) throws IOException {
148        Document doc = new Document();
149        Field field = new Field(CONTENT, content, TYPE_STORED);
150        doc.add(field);
151        writer.addDocument(doc);
152    }
153    
154    /**
155     * Get the frequency of each term contained in the document.
156     * @param reader
157     * @param docId
158     * @return
159     * @throws IOException
160     */
161    private Map<String, Integer> getTermFrequencies(IndexReader reader, int docId)
162            throws IOException {
163        Terms vector = reader.getTermVector(docId, CONTENT);
164        TermsEnum termsEnum = vector.iterator();
165        Map<String, Integer> frequencies = new HashMap<>();
166        BytesRef text = null;
167        while ((text = termsEnum.next()) != null) {
168            String term = text.utf8ToString();
169            int freq = (int) termsEnum.totalTermFreq();
170            frequencies.put(term, freq);
171            terms.add(term);
172        }
173        return frequencies;
174    }
175    
176    /**
177     * Get the weight(tf*idf) of each term contained in the document.
178     * @param reader
179     * @param docId
180     * @return
181     * @throws IOException
182     */
183    private Map<String, Double> getTermWeights(IndexReader reader, int docId)
184            throws IOException {
185        Terms vector = reader.getTermVector(docId, CONTENT);
186        //TODO: not sure if this is reasonable but it prevents NPEs
187        if (vector == null) {
188            return new HashMap<>();
189        }
190        TermsEnum termsEnum = vector.iterator();
191        Map<String, Double> weights = new HashMap<>();
192        BytesRef text = null;
193        while ((text = termsEnum.next()) != null) {
194            String term = text.utf8ToString();
195            //get the term frequency
196            int tf = (int) termsEnum.totalTermFreq();
197            //get the document frequency
198            int df = reader.docFreq(new Term(CONTENT, text));
199            //compute the inverse document frequency
200            double idf = getIDF(reader.numDocs(), df);
201            //compute tf*idf
202            double weight = tf * idf;
203            
204            weights.put(term, weight);
205            terms.add(term);
206        }
207        return weights;
208    }
209    
210    private double getIDF(int totalNumberOfDocuments, int documentFrequency){
211        return 1 + Math.log(totalNumberOfDocuments/documentFrequency);
212    }
213    
214    private double getCosineSimilarity() {
215        return (v1.dotProduct(v2)) / (v1.getNorm() * v2.getNorm());
216    }
217    
218    private RealVector getTermVectorInteger(Map<String, Integer> map) {
219        RealVector vector = new ArrayRealVector(terms.size());
220        int i = 0;
221        for (String term : terms) {
222            int value = map.containsKey(term) ? map.get(term) : 0;
223            vector.setEntry(i++, value);
224        }
225        return vector.mapDivide(vector.getL1Norm());
226    }
227    
228    private RealVector getTermVectorDouble(Map<String, Double> map) {
229        RealVector vector = new ArrayRealVector(terms.size());
230        int i = 0;
231        for (String term : terms) {
232            double value = map.containsKey(term) ? map.get(term) : 0d;
233            vector.setEntry(i++, value);
234        }
235        return vector.mapDivide(vector.getL1Norm());
236    }
237    
238    public static void main(String[] args) throws Exception {
239                double cosineSimilarity = VSMCosineDocumentSimilarity.getCosineSimilarity("The king is here", "The salad is cold");
240                System.out.println(cosineSimilarity);
241        }
242
243}