001/** 002 * Copyright (C) 2007 - 2016, Jens Lehmann 003 * 004 * This file is part of DL-Learner. 005 * 006 * DL-Learner is free software; you can redistribute it and/or modify 007 * it under the terms of the GNU General Public License as published by 008 * the Free Software Foundation; either version 3 of the License, or 009 * (at your option) any later version. 010 * 011 * DL-Learner is distributed in the hope that it will be useful, 012 * but WITHOUT ANY WARRANTY; without even the implied warranty of 013 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 014 * GNU General Public License for more details. 015 * 016 * You should have received a copy of the GNU General Public License 017 * along with this program. If not, see <http://www.gnu.org/licenses/>. 018 */ 019package org.dllearner.utilities.semkernel; 020 021import java.io.BufferedReader; 022import java.io.BufferedWriter; 023import java.io.File; 024import java.io.FileReader; 025import java.io.FileWriter; 026import java.io.IOException; 027import java.util.*; 028 029import org.apache.log4j.Logger; 030import org.dllearner.algorithms.semkernel.SemKernel; 031import org.dllearner.algorithms.semkernel.SemKernel.SvmType; 032import org.dllearner.core.ComponentAnn; 033import org.dllearner.core.ComponentInitException; 034import org.semanticweb.elk.owlapi.ElkReasonerFactory; 035import org.semanticweb.owlapi.apibinding.OWLManager; 036import org.semanticweb.owlapi.model.IRI; 037import org.semanticweb.owlapi.model.OWLClass; 038import org.semanticweb.owlapi.model.OWLDataFactory; 039import org.semanticweb.owlapi.model.OWLOntology; 040import org.semanticweb.owlapi.model.OWLOntologyCreationException; 041import org.semanticweb.owlapi.model.OWLOntologyManager; 042import org.semanticweb.owlapi.reasoner.ConsoleProgressMonitor; 043import org.semanticweb.owlapi.reasoner.InferenceType; 044import org.semanticweb.owlapi.reasoner.OWLReasoner; 045import org.semanticweb.owlapi.reasoner.OWLReasonerConfiguration; 046import org.semanticweb.owlapi.reasoner.OWLReasonerFactory; 047import org.semanticweb.owlapi.reasoner.SimpleConfiguration; 048 049import uk.ac.manchester.cs.owl.owlapi.OWLClassImpl; 050 051import com.google.common.collect.Sets; 052 053/** 054 * Since the current setup for running a SemKernel example comprises several 055 * steps, like preparing the training data, do the training, preparing the 056 * prediction data and so on, this component is intended to encapsulate this 057 * whole process and make it callable and configurable via the standard 058 * DL-Learner CLI. 059 * As already said, there are different steps, depending on the tasks to solve 060 * 061 * - training: 062 * T1) read URIs to train 063 * T2) read the underlying MP knowledge base 064 * T3) write the prepared training data (in SVM light format) to the training 065 * data directory 066 * T4) read the GO knowledge base 067 * T5) do the training run on the GO knowledge base and write out the 068 * training model to the model directory 069 * 070 * - prediction 071 * P1) read URIs to predict 072 * P2) read the underlying knowledge base (if not done already) 073 * P3) write out the prepared prediction data to the prediction data directory 074 * P4) read the GO knowledge base 075 * P5) do the prediction based on the GO knowledge base and write out the 076 * prediction results to the result directory 077 * 078 * Additionally this workflow also uses special files that contain association 079 * mappings between MGI marker accession IDs and gene functions (i.e. GO 080 * classes) called MGI2GO and between MGI marker accession IDs and phenotypes 081 * (i.e. MP classes) called MGI2MP. 082 * So after reading the phenotype class URIs to train/predict the corresponding 083 * gene functions are determined via these mapping files as follows: 084 * 085 * classifierFor set = { input MP class and all its subclasses defined in the 086 * underlying knowledge base from T2/P2 } 087 * 088 * for MGI ID in MGI2MP { 089 * MP classes = MGI2MP.get(MGI ID) 090 * 091 * if the classifiersFor set and the MP classes contain common classes { 092 * // this MGI ID is a positive sample 093 * GO classes = MGI2GO.get(MGI ID) 094 * add GO classes to positive samples 095 * 096 * } else { 097 * // this MGI ID is a negative sample 098 * GO classes = MGI2GO.get(MGI ID) 099 * add GO classes to negative samples 100 * } 101 * } 102 * 103 * So a rough illustration (neglecting the determination of positive/negative 104 * samples) would look like this: 105 * 106 * MP class --(MGI2MP)--> MGI ID --(MGI2GO)--> GO classes 107 * 108 * So, to wrap up, the following files and directories need to be specified: 109 * 110 * T1) file containing MP URIs to train (one URI per line) (trainURIsFilePath) 111 * T2) ontology file containing the MP ontology to derive all subclasses of the 112 * MP class to train from (mpKBFilePath) 113 * T3) directory where the prepared training input data in SVM light format 114 * should be written to (one file per MP class to train) 115 * (trainingInputDirectoryPath) 116 * T4) ontology file containing the GO ontology used to derive a semantic 117 * similarity between GO classes which is used by the SemKernel 118 * (goKBFilePath) 119 * T5) directory where the model data should be written to (one file per MP 120 * class to train) (trainingOutputDirectoryPath) 121 * 122 * P1) file containing MP URIs to calculate a prediction for (one URI per line) 123 * (predictionURIsFilePath) 124 * P2) see T2) 125 * P3) directory where the prepared prediction input data in SVM light format 126 * should be written to (one file per MP class to calculate a prediction 127 * for) (predictionInputDirectoryPath) 128 * P4) see T4) 129 * P5) directory where the prediction output should be written to (one file 130 * per MP class the prediction was made for) (predictionOutputDirectoryPath) 131 * 132 * - file containing the MGI ID to MP class mappings (mgi2mpMappingsFilePath) 133 * - file containing the MGI ID to GO class association mappings 134 * (mgi2goMappingsFilePath) 135 * 136 * @author Patrick Westphal 137 */ 138@ComponentAnn(name="Mammalian Phenotype SemKernel Workflow", shortName="mpskw", version=0.1) 139public class MPSemKernelWorkflow extends SemKernelWorkflow { 140 141 // ------------------- files and directories to specify ------------------- 142 /** file containing MP URIs to train (one URI per line) */ 143 private String trainURIsFilePath; 144 /** 145 * ontology file containing the MP ontology to derive all subclasses of the 146 * MP class to train from */ 147 private String mpKBFilePath; 148 /** 149 * directory where the prepared training input data in SVM light format 150 * should be written to (one file per MP class to train) */ 151 private String trainingInputDirectoryPath; 152 /** 153 * ontology file containing the GO ontology used to derive a semantic 154 * similarity between GO classes which is used by the SemKernel */ 155 private String goKBFilePath; 156 /** 157 * directory where the model data should be written to (one file per MP 158 * class to train) */ 159 private String trainingOutputDirectoryPath; 160 /** 161 * file containing MP URIs to calculate a prediction for (one URI per line) 162 */ 163 private String predictionURIsFilePath; 164 /** 165 * directory where the prepared prediction input data in SVM light format 166 * should be written to (one file per MP class to calculate a prediction 167 * for */ 168 private String predictionInputDirectoryPath; 169 /** 170 * directory where the prediction output should be written to (one file 171 * per MP class the prediction was made for) */ 172 private String predictionOutputDirectoryPath; 173 /** file containing the MGI ID to MP class mappings */ 174 private String mgi2mpMappingsFilePath; 175 /** file containing the MGI ID to GO class association mappings */ 176 private String mgi2goMappingsFilePath; 177 178 // -------------------------- SemKernel settings -------------------------- 179 private SemKernel kernel; 180 private SvmType svmType = SvmType.C_SVC; 181 private boolean doProbabilityEstimates = true; 182 private int crossValidationFolds = 10; 183 private float cost = 5f; 184 private boolean predictProbability = true; 185 private double posNegExampleRatio = 1; 186 private boolean doTraining = true; 187 private boolean doPrediction = true; 188 189 // -------------------------------- misc --------------------------------- 190 private final Logger logger = Logger.getLogger(MPSemKernelWorkflow.class); 191 private OWLDataFactory dataFactory; 192 private OWLOntology mpKB; 193 private OWLReasoner mpKBReasoner; 194 private Map<String, Set<String>> mgi2mp; 195 private Map<String, Set<String>> mgi2go; 196 private final String oboPrefix = "http://purl.obolibrary.org/obo/"; 197 198 @Override 199 public void init() throws ComponentInitException { 200 logger.info("Inializing workflow..."); 201 dataFactory = OWLManager.getOWLDataFactory(); 202 203 OWLOntologyManager man = OWLManager.createOWLOntologyManager(); 204 try { 205 // T2)/P2) ------------ 206 mpKB = man.loadOntologyFromOntologyDocument(new File(mpKBFilePath)); 207 } catch (OWLOntologyCreationException e) { 208 e.printStackTrace(); 209 System.exit(1); 210 } 211 ConsoleProgressMonitor mon = new ConsoleProgressMonitor(); 212 OWLReasonerConfiguration reasonerConf = new SimpleConfiguration(mon); 213 OWLReasonerFactory reasonerFactory = new ElkReasonerFactory(); 214 mpKBReasoner = reasonerFactory.createReasoner(mpKB, reasonerConf); 215 mpKBReasoner.precomputeInferences(InferenceType.CLASS_HIERARCHY); 216 217 try { 218 mgi2go = readMGI2GOMapping(mgi2goMappingsFilePath); 219 mgi2mp = readMGI2MPMapping(mgi2mpMappingsFilePath); 220 } catch (IOException e) { 221 e.printStackTrace(); 222 System.exit(1); 223 } 224 225 kernel = new SemKernel(); 226 kernel.setSvmType(svmType); 227 kernel.setDoProbabilityEstimates(doProbabilityEstimates); 228 kernel.setCrossValidationFolds(crossValidationFolds); 229 kernel.setCost(cost); 230 kernel.setOntologyFilePath(goKBFilePath); 231 kernel.setTrainingDirPath(trainingInputDirectoryPath); 232 kernel.setModelDirPath(trainingOutputDirectoryPath); 233 kernel.setPredictionDataDirPath(predictionInputDirectoryPath); 234 kernel.setResultsDirPath(predictionOutputDirectoryPath); 235 kernel.setGamma(0); 236 kernel.setPredictProbability(predictProbability); 237 kernel.init(); 238 239 initialized = true; 240 logger.info("Finished workflow initialization."); 241 } 242 243 @Override 244 public void start() { 245 if (doTraining) { 246 logger.info("Preparing training data..."); 247 try { 248 prepareMPSampleTrainingData(); 249 } catch (IOException e) { 250 e.printStackTrace(); 251 System.exit(1); 252 } 253 logger.info("Finished training data preparation."); 254 255 // T4) done during kernel.init() ------------ 256 // T5) ------------ 257 logger.info("Training..."); 258 kernel.train(); 259 logger.info("Finished trainig."); 260 } 261 262 if (doPrediction) { 263 logger.info("Preparing prediction data..."); 264 try { 265 prepareMPPredictionData(); 266 } catch (IOException e) { 267 e.printStackTrace(); 268 System.exit(1); 269 } 270 logger.info("Finished prediction data preparation."); 271 272 // P4) done during kernel.init() ------------ 273 // P5) ------------ 274 logger.info("Doing predictions..."); 275 kernel.predict(); 276 logger.info("Finished prediction."); 277 } 278 } 279 280 private Map<String, Set<String>> readMGI2MPMapping(String mgi2mpFilePath) 281 throws IOException { 282 283 Map<String, Set<String>> mgi2mp = new HashMap<>(); 284 285 BufferedReader bufferedReader = new BufferedReader(new FileReader( 286 new File(mgi2mpFilePath))); 287 288 String line; 289 while ((line = bufferedReader.readLine()) != null) { 290 String[] fields = line.split("\t"); 291 if (fields.length < 2) continue; 292 293 String mgiId = fields[0]; 294 295 if (mgi2go.containsKey(mgiId)) { 296 String mpId = fields[1]; 297 if (mpId.trim().length() == 0) continue; // skip lines not containing an MP ID 298 299 String mpUriStr = oboPrefix + mpId.replace(":", "_"); 300 301 if (!mgi2mp.containsKey(mgiId)) { 302 mgi2mp.put(mgiId, new TreeSet<>()); 303 } 304 mgi2mp.get(mgiId).add(mpUriStr); 305 } 306 } 307 bufferedReader.close(); 308 309 return mgi2mp; 310 } 311 312 public void prepareMPSampleTrainingData() throws IOException { 313 // append path separator if not set already 314 if (!trainingInputDirectoryPath.endsWith(File.separator)) { 315 trainingInputDirectoryPath = trainingInputDirectoryPath + File.separator; 316 } 317 318 // T1) ------------ 319 Set<String> trainUriStrs = readTrainURIs(trainURIsFilePath); 320 321 // T3) (for each URI from T1)) ------------ 322 for (String searchClassUriStr : trainUriStrs) { 323 String localPart = getLocalPart(searchClassUriStr); 324 325 String trainOutFilePath = trainingInputDirectoryPath + localPart; 326 327 OWLClass searchCls = new OWLClassImpl(IRI.create(searchClassUriStr)); 328 329 Set<String> classifierFor = new TreeSet<>(); 330 classifierFor.add(searchClassUriStr); 331 332 Set<OWLClass> subClasses = 333 mpKBReasoner.getSubClasses(searchCls, false).getFlattened(); 334 335 for (OWLClass owlClass : subClasses) { 336 String uriStr = owlClass.getIRI().toString(); 337 classifierFor.add(uriStr); 338 } 339 340 List<String> negatives = new ArrayList<>(); 341 List<String> positives = new ArrayList<>(); 342 343 // build lines to write to file (SVM light format) 344 for (String mgiId : mgi2mp.keySet()) { 345 String outputLine = ""; 346 347 Set<String> mpUriStrs = mgi2mp.get(mgiId); 348 349 if (Sets.intersection(classifierFor, mpUriStrs).isEmpty()) { 350 outputLine += "0"; 351 } else { 352 outputLine += "1"; 353 } 354 355 if (!mgi2go.containsKey(mgiId)) continue; 356 for (String goUristr : mgi2go.get(mgiId)) { 357 outputLine += "\t" + goUristr; 358 } 359 360 if (Sets.intersection(classifierFor, mpUriStrs).isEmpty()) { 361 negatives.add(outputLine); 362 } else { 363 positives.add(outputLine); 364 } 365 } 366 367 // shorten negative SVM light lines set to the configured 368 // positives-negatives ratio 369 Collections.shuffle(negatives); 370 int cutoff = (int) Math.round(posNegExampleRatio * positives.size()); 371 if (cutoff < negatives.size()) { 372 negatives = negatives.subList(0, cutoff); 373 } 374 375 BufferedWriter buffWriter = new BufferedWriter(new FileWriter(trainOutFilePath)); 376 377 for (String posLine : positives) { 378 buffWriter.write(posLine); 379 buffWriter.newLine(); 380 } 381 for (String negLine : negatives) { 382 buffWriter.write(negLine); 383 buffWriter.newLine(); 384 } 385 buffWriter.close(); 386 } 387 } 388 389 private void prepareMPPredictionData() throws IOException { 390 if (!predictionInputDirectoryPath.endsWith(File.separator)) { 391 predictionInputDirectoryPath += File.separator; 392 } 393 394 // P1) ------------ 395 Set<String> predictClsUriStrs = readTrainURIs(predictionURIsFilePath); 396 397 // P3) (for each URI from P1)) ------------ 398 for (String predClsUriStr : predictClsUriStrs) { 399 String localPart = getLocalPart(predClsUriStr); 400 String predOutFilePath = predictionInputDirectoryPath + localPart; 401 402 BufferedWriter buffWriter = new BufferedWriter( 403 new FileWriter(predOutFilePath)); 404// buffWriter.write("map"); 405// for (String clsUri : allClsUriStrs) { 406// buffWriter.write("\t" + clsUri); 407// } 408// buffWriter.newLine(); 409 410 Set<String> classifierFor = new TreeSet<>(); 411 classifierFor.add(predClsUriStr); 412 413 OWLClass predCls = new OWLClassImpl(IRI.create(predClsUriStr)); 414 Set<OWLClass> subClasses = 415 mpKBReasoner.getSubClasses(predCls, false).getFlattened(); 416 417 for (OWLClass subClass : subClasses) { 418 String uriStr = subClass.getIRI().toString(); 419 classifierFor.add(uriStr); 420 } 421 422 List<String> negatives = new ArrayList<>(); 423 List<String> positives = new ArrayList<>(); 424 425 // build lines to write to file (SVM light format) 426 for (String mgiId : mgi2mp.keySet()) { 427 String outputLine = ""; 428 429 Set<String> mpUriStrs = mgi2mp.get(mgiId); 430 431 if (Sets.intersection(classifierFor, mpUriStrs).isEmpty()) { 432 outputLine += "0"; 433 } else { 434 outputLine += "1"; 435 } 436 437 if (!mgi2go.containsKey(mgiId)) continue; 438 for (String goUristr : mgi2go.get(mgiId)) { 439 outputLine += "\t" + goUristr; 440 } 441 442 if (Sets.intersection(classifierFor, mpUriStrs).isEmpty()) { 443 negatives.add(outputLine); 444 } else { 445 positives.add(outputLine); 446 } 447 } 448 449 // shorten negative SVM light lines set to the configured 450 // positives-negatives ratio 451 Collections.shuffle(negatives); 452 int cutoff = (int) Math.round(posNegExampleRatio * positives.size()); 453 if (cutoff < negatives.size()) { 454 negatives = negatives.subList(0, cutoff); 455 } 456 457 for (String posLine : positives) { 458 buffWriter.write(posLine); 459 buffWriter.newLine(); 460 } 461 for (String negLine : negatives) { 462 buffWriter.write(negLine); 463 buffWriter.newLine(); 464 } 465 buffWriter.close(); 466 } 467 } 468 469 private static Set<String> readTrainURIs(String trainURIsFilePath) throws IOException { 470 Set<String> uriStrs = new HashSet<>(); 471 BufferedReader bufferedReader = new BufferedReader( 472 new FileReader(new File(trainURIsFilePath))); 473 474 String line; 475 while ((line = bufferedReader.readLine()) != null) { 476 // strip off leading and trailing angled bracket 477 if (line.startsWith("<") && line.endsWith(">")) { 478 line = line.substring(1, line.length()-1); 479 } 480 481 uriStrs.add(line); 482 } 483 bufferedReader.close(); 484 485 return uriStrs; 486 } 487 488 private Map<String, Set<String>> readMGI2GOMapping(String mgi2goFilePath) 489 throws IOException { 490 Map<String, Set<String>> mgi2go = new HashMap<>(); 491 492 BufferedReader bufferedReader = new BufferedReader(new FileReader( 493 new File(mgi2goFilePath))); 494 495 String line; 496 while ((line = bufferedReader.readLine()) != null) { 497 if (line.startsWith("!")) continue; 498 499 String[] fields = line.split("\t"); 500 String mgiId = fields[1]; 501 // one or more of "NOT", "contributes_to", "co-localizes_with" 502 String qualifier = fields[3]; 503 String goId = fields[4]; 504 String evidenceCode = fields[6]; 505 506 if (goId.trim().length() == 0) continue; // skip lines not containing a GO ID 507 508 if (!Objects.equals(evidenceCode, "ND") && !qualifier.contains("NOT")) { 509 String goUriStr = oboPrefix + goId.replace(":", "_"); 510 511 if (!mgi2go.containsKey(mgiId)) { 512 mgi2go.put(mgiId, new TreeSet<>()); 513 } 514 mgi2go.get(mgiId).add(goUriStr); 515 } 516 } 517 bufferedReader.close(); 518 519 return mgi2go; 520 } 521 522 private static String getLocalPart(String uriStr) { 523 int lastSlashIdx = uriStr.lastIndexOf('/'); 524 if (lastSlashIdx > -1) { 525 return uriStr.substring(lastSlashIdx); 526 527 } else { // try looking for a hash sign 528 int lastHashIdx = uriStr.lastIndexOf('#'); 529 if (lastHashIdx > -1) { 530 return uriStr.substring(lastHashIdx); 531 } 532 } 533 534 return uriStr; 535 } 536 537 // -------------------- only getters and setters below -------------------- 538 public String getTrainURIsFilePath() { 539 return trainURIsFilePath; 540 } 541 542 public void setTrainURIsFilePath(String trainURIsFilePath) { 543 this.trainURIsFilePath = trainURIsFilePath; 544 } 545 546 public String getMpKBFilePath() { 547 return mpKBFilePath; 548 } 549 550 public void setMpKBFilePath(String mpKBFilePath) { 551 this.mpKBFilePath = mpKBFilePath; 552 } 553 554 public String getTrainingInputDirectoryPath() { 555 return trainingInputDirectoryPath; 556 } 557 558 public void setTrainingInputDirectoryPath(String trainingInputDirectoryPath) { 559 this.trainingInputDirectoryPath = trainingInputDirectoryPath; 560 } 561 562 public String getGoKBFilePath() { 563 return goKBFilePath; 564 } 565 566 public void setGoKBFilePath(String goKBFilePath) { 567 this.goKBFilePath = goKBFilePath; 568 } 569 570 public String getTrainingOutputDirectoryPath() { 571 return trainingOutputDirectoryPath; 572 } 573 574 public void setTrainingOutputDirectoryPath( 575 String trainingOutputDirectoryPath) { 576 this.trainingOutputDirectoryPath = trainingOutputDirectoryPath; 577 } 578 579 public String getPredictionURIsFilePath() { 580 return predictionURIsFilePath; 581 } 582 583 public void setPredictionURIsFilePath(String predictionURIsFilePath) { 584 this.predictionURIsFilePath = predictionURIsFilePath; 585 } 586 587 public String getPredictionInputDirectoryPath() { 588 return predictionInputDirectoryPath; 589 } 590 591 public void setPredictionInputDirectoryPath( 592 String predictionInputDirectoryPath) { 593 this.predictionInputDirectoryPath = predictionInputDirectoryPath; 594 } 595 596 public String getPredictionOutputDirectoryPath() { 597 return predictionOutputDirectoryPath; 598 } 599 600 public void setPredictionOutputDirectoryPath( 601 String predictionOutputDirectoryPath) { 602 this.predictionOutputDirectoryPath = predictionOutputDirectoryPath; 603 } 604 605 public String getMgi2mpMappingsFilePath() { 606 return mgi2mpMappingsFilePath; 607 } 608 609 public void setMgi2mpMappingsFilePath(String mgi2mpMappingsFilePath) { 610 this.mgi2mpMappingsFilePath = mgi2mpMappingsFilePath; 611 } 612 613 public String getMgi2goMappingsFilePath() { 614 return mgi2goMappingsFilePath; 615 } 616 617 public void setMgi2goMappingsFilePath(String mgi2goMappingsFilePath) { 618 this.mgi2goMappingsFilePath = mgi2goMappingsFilePath; 619 } 620 621 public SvmType getSvmType() { 622 return svmType; 623 } 624 625 public void setSvmType(SvmType svmType) { 626 this.svmType = svmType; 627 } 628 629 public boolean isDoProbabilityEstimates() { 630 return doProbabilityEstimates; 631 } 632 633 public void setDoProbabilityEstimates(boolean doProbabilityEstimates) { 634 this.doProbabilityEstimates = doProbabilityEstimates; 635 } 636 637 public int getCrossValidationFolds() { 638 return crossValidationFolds; 639 } 640 641 public void setCrossValidationFolds(int crossValidationFolds) { 642 this.crossValidationFolds = crossValidationFolds; 643 } 644 645 public float getCost() { 646 return cost; 647 } 648 649 public void setCost(float cost) { 650 this.cost = cost; 651 } 652 653 public boolean isPredictProbability() { 654 return predictProbability; 655 } 656 657 public void setPredictProbability(boolean predictProbability) { 658 this.predictProbability = predictProbability; 659 } 660 661 public double getPosNegExampleRatio() { 662 return posNegExampleRatio; 663 } 664 665 public void setPosNegExampleRatio(double posNegExampleRatio) { 666 this.posNegExampleRatio = posNegExampleRatio; 667 } 668 669 public boolean isDoTraining() { 670 return doTraining; 671 } 672 673 public void setDoTraining(boolean doTraining) { 674 this.doTraining = doTraining; 675 } 676 677 public boolean isDoPrediction() { 678 return doPrediction; 679 } 680 681 public void setDoPrediction(boolean doPrediction) { 682 this.doPrediction = doPrediction; 683 } 684 685}