001/**
002 * Copyright (C) 2007 - 2016, Jens Lehmann
003 *
004 * This file is part of DL-Learner.
005 *
006 * DL-Learner is free software; you can redistribute it and/or modify
007 * it under the terms of the GNU General Public License as published by
008 * the Free Software Foundation; either version 3 of the License, or
009 * (at your option) any later version.
010 *
011 * DL-Learner is distributed in the hope that it will be useful,
012 * but WITHOUT ANY WARRANTY; without even the implied warranty of
013 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
014 * GNU General Public License for more details.
015 *
016 * You should have received a copy of the GNU General Public License
017 * along with this program.  If not, see <http://www.gnu.org/licenses/>.
018 */
019package org.dllearner.utilities.semkernel;
020
021import java.io.BufferedReader;
022import java.io.BufferedWriter;
023import java.io.File;
024import java.io.FileReader;
025import java.io.FileWriter;
026import java.io.IOException;
027import java.util.*;
028
029import org.apache.log4j.Logger;
030import org.dllearner.algorithms.semkernel.SemKernel;
031import org.dllearner.algorithms.semkernel.SemKernel.SvmType;
032import org.dllearner.core.ComponentAnn;
033import org.dllearner.core.ComponentInitException;
034import org.semanticweb.elk.owlapi.ElkReasonerFactory;
035import org.semanticweb.owlapi.apibinding.OWLManager;
036import org.semanticweb.owlapi.model.IRI;
037import org.semanticweb.owlapi.model.OWLClass;
038import org.semanticweb.owlapi.model.OWLDataFactory;
039import org.semanticweb.owlapi.model.OWLOntology;
040import org.semanticweb.owlapi.model.OWLOntologyCreationException;
041import org.semanticweb.owlapi.model.OWLOntologyManager;
042import org.semanticweb.owlapi.reasoner.ConsoleProgressMonitor;
043import org.semanticweb.owlapi.reasoner.InferenceType;
044import org.semanticweb.owlapi.reasoner.OWLReasoner;
045import org.semanticweb.owlapi.reasoner.OWLReasonerConfiguration;
046import org.semanticweb.owlapi.reasoner.OWLReasonerFactory;
047import org.semanticweb.owlapi.reasoner.SimpleConfiguration;
048
049import uk.ac.manchester.cs.owl.owlapi.OWLClassImpl;
050
051import com.google.common.collect.Sets;
052
053/**
054 * Since the current setup for running a SemKernel example comprises several
055 * steps, like preparing the training data, do the training, preparing the
056 * prediction data and so on, this component is intended to encapsulate this
057 * whole process and make it callable and configurable via the standard
058 * DL-Learner CLI.
059 * As already said, there are different steps, depending on the tasks to solve
060 *
061 * - training:
062 *   T1) read URIs to train
063 *   T2) read the underlying MP knowledge base
064 *   T3) write the prepared training data (in SVM light format) to the training
065 *       data directory
066 *   T4) read the GO knowledge base
067 *   T5) do the training run on the GO knowledge base and write out the
068 *       training model to the model directory
069 *
070 * - prediction
071 *   P1) read URIs to predict
072 *   P2) read the underlying knowledge base (if not done already)
073 *   P3) write out the prepared prediction data to the prediction data directory
074 *   P4) read the GO knowledge base
075 *   P5) do the prediction based on the GO knowledge base and write out the
076 *       prediction results to the result directory
077 *
078 * Additionally this workflow also uses special files that contain association
079 * mappings between MGI marker accession IDs and gene functions (i.e. GO
080 * classes) called MGI2GO and between MGI marker accession IDs and phenotypes
081 * (i.e. MP classes) called MGI2MP.
082 * So after reading the phenotype class URIs to train/predict the corresponding
083 * gene functions are determined via these mapping files as follows:
084 *
085 * classifierFor set = { input MP class and all its subclasses defined in the
086 *                       underlying knowledge base from T2/P2 }
087 *
088 * for MGI ID in MGI2MP {
089 *      MP classes = MGI2MP.get(MGI ID)
090 *
091 *      if the classifiersFor set and the MP classes contain common classes {
092 *          // this MGI ID is a positive sample
093 *          GO classes = MGI2GO.get(MGI ID)
094 *          add GO classes to positive samples
095 *
096 *      } else {
097 *          // this MGI ID is a negative sample
098 *          GO classes = MGI2GO.get(MGI ID)
099 *          add GO classes to negative samples
100 *      }
101 * }
102 *
103 * So a rough illustration (neglecting the determination of positive/negative
104 * samples) would look like this:
105 *
106 * MP class --(MGI2MP)--> MGI ID --(MGI2GO)--> GO classes
107 *
108 * So, to wrap up, the following files and directories need to be specified:
109 *
110 * T1) file containing MP URIs to train (one URI per line) (trainURIsFilePath)
111 * T2) ontology file containing the MP ontology to derive all subclasses of the
112 *     MP class to train from (mpKBFilePath)
113 * T3) directory where the prepared training input data in SVM light format
114 *     should be written to (one file per MP class to train)
115 *     (trainingInputDirectoryPath)
116 * T4) ontology file containing the GO ontology used to derive a semantic
117 *     similarity between GO classes which is used by the SemKernel
118 *     (goKBFilePath)
119 * T5) directory where the model data should be written to (one file per MP
120 *     class to train) (trainingOutputDirectoryPath)
121 *
122 * P1) file containing MP URIs to calculate a prediction for (one URI per line)
123 *     (predictionURIsFilePath)
124 * P2) see T2)
125 * P3) directory where the prepared prediction input data in SVM light format
126 *     should be written to (one file per MP class to calculate a prediction
127 *     for) (predictionInputDirectoryPath)
128 * P4) see T4)
129 * P5) directory where the prediction output should be written to (one file
130 *     per MP class the prediction was made for) (predictionOutputDirectoryPath)
131 *
132 * - file containing the MGI ID to MP class mappings (mgi2mpMappingsFilePath)
133 * - file containing the MGI ID to GO class association mappings
134 *   (mgi2goMappingsFilePath)
135 *
136 * @author Patrick Westphal
137 */
138@ComponentAnn(name="Mammalian Phenotype SemKernel Workflow", shortName="mpskw", version=0.1)
139public class MPSemKernelWorkflow extends SemKernelWorkflow {
140
141    // ------------------- files and directories to specify -------------------
142    /** file containing MP URIs to train (one URI per line) */
143    private String trainURIsFilePath;
144    /**
145     * ontology file containing the MP ontology to derive all subclasses of the
146     * MP class to train from */
147    private String mpKBFilePath;
148    /**
149     * directory where the prepared training input data in SVM light format
150     * should be written to (one file per MP class to train) */
151    private String trainingInputDirectoryPath;
152    /**
153     * ontology file containing the GO ontology used to derive a semantic
154     * similarity between GO classes which is used by the SemKernel */
155    private String goKBFilePath;
156    /**
157     * directory where the model data should be written to (one file per MP
158     * class to train) */
159    private String trainingOutputDirectoryPath;
160    /**
161     * file containing MP URIs to calculate a prediction for (one URI per line)
162     */
163    private String predictionURIsFilePath;
164    /**
165     * directory where the prepared prediction input data in SVM light format
166     * should be written to (one file per MP class to calculate a prediction
167     * for */
168    private String predictionInputDirectoryPath;
169    /**
170     * directory where the prediction output should be written to (one file
171     * per MP class the prediction was made for) */
172    private String predictionOutputDirectoryPath;
173    /** file containing the MGI ID to MP class mappings */
174    private String mgi2mpMappingsFilePath;
175    /** file containing the MGI ID to GO class association mappings */
176    private String mgi2goMappingsFilePath;
177
178    // -------------------------- SemKernel settings --------------------------
179    private SemKernel kernel;
180    private SvmType svmType = SvmType.C_SVC;
181    private boolean doProbabilityEstimates = true;
182    private int crossValidationFolds = 10;
183    private float cost = 5f;
184    private boolean predictProbability = true;
185    private double posNegExampleRatio = 1;
186    private boolean doTraining = true;
187    private boolean doPrediction = true;
188
189    // -------------------------------- misc ---------------------------------
190    private final Logger logger = Logger.getLogger(MPSemKernelWorkflow.class);
191    private OWLDataFactory dataFactory;
192    private OWLOntology mpKB;
193    private OWLReasoner mpKBReasoner;
194    private Map<String, Set<String>> mgi2mp;
195    private Map<String, Set<String>> mgi2go;
196    private final String oboPrefix = "http://purl.obolibrary.org/obo/";
197
198    @Override
199    public void init() throws ComponentInitException {
200        logger.info("Inializing workflow...");
201        dataFactory = OWLManager.getOWLDataFactory();
202
203        OWLOntologyManager man = OWLManager.createOWLOntologyManager();
204        try {
205            // T2)/P2) ------------
206            mpKB = man.loadOntologyFromOntologyDocument(new File(mpKBFilePath));
207        } catch (OWLOntologyCreationException e) {
208            e.printStackTrace();
209            System.exit(1);
210        }
211        ConsoleProgressMonitor mon = new ConsoleProgressMonitor();
212        OWLReasonerConfiguration reasonerConf = new SimpleConfiguration(mon);
213        OWLReasonerFactory reasonerFactory = new ElkReasonerFactory();
214        mpKBReasoner = reasonerFactory.createReasoner(mpKB, reasonerConf);
215        mpKBReasoner.precomputeInferences(InferenceType.CLASS_HIERARCHY);
216
217        try {
218            mgi2go = readMGI2GOMapping(mgi2goMappingsFilePath);
219            mgi2mp = readMGI2MPMapping(mgi2mpMappingsFilePath);
220        } catch (IOException e) {
221            e.printStackTrace();
222            System.exit(1);
223        }
224
225        kernel = new SemKernel();
226        kernel.setSvmType(svmType);
227        kernel.setDoProbabilityEstimates(doProbabilityEstimates);
228        kernel.setCrossValidationFolds(crossValidationFolds);
229        kernel.setCost(cost);
230        kernel.setOntologyFilePath(goKBFilePath);
231        kernel.setTrainingDirPath(trainingInputDirectoryPath);
232        kernel.setModelDirPath(trainingOutputDirectoryPath);
233        kernel.setPredictionDataDirPath(predictionInputDirectoryPath);
234        kernel.setResultsDirPath(predictionOutputDirectoryPath);
235        kernel.setGamma(0);
236        kernel.setPredictProbability(predictProbability);
237        kernel.init();
238
239        initialized = true;
240        logger.info("Finished workflow initialization.");
241    }
242
243    @Override
244    public void start() {
245        if (doTraining) {
246            logger.info("Preparing training data...");
247            try {
248                prepareMPSampleTrainingData();
249            } catch (IOException e) {
250                e.printStackTrace();
251                System.exit(1);
252            }
253            logger.info("Finished training data preparation.");
254
255            // T4) done during kernel.init() ------------
256            // T5) ------------
257            logger.info("Training...");
258            kernel.train();
259            logger.info("Finished trainig.");
260        }
261
262        if (doPrediction) {
263            logger.info("Preparing prediction data...");
264            try {
265                prepareMPPredictionData();
266            } catch (IOException e) {
267                e.printStackTrace();
268                System.exit(1);
269            }
270            logger.info("Finished prediction data preparation.");
271
272            // P4) done during kernel.init() ------------
273            // P5) ------------
274            logger.info("Doing predictions...");
275            kernel.predict();
276            logger.info("Finished prediction.");
277        }
278    }
279
280    private Map<String, Set<String>> readMGI2MPMapping(String mgi2mpFilePath)
281            throws IOException {
282
283        Map<String, Set<String>> mgi2mp = new HashMap<>();
284
285        BufferedReader bufferedReader = new BufferedReader(new FileReader(
286                new File(mgi2mpFilePath)));
287
288        String line;
289        while ((line = bufferedReader.readLine()) != null) {
290            String[] fields = line.split("\t");
291            if (fields.length < 2) continue;
292
293            String mgiId = fields[0];
294
295            if (mgi2go.containsKey(mgiId)) {
296                String mpId = fields[1];
297                if (mpId.trim().length() == 0) continue;  // skip lines not containing an MP ID
298
299                String mpUriStr = oboPrefix + mpId.replace(":", "_");
300
301                if (!mgi2mp.containsKey(mgiId)) {
302                    mgi2mp.put(mgiId, new TreeSet<>());
303                }
304                mgi2mp.get(mgiId).add(mpUriStr);
305            }
306        }
307        bufferedReader.close();
308
309        return mgi2mp;
310    }
311
312    public void prepareMPSampleTrainingData() throws IOException {
313        // append path separator if not set already
314        if (!trainingInputDirectoryPath.endsWith(File.separator)) {
315            trainingInputDirectoryPath = trainingInputDirectoryPath + File.separator;
316        }
317
318        // T1) ------------
319        Set<String> trainUriStrs = readTrainURIs(trainURIsFilePath);
320
321        // T3) (for each URI from T1)) ------------
322        for (String searchClassUriStr : trainUriStrs) {
323            String localPart = getLocalPart(searchClassUriStr);
324
325            String trainOutFilePath = trainingInputDirectoryPath + localPart;
326
327            OWLClass searchCls = new OWLClassImpl(IRI.create(searchClassUriStr));
328
329            Set<String> classifierFor = new TreeSet<>();
330            classifierFor.add(searchClassUriStr);
331
332            Set<OWLClass> subClasses =
333                    mpKBReasoner.getSubClasses(searchCls, false).getFlattened();
334
335            for (OWLClass owlClass : subClasses) {
336                String uriStr = owlClass.getIRI().toString();
337                classifierFor.add(uriStr);
338            }
339
340            List<String> negatives = new ArrayList<>();
341            List<String> positives = new ArrayList<>();
342
343            // build lines to write to file (SVM light format)
344            for (String mgiId : mgi2mp.keySet()) {
345                String outputLine = "";
346
347                Set<String> mpUriStrs = mgi2mp.get(mgiId);
348
349                if (Sets.intersection(classifierFor, mpUriStrs).isEmpty()) {
350                    outputLine += "0";
351                } else {
352                    outputLine += "1";
353                }
354
355                if (!mgi2go.containsKey(mgiId)) continue;
356                for (String goUristr : mgi2go.get(mgiId)) {
357                    outputLine += "\t" + goUristr;
358                }
359
360                if (Sets.intersection(classifierFor, mpUriStrs).isEmpty()) {
361                    negatives.add(outputLine);
362                } else {
363                    positives.add(outputLine);
364                }
365            }
366
367            // shorten negative SVM light lines set to the configured
368            // positives-negatives ratio
369            Collections.shuffle(negatives);
370            int cutoff = (int) Math.round(posNegExampleRatio * positives.size());
371            if (cutoff < negatives.size()) {
372                negatives = negatives.subList(0, cutoff);
373            }
374
375            BufferedWriter buffWriter = new BufferedWriter(new FileWriter(trainOutFilePath));
376
377            for (String posLine : positives) {
378                buffWriter.write(posLine);
379                buffWriter.newLine();
380            }
381            for (String negLine : negatives) {
382                buffWriter.write(negLine);
383                buffWriter.newLine();
384            }
385            buffWriter.close();
386        }
387    }
388
389    private void prepareMPPredictionData() throws IOException {
390        if (!predictionInputDirectoryPath.endsWith(File.separator)) {
391            predictionInputDirectoryPath += File.separator;
392        }
393
394        // P1) ------------
395        Set<String> predictClsUriStrs = readTrainURIs(predictionURIsFilePath);
396
397        // P3) (for each URI from P1)) ------------
398        for (String predClsUriStr : predictClsUriStrs) {
399            String localPart = getLocalPart(predClsUriStr);
400            String predOutFilePath = predictionInputDirectoryPath + localPart;
401
402            BufferedWriter buffWriter = new BufferedWriter(
403                    new FileWriter(predOutFilePath));
404//            buffWriter.write("map");
405//            for (String clsUri : allClsUriStrs) {
406//                buffWriter.write("\t" + clsUri);
407//            }
408//            buffWriter.newLine();
409
410            Set<String> classifierFor = new TreeSet<>();
411            classifierFor.add(predClsUriStr);
412
413            OWLClass predCls = new OWLClassImpl(IRI.create(predClsUriStr));
414            Set<OWLClass> subClasses =
415                    mpKBReasoner.getSubClasses(predCls, false).getFlattened();
416
417            for (OWLClass subClass : subClasses) {
418                String uriStr = subClass.getIRI().toString();
419                classifierFor.add(uriStr);
420            }
421
422            List<String> negatives = new ArrayList<>();
423            List<String> positives = new ArrayList<>();
424
425            // build lines to write to file (SVM light format)
426            for (String mgiId : mgi2mp.keySet()) {
427                String outputLine = "";
428
429                Set<String> mpUriStrs = mgi2mp.get(mgiId);
430
431                if (Sets.intersection(classifierFor, mpUriStrs).isEmpty()) {
432                    outputLine += "0";
433                } else {
434                    outputLine += "1";
435                }
436
437                if (!mgi2go.containsKey(mgiId)) continue;
438                for (String goUristr : mgi2go.get(mgiId)) {
439                    outputLine += "\t" + goUristr;
440                }
441
442                if (Sets.intersection(classifierFor, mpUriStrs).isEmpty()) {
443                    negatives.add(outputLine);
444                } else {
445                    positives.add(outputLine);
446                }
447            }
448
449            // shorten negative SVM light lines set to the configured
450            // positives-negatives ratio
451            Collections.shuffle(negatives);
452            int cutoff = (int) Math.round(posNegExampleRatio * positives.size());
453            if (cutoff < negatives.size()) {
454                negatives = negatives.subList(0, cutoff);
455            }
456
457            for (String posLine : positives) {
458                buffWriter.write(posLine);
459                buffWriter.newLine();
460            }
461            for (String negLine : negatives) {
462                buffWriter.write(negLine);
463                buffWriter.newLine();
464            }
465            buffWriter.close();
466        }
467    }
468
469    private static Set<String> readTrainURIs(String trainURIsFilePath) throws IOException {
470        Set<String> uriStrs = new HashSet<>();
471        BufferedReader bufferedReader = new BufferedReader(
472                new FileReader(new File(trainURIsFilePath)));
473
474        String line;
475        while ((line = bufferedReader.readLine()) != null) {
476            // strip off leading and trailing angled bracket
477            if (line.startsWith("<") && line.endsWith(">")) {
478                line = line.substring(1, line.length()-1);
479            }
480
481            uriStrs.add(line);
482        }
483        bufferedReader.close();
484
485        return uriStrs;
486    }
487
488    private Map<String, Set<String>> readMGI2GOMapping(String mgi2goFilePath)
489            throws IOException {
490        Map<String, Set<String>> mgi2go = new HashMap<>();
491
492        BufferedReader bufferedReader = new BufferedReader(new FileReader(
493                new File(mgi2goFilePath)));
494
495        String line;
496        while ((line = bufferedReader.readLine()) != null) {
497            if (line.startsWith("!")) continue;
498
499            String[] fields = line.split("\t");
500            String mgiId = fields[1];
501            // one or more of "NOT", "contributes_to", "co-localizes_with"
502            String qualifier = fields[3];
503            String goId = fields[4];
504            String evidenceCode = fields[6];
505
506            if (goId.trim().length() == 0) continue;  // skip lines not containing a GO ID
507
508            if (!Objects.equals(evidenceCode, "ND") && !qualifier.contains("NOT")) {
509                String goUriStr = oboPrefix + goId.replace(":", "_");
510
511                if (!mgi2go.containsKey(mgiId)) {
512                    mgi2go.put(mgiId, new TreeSet<>());
513                }
514                mgi2go.get(mgiId).add(goUriStr);
515            }
516        }
517        bufferedReader.close();
518
519        return mgi2go;
520    }
521
522    private static String getLocalPart(String uriStr) {
523        int lastSlashIdx = uriStr.lastIndexOf('/');
524        if (lastSlashIdx > -1) {
525            return uriStr.substring(lastSlashIdx);
526
527        } else {  // try looking for a hash sign
528            int lastHashIdx = uriStr.lastIndexOf('#');
529            if (lastHashIdx > -1) {
530                return uriStr.substring(lastHashIdx);
531            }
532        }
533
534        return uriStr;
535    }
536
537    // -------------------- only getters and setters below --------------------
538    public String getTrainURIsFilePath() {
539        return trainURIsFilePath;
540    }
541
542    public void setTrainURIsFilePath(String trainURIsFilePath) {
543        this.trainURIsFilePath = trainURIsFilePath;
544    }
545
546    public String getMpKBFilePath() {
547        return mpKBFilePath;
548    }
549
550    public void setMpKBFilePath(String mpKBFilePath) {
551        this.mpKBFilePath = mpKBFilePath;
552    }
553
554    public String getTrainingInputDirectoryPath() {
555        return trainingInputDirectoryPath;
556    }
557
558    public void setTrainingInputDirectoryPath(String trainingInputDirectoryPath) {
559        this.trainingInputDirectoryPath = trainingInputDirectoryPath;
560    }
561
562    public String getGoKBFilePath() {
563        return goKBFilePath;
564    }
565
566    public void setGoKBFilePath(String goKBFilePath) {
567        this.goKBFilePath = goKBFilePath;
568    }
569
570    public String getTrainingOutputDirectoryPath() {
571        return trainingOutputDirectoryPath;
572    }
573
574    public void setTrainingOutputDirectoryPath(
575            String trainingOutputDirectoryPath) {
576        this.trainingOutputDirectoryPath = trainingOutputDirectoryPath;
577    }
578
579    public String getPredictionURIsFilePath() {
580        return predictionURIsFilePath;
581    }
582
583    public void setPredictionURIsFilePath(String predictionURIsFilePath) {
584        this.predictionURIsFilePath = predictionURIsFilePath;
585    }
586
587    public String getPredictionInputDirectoryPath() {
588        return predictionInputDirectoryPath;
589    }
590
591    public void setPredictionInputDirectoryPath(
592            String predictionInputDirectoryPath) {
593        this.predictionInputDirectoryPath = predictionInputDirectoryPath;
594    }
595
596    public String getPredictionOutputDirectoryPath() {
597        return predictionOutputDirectoryPath;
598    }
599
600    public void setPredictionOutputDirectoryPath(
601            String predictionOutputDirectoryPath) {
602        this.predictionOutputDirectoryPath = predictionOutputDirectoryPath;
603    }
604
605    public String getMgi2mpMappingsFilePath() {
606        return mgi2mpMappingsFilePath;
607    }
608
609    public void setMgi2mpMappingsFilePath(String mgi2mpMappingsFilePath) {
610        this.mgi2mpMappingsFilePath = mgi2mpMappingsFilePath;
611    }
612
613    public String getMgi2goMappingsFilePath() {
614        return mgi2goMappingsFilePath;
615    }
616
617    public void setMgi2goMappingsFilePath(String mgi2goMappingsFilePath) {
618        this.mgi2goMappingsFilePath = mgi2goMappingsFilePath;
619    }
620
621    public SvmType getSvmType() {
622        return svmType;
623    }
624
625    public void setSvmType(SvmType svmType) {
626        this.svmType = svmType;
627    }
628
629    public boolean isDoProbabilityEstimates() {
630        return doProbabilityEstimates;
631    }
632
633    public void setDoProbabilityEstimates(boolean doProbabilityEstimates) {
634        this.doProbabilityEstimates = doProbabilityEstimates;
635    }
636
637    public int getCrossValidationFolds() {
638        return crossValidationFolds;
639    }
640
641    public void setCrossValidationFolds(int crossValidationFolds) {
642        this.crossValidationFolds = crossValidationFolds;
643    }
644
645    public float getCost() {
646        return cost;
647    }
648
649    public void setCost(float cost) {
650        this.cost = cost;
651    }
652
653    public boolean isPredictProbability() {
654        return predictProbability;
655    }
656
657    public void setPredictProbability(boolean predictProbability) {
658        this.predictProbability = predictProbability;
659    }
660
661    public double getPosNegExampleRatio() {
662        return posNegExampleRatio;
663    }
664
665    public void setPosNegExampleRatio(double posNegExampleRatio) {
666        this.posNegExampleRatio = posNegExampleRatio;
667    }
668
669    public boolean isDoTraining() {
670        return doTraining;
671    }
672
673    public void setDoTraining(boolean doTraining) {
674        this.doTraining = doTraining;
675    }
676
677    public boolean isDoPrediction() {
678        return doPrediction;
679    }
680
681    public void setDoPrediction(boolean doPrediction) {
682        this.doPrediction = doPrediction;
683    }
684
685}