001package org.dllearner.examples;
002
003import java.io.BufferedReader;
004import java.io.BufferedWriter;
005import java.io.File;
006import java.io.FileReader;
007import java.io.FileWriter;
008import java.io.IOException;
009import java.util.ArrayList;
010import java.util.Collections;
011import java.util.HashMap;
012import java.util.List;
013import java.util.Set;
014import java.util.TreeSet;
015
016import org.apache.log4j.Level;
017import org.apache.log4j.Logger;
018import org.dllearner.algorithms.semkernel.SemKernel;
019import org.dllearner.algorithms.semkernel.SemKernel.SvmType;
020import org.semanticweb.elk.owlapi.ElkReasonerFactory;
021import org.semanticweb.owlapi.apibinding.OWLManager;
022import org.semanticweb.owlapi.model.IRI;
023import org.semanticweb.owlapi.model.OWLClass;
024import org.semanticweb.owlapi.model.OWLDataFactory;
025import org.semanticweb.owlapi.model.OWLOntology;
026import org.semanticweb.owlapi.model.OWLOntologyCreationException;
027import org.semanticweb.owlapi.model.OWLOntologyManager;
028import org.semanticweb.owlapi.reasoner.ConsoleProgressMonitor;
029import org.semanticweb.owlapi.reasoner.InferenceType;
030import org.semanticweb.owlapi.reasoner.OWLReasoner;
031import org.semanticweb.owlapi.reasoner.OWLReasonerConfiguration;
032import org.semanticweb.owlapi.reasoner.OWLReasonerFactory;
033import org.semanticweb.owlapi.reasoner.SimpleConfiguration;
034
035import com.google.common.collect.Sets;
036
037public class SemKernelExample {
038    private static final Logger logger = Logger.getLogger(SemKernelExample.class);
039
040    private static HashMap<String, Set<String>> mgi2go;
041    private static HashMap<String, Set<String>> mgi2mp;
042    private static Set<String> allClsUriStrs;
043    private static OWLDataFactory factory;
044
045    private static SemKernel kernel;
046    /* kernel settings, according to
047     * svm_train -s 0 -t 5 -b 1 -v 10 -c $cost -f go.obo \
048     *      $TRAININGDIR/$NAME-$par > $CVDIR/$NAME-$par-$cost
049     */
050    private static final SvmType svmType = SvmType.C_SVC;
051    private static final boolean doProbabilityEstimates = true;
052    private static final int crossValidationFolds = 10;
053    private static final float cost = 5f;
054    private static final boolean predictProbability = true;
055
056    /* semkernel settings */
057    /** URIs of phenotype classes to build a classifier for */
058    private static final Set<String> trainClsUriStrs = Sets.newHashSet(
059//            "http://purl.obolibrary.org/obo/MP_0004031",
060//            "http://purl.obolibrary.org/obo/MP_0000202",
061            "http://purl.obolibrary.org/obo/MP_0001186"
062//            "http://purl.obolibrary.org/obo/MP_0000001"
063            );
064    private static final Set<String> predictClsUriStrs = Sets.newHashSet(
065//          "http://purl.obolibrary.org/obo/MP_0004031",
066//          "http://purl.obolibrary.org/obo/MP_0000202",
067          "http://purl.obolibrary.org/obo/MP_0001186"
068//          "http://purl.obolibrary.org/obo/MP_0000001"
069          );
070    private static final double posNegExampleRatio = 1;
071
072    /* input/output */
073    private static final String workingDirPath = "/tmp/semkernel/";
074    // get it at http://purl.obolibrary.org/obo/go.obo
075//    private static final String kbFilePath = workingDirPath + "go.obo";
076    private static final String kbFilePath = workingDirPath + "go.owl";
077    private static final String trainDirPath = workingDirPath + "train/";
078    private static final String modelDirPath = workingDirPath + "models/";
079    private static final String predictionDirPath = workingDirPath + "predictions/";
080    private static final String resultsDirPath = workingDirPath + "results/";
081    // get it at http://purl.obolibrary.org/obo/mp.obo
082    private static final String samplesKbFilePath = workingDirPath + "mp.obo";
083    private static OWLOntology samplesKb;
084    private static OWLReasoner samplesReasoner;
085    // get it at ftp://ftp.informatics.jax.org/pub/reports/gene_association.mgi
086    private static String mgi2goClsMappingsFilePath = workingDirPath + "gene_association.mgi";
087    // get it at http://aber-owl.net/aber-owl/diseasephenotypes/data/mousephenotypes.txt
088    private static String mgi2mpClsMappingsFilePath = workingDirPath + "mousephenotypes_new.txt";
089
090    /* URI constants */
091    private static final String oboPrefix = "http://purl.obolibrary.org/obo/";
092
093    public static void main(String[] args) throws Exception {
094        initExample();
095        prepareTrainingData();
096
097        initSemkernel();
098        train();
099
100        preparePredictionData();
101        predict();
102    }
103
104    private static void initExample() throws OWLOntologyCreationException, IOException {
105        logger.setLevel(Level.DEBUG);
106
107        factory = OWLManager.getOWLDataFactory();
108        allClsUriStrs = new TreeSet<>();
109
110        logger.info(String.format("Loading samples ontology file %s ...", samplesKbFilePath));
111        samplesKb = readKb(samplesKbFilePath);
112        logger.info("-Done-");
113
114        logger.info("Initialising reasoner...");
115        ConsoleProgressMonitor mon = new ConsoleProgressMonitor();
116        OWLReasonerConfiguration reasonerConf = new SimpleConfiguration(mon);
117        OWLReasonerFactory reasonerFactory = new ElkReasonerFactory();
118        samplesReasoner = reasonerFactory.createReasoner(samplesKb, reasonerConf);
119        samplesReasoner.precomputeInferences(InferenceType.CLASS_HIERARCHY);
120        logger.info("-Done-");
121
122        logger.info(String.format(
123                "Reading MGI ID to GO class mappings from %s ...",
124                mgi2goClsMappingsFilePath));
125        mgi2go = new HashMap<>();
126        BufferedReader buffReader = new BufferedReader(
127                new FileReader(mgi2goClsMappingsFilePath));
128
129        String line;
130        while ((line = buffReader.readLine()) != null) {
131            if (line.startsWith("!")) continue;
132
133            String[] fields = line.split("\t");
134            String mgiId = fields[1];
135            // one or more of "NOT", "contributes_to", "co-localizes_with"
136            String qualifier = fields[3];
137            String goId = fields[4];
138            String evidenceCode = fields[6];
139
140            if (goId.trim().length() == 0) continue;  // skip lines not containing a GO ID
141
142            if (evidenceCode != "ND" && !qualifier.contains("NOT")) {
143                String goUriStr = oboPrefix + goId.replace(":", "_");
144
145                if (!mgi2go.containsKey(mgiId)) {
146                    mgi2go.put(mgiId, new TreeSet<>());
147                }
148                mgi2go.get(mgiId).add(goUriStr);
149                allClsUriStrs.add(goUriStr);
150            } else {
151                logger.debug(String.format("-- Skipping line \"%s\"", line));
152            }
153        }
154        buffReader.close();
155        logger.info("-Done-");
156
157        logger.info(String.format(
158                "Reading MGI ID to MP class mappings from %s ...",
159                mgi2mpClsMappingsFilePath));
160        mgi2mp = new HashMap<>();
161        buffReader = new BufferedReader(new FileReader(mgi2mpClsMappingsFilePath));
162
163        while ((line = buffReader.readLine()) != null) {
164            String[] fields = line.split("\t");
165            if (fields.length < 2) continue;
166
167            String mgiId = fields[0];
168
169            if (mgi2go.containsKey(mgiId)) {
170                String mpId = fields[1];
171                if (mpId.trim().length() == 0) continue;  // skip lines not containing an MP ID
172
173                String mpUriStr = oboPrefix + mpId.replace(":", "_");
174
175                if (!mgi2mp.containsKey(mgiId)) {
176                    mgi2mp.put(mgiId, new TreeSet<>());
177                }
178                mgi2mp.get(mgiId).add(mpUriStr);
179            }
180        }
181        buffReader.close();
182        logger.info("-Done-");
183    }
184
185    private static OWLOntology readKb(String filePath) throws OWLOntologyCreationException {
186        OWLOntologyManager man = OWLManager.createOWLOntologyManager();
187        return man.loadOntologyFromOntologyDocument(new File(filePath));
188    }
189
190    private static void initSemkernel() {
191        logger.info("Initialising the semkernel...");
192        kernel = new SemKernel();
193
194        // svm_train -s 0 ...
195        kernel.setSvmType(svmType);
196        // ... -t 5 ...  --> use semantic kernel (fixed for SemKernel anyway)
197        // ... -b 1 ...
198        kernel.setDoProbabilityEstimates(doProbabilityEstimates);
199        // ... -v 10 ...
200        kernel.setCrossValidationFolds(crossValidationFolds);
201        // ... -c $cost ...
202        kernel.setCost(cost);
203        // ... -f go.obo ...
204        kernel.setOntologyFilePath(kbFilePath);
205        // ... $TRAININGDIR/$NAME-$par > $CVDIR/$NAME-$par-$cost
206        kernel.setTrainingDirPath(trainDirPath);
207        kernel.setModelDirPath(modelDirPath );
208        kernel.setPredictionDataDirPath(predictionDirPath);
209        kernel.setResultsDirPath(resultsDirPath);
210
211        kernel.setGamma(0);
212        kernel.setPredictProbability(predictProbability);
213
214        kernel.init();
215        logger.info("-Done-");
216    }
217
218    private static void prepareTrainingData() throws IOException {
219        logger.info(String.format("Writing training sample data to file (%s)...",
220                trainDirPath));
221        for (String searchClassUriStr : trainClsUriStrs) {
222            String localPart = getLocalPart(searchClassUriStr);
223            String trainOutFilePath = trainDirPath + localPart;
224            logger.debug("-- " + trainOutFilePath);
225
226//            OWLClass searchCls = new OWLClassImpl(IRI.create(searchClassUriStr));
227            OWLClass searchCls = factory.getOWLClass(IRI.create(searchClassUriStr));
228
229            Set<String> classifierFor = new TreeSet<>();
230            classifierFor.add(searchClassUriStr);
231
232            Set<OWLClass> subClasses =
233                    samplesReasoner.getSubClasses(searchCls, false).getFlattened();
234
235            for (OWLClass owlClass : subClasses) {
236                String uriStr = owlClass.getIRI().toString();
237                classifierFor.add(uriStr);
238            }
239
240            List<String> negatives = new ArrayList<>();
241            List<String> positives = new ArrayList<>();
242
243            // build lines to write to file (SVM light format)
244            for (String mgiId : mgi2mp.keySet()) {
245                String outputLine = "";
246
247                Set<String> mpUriStrs = mgi2mp.get(mgiId);
248
249                if (Sets.intersection(classifierFor, mpUriStrs).isEmpty()) {
250                    outputLine += "0";
251                } else {
252                    outputLine += "1";
253                }
254
255                if (!mgi2go.containsKey(mgiId)) continue;
256                for (String goUristr : mgi2go.get(mgiId)) {
257                    outputLine += "\t" + goUristr;
258                }
259
260                if (Sets.intersection(classifierFor, mpUriStrs).isEmpty()) {
261                    negatives.add(outputLine);
262                } else {
263                    positives.add(outputLine);
264                }
265            }
266
267            // shorten negative SVM light lines set to the configured
268            // positives-negatives ratio
269            Collections.shuffle(negatives);
270            int cutoff = (int) Math.round(posNegExampleRatio * positives.size());
271            if (cutoff < negatives.size()) {
272                negatives = negatives.subList(0, cutoff);
273            }
274
275            BufferedWriter buffWriter = new BufferedWriter(new FileWriter(trainOutFilePath));
276
277            for (String posLine : positives) {
278                buffWriter.write(posLine);
279                buffWriter.newLine();
280            }
281            for (String negLine : negatives) {
282                buffWriter.write(negLine);
283                buffWriter.newLine();
284            }
285
286            buffWriter.close();
287        }
288        logger.info("-Done-");
289    }
290
291    private static void preparePredictionData() throws IOException {
292        logger.info("Preparing prediction data...");
293        for (String predClsUriStr : predictClsUriStrs) {
294            String localPart = getLocalPart(predClsUriStr);
295            String predOutFilePath = predictionDirPath + localPart;
296            logger.debug("-- " + predOutFilePath);
297
298            BufferedWriter buffWriter = new BufferedWriter(
299                    new FileWriter(predOutFilePath));
300//            buffWriter.write("map");
301//            for (String clsUri : allClsUriStrs) {
302//                buffWriter.write("\t" + clsUri);
303//            }
304//            buffWriter.newLine();
305
306            Set<String> classifierFor = new TreeSet<>();
307            classifierFor.add(predClsUriStr);
308
309            OWLClass predCls = factory.getOWLClass(IRI.create(predClsUriStr));
310            Set<OWLClass> subClasses =
311                    samplesReasoner.getSubClasses(predCls, false).getFlattened();
312
313            for (OWLClass subClass : subClasses) {
314                String uriStr = subClass.getIRI().toString();
315                classifierFor.add(uriStr);
316            }
317
318            List<String> negatives = new ArrayList<>();
319            List<String> positives = new ArrayList<>();
320
321            // build lines to write to file (SVM light format)
322            for (String mgiId : mgi2mp.keySet()) {
323                String outputLine = "";
324
325                Set<String> mpUriStrs = mgi2mp.get(mgiId);
326
327                if (Sets.intersection(classifierFor, mpUriStrs).isEmpty()) {
328                    outputLine += "0";
329                } else {
330                    outputLine += "1";
331                }
332
333                if (!mgi2go.containsKey(mgiId)) continue;
334                for (String goUristr : mgi2go.get(mgiId)) {
335                    outputLine += "\t" + goUristr;
336                }
337
338                if (Sets.intersection(classifierFor, mpUriStrs).isEmpty()) {
339                    negatives.add(outputLine);
340                } else {
341                    positives.add(outputLine);
342                }
343            }
344
345            // shorten negative SVM light lines set to the configured
346            // positives-negatives ratio
347            Collections.shuffle(negatives);
348            int cutoff = (int) Math.round(posNegExampleRatio * positives.size());
349            if (cutoff < negatives.size()) {
350                negatives = negatives.subList(0, cutoff);
351            }
352
353            for (String posLine : positives) {
354                buffWriter.write(posLine);
355                buffWriter.newLine();
356            }
357            for (String negLine : negatives) {
358                buffWriter.write(negLine);
359                buffWriter.newLine();
360            }
361
362            buffWriter.close();
363        }
364        logger.info("-Done-");
365    }
366
367    private static String getLocalPart(String uriStr) {
368        int lastSlashIdx = uriStr.lastIndexOf('/');
369        if (lastSlashIdx > -1) {
370            return uriStr.substring(lastSlashIdx);
371
372        } else {  // try looking for a hash sign
373            int lastHashIdx = uriStr.lastIndexOf('#');
374            if (lastHashIdx > -1) {
375                return uriStr.substring(lastHashIdx);
376            }
377        }
378
379        return uriStr;
380    }
381
382    private static void train() {
383        logger.info("Training...");
384        kernel.train();
385        logger.info("-Done-");
386    }
387
388    private static void predict() {
389        logger.info("Running prediction...");
390        kernel.predict();
391        logger.info("-Done-");
392    }
393}