001package org.dllearner.examples; 002 003import java.io.BufferedReader; 004import java.io.BufferedWriter; 005import java.io.File; 006import java.io.FileReader; 007import java.io.FileWriter; 008import java.io.IOException; 009import java.util.ArrayList; 010import java.util.Collections; 011import java.util.HashMap; 012import java.util.List; 013import java.util.Set; 014import java.util.TreeSet; 015 016import org.apache.log4j.Level; 017import org.apache.log4j.Logger; 018import org.dllearner.algorithms.semkernel.SemKernel; 019import org.dllearner.algorithms.semkernel.SemKernel.SvmType; 020import org.semanticweb.elk.owlapi.ElkReasonerFactory; 021import org.semanticweb.owlapi.apibinding.OWLManager; 022import org.semanticweb.owlapi.model.IRI; 023import org.semanticweb.owlapi.model.OWLClass; 024import org.semanticweb.owlapi.model.OWLDataFactory; 025import org.semanticweb.owlapi.model.OWLOntology; 026import org.semanticweb.owlapi.model.OWLOntologyCreationException; 027import org.semanticweb.owlapi.model.OWLOntologyManager; 028import org.semanticweb.owlapi.reasoner.ConsoleProgressMonitor; 029import org.semanticweb.owlapi.reasoner.InferenceType; 030import org.semanticweb.owlapi.reasoner.OWLReasoner; 031import org.semanticweb.owlapi.reasoner.OWLReasonerConfiguration; 032import org.semanticweb.owlapi.reasoner.OWLReasonerFactory; 033import org.semanticweb.owlapi.reasoner.SimpleConfiguration; 034 035import com.google.common.collect.Sets; 036 037public class SemKernelExample { 038 private static final Logger logger = Logger.getLogger(SemKernelExample.class); 039 040 private static HashMap<String, Set<String>> mgi2go; 041 private static HashMap<String, Set<String>> mgi2mp; 042 private static Set<String> allClsUriStrs; 043 private static OWLDataFactory factory; 044 045 private static SemKernel kernel; 046 /* kernel settings, according to 047 * svm_train -s 0 -t 5 -b 1 -v 10 -c $cost -f go.obo \ 048 * $TRAININGDIR/$NAME-$par > $CVDIR/$NAME-$par-$cost 049 */ 050 private static final SvmType svmType = SvmType.C_SVC; 051 private static final boolean doProbabilityEstimates = true; 052 private static final int crossValidationFolds = 10; 053 private static final float cost = 5f; 054 private static final boolean predictProbability = true; 055 056 /* semkernel settings */ 057 /** URIs of phenotype classes to build a classifier for */ 058 private static final Set<String> trainClsUriStrs = Sets.newHashSet( 059// "http://purl.obolibrary.org/obo/MP_0004031", 060// "http://purl.obolibrary.org/obo/MP_0000202", 061 "http://purl.obolibrary.org/obo/MP_0001186" 062// "http://purl.obolibrary.org/obo/MP_0000001" 063 ); 064 private static final Set<String> predictClsUriStrs = Sets.newHashSet( 065// "http://purl.obolibrary.org/obo/MP_0004031", 066// "http://purl.obolibrary.org/obo/MP_0000202", 067 "http://purl.obolibrary.org/obo/MP_0001186" 068// "http://purl.obolibrary.org/obo/MP_0000001" 069 ); 070 private static final double posNegExampleRatio = 1; 071 072 /* input/output */ 073 private static final String workingDirPath = "/tmp/semkernel/"; 074 // get it at http://purl.obolibrary.org/obo/go.obo 075// private static final String kbFilePath = workingDirPath + "go.obo"; 076 private static final String kbFilePath = workingDirPath + "go.owl"; 077 private static final String trainDirPath = workingDirPath + "train/"; 078 private static final String modelDirPath = workingDirPath + "models/"; 079 private static final String predictionDirPath = workingDirPath + "predictions/"; 080 private static final String resultsDirPath = workingDirPath + "results/"; 081 // get it at http://purl.obolibrary.org/obo/mp.obo 082 private static final String samplesKbFilePath = workingDirPath + "mp.obo"; 083 private static OWLOntology samplesKb; 084 private static OWLReasoner samplesReasoner; 085 // get it at ftp://ftp.informatics.jax.org/pub/reports/gene_association.mgi 086 private static String mgi2goClsMappingsFilePath = workingDirPath + "gene_association.mgi"; 087 // get it at http://aber-owl.net/aber-owl/diseasephenotypes/data/mousephenotypes.txt 088 private static String mgi2mpClsMappingsFilePath = workingDirPath + "mousephenotypes_new.txt"; 089 090 /* URI constants */ 091 private static final String oboPrefix = "http://purl.obolibrary.org/obo/"; 092 093 public static void main(String[] args) throws Exception { 094 initExample(); 095 prepareTrainingData(); 096 097 initSemkernel(); 098 train(); 099 100 preparePredictionData(); 101 predict(); 102 } 103 104 private static void initExample() throws OWLOntologyCreationException, IOException { 105 logger.setLevel(Level.DEBUG); 106 107 factory = OWLManager.getOWLDataFactory(); 108 allClsUriStrs = new TreeSet<>(); 109 110 logger.info(String.format("Loading samples ontology file %s ...", samplesKbFilePath)); 111 samplesKb = readKb(samplesKbFilePath); 112 logger.info("-Done-"); 113 114 logger.info("Initialising reasoner..."); 115 ConsoleProgressMonitor mon = new ConsoleProgressMonitor(); 116 OWLReasonerConfiguration reasonerConf = new SimpleConfiguration(mon); 117 OWLReasonerFactory reasonerFactory = new ElkReasonerFactory(); 118 samplesReasoner = reasonerFactory.createReasoner(samplesKb, reasonerConf); 119 samplesReasoner.precomputeInferences(InferenceType.CLASS_HIERARCHY); 120 logger.info("-Done-"); 121 122 logger.info(String.format( 123 "Reading MGI ID to GO class mappings from %s ...", 124 mgi2goClsMappingsFilePath)); 125 mgi2go = new HashMap<>(); 126 BufferedReader buffReader = new BufferedReader( 127 new FileReader(mgi2goClsMappingsFilePath)); 128 129 String line; 130 while ((line = buffReader.readLine()) != null) { 131 if (line.startsWith("!")) continue; 132 133 String[] fields = line.split("\t"); 134 String mgiId = fields[1]; 135 // one or more of "NOT", "contributes_to", "co-localizes_with" 136 String qualifier = fields[3]; 137 String goId = fields[4]; 138 String evidenceCode = fields[6]; 139 140 if (goId.trim().length() == 0) continue; // skip lines not containing a GO ID 141 142 if (evidenceCode != "ND" && !qualifier.contains("NOT")) { 143 String goUriStr = oboPrefix + goId.replace(":", "_"); 144 145 if (!mgi2go.containsKey(mgiId)) { 146 mgi2go.put(mgiId, new TreeSet<>()); 147 } 148 mgi2go.get(mgiId).add(goUriStr); 149 allClsUriStrs.add(goUriStr); 150 } else { 151 logger.debug(String.format("-- Skipping line \"%s\"", line)); 152 } 153 } 154 buffReader.close(); 155 logger.info("-Done-"); 156 157 logger.info(String.format( 158 "Reading MGI ID to MP class mappings from %s ...", 159 mgi2mpClsMappingsFilePath)); 160 mgi2mp = new HashMap<>(); 161 buffReader = new BufferedReader(new FileReader(mgi2mpClsMappingsFilePath)); 162 163 while ((line = buffReader.readLine()) != null) { 164 String[] fields = line.split("\t"); 165 if (fields.length < 2) continue; 166 167 String mgiId = fields[0]; 168 169 if (mgi2go.containsKey(mgiId)) { 170 String mpId = fields[1]; 171 if (mpId.trim().length() == 0) continue; // skip lines not containing an MP ID 172 173 String mpUriStr = oboPrefix + mpId.replace(":", "_"); 174 175 if (!mgi2mp.containsKey(mgiId)) { 176 mgi2mp.put(mgiId, new TreeSet<>()); 177 } 178 mgi2mp.get(mgiId).add(mpUriStr); 179 } 180 } 181 buffReader.close(); 182 logger.info("-Done-"); 183 } 184 185 private static OWLOntology readKb(String filePath) throws OWLOntologyCreationException { 186 OWLOntologyManager man = OWLManager.createOWLOntologyManager(); 187 return man.loadOntologyFromOntologyDocument(new File(filePath)); 188 } 189 190 private static void initSemkernel() { 191 logger.info("Initialising the semkernel..."); 192 kernel = new SemKernel(); 193 194 // svm_train -s 0 ... 195 kernel.setSvmType(svmType); 196 // ... -t 5 ... --> use semantic kernel (fixed for SemKernel anyway) 197 // ... -b 1 ... 198 kernel.setDoProbabilityEstimates(doProbabilityEstimates); 199 // ... -v 10 ... 200 kernel.setCrossValidationFolds(crossValidationFolds); 201 // ... -c $cost ... 202 kernel.setCost(cost); 203 // ... -f go.obo ... 204 kernel.setOntologyFilePath(kbFilePath); 205 // ... $TRAININGDIR/$NAME-$par > $CVDIR/$NAME-$par-$cost 206 kernel.setTrainingDirPath(trainDirPath); 207 kernel.setModelDirPath(modelDirPath ); 208 kernel.setPredictionDataDirPath(predictionDirPath); 209 kernel.setResultsDirPath(resultsDirPath); 210 211 kernel.setGamma(0); 212 kernel.setPredictProbability(predictProbability); 213 214 kernel.init(); 215 logger.info("-Done-"); 216 } 217 218 private static void prepareTrainingData() throws IOException { 219 logger.info(String.format("Writing training sample data to file (%s)...", 220 trainDirPath)); 221 for (String searchClassUriStr : trainClsUriStrs) { 222 String localPart = getLocalPart(searchClassUriStr); 223 String trainOutFilePath = trainDirPath + localPart; 224 logger.debug("-- " + trainOutFilePath); 225 226// OWLClass searchCls = new OWLClassImpl(IRI.create(searchClassUriStr)); 227 OWLClass searchCls = factory.getOWLClass(IRI.create(searchClassUriStr)); 228 229 Set<String> classifierFor = new TreeSet<>(); 230 classifierFor.add(searchClassUriStr); 231 232 Set<OWLClass> subClasses = 233 samplesReasoner.getSubClasses(searchCls, false).getFlattened(); 234 235 for (OWLClass owlClass : subClasses) { 236 String uriStr = owlClass.getIRI().toString(); 237 classifierFor.add(uriStr); 238 } 239 240 List<String> negatives = new ArrayList<>(); 241 List<String> positives = new ArrayList<>(); 242 243 // build lines to write to file (SVM light format) 244 for (String mgiId : mgi2mp.keySet()) { 245 String outputLine = ""; 246 247 Set<String> mpUriStrs = mgi2mp.get(mgiId); 248 249 if (Sets.intersection(classifierFor, mpUriStrs).isEmpty()) { 250 outputLine += "0"; 251 } else { 252 outputLine += "1"; 253 } 254 255 if (!mgi2go.containsKey(mgiId)) continue; 256 for (String goUristr : mgi2go.get(mgiId)) { 257 outputLine += "\t" + goUristr; 258 } 259 260 if (Sets.intersection(classifierFor, mpUriStrs).isEmpty()) { 261 negatives.add(outputLine); 262 } else { 263 positives.add(outputLine); 264 } 265 } 266 267 // shorten negative SVM light lines set to the configured 268 // positives-negatives ratio 269 Collections.shuffle(negatives); 270 int cutoff = (int) Math.round(posNegExampleRatio * positives.size()); 271 if (cutoff < negatives.size()) { 272 negatives = negatives.subList(0, cutoff); 273 } 274 275 BufferedWriter buffWriter = new BufferedWriter(new FileWriter(trainOutFilePath)); 276 277 for (String posLine : positives) { 278 buffWriter.write(posLine); 279 buffWriter.newLine(); 280 } 281 for (String negLine : negatives) { 282 buffWriter.write(negLine); 283 buffWriter.newLine(); 284 } 285 286 buffWriter.close(); 287 } 288 logger.info("-Done-"); 289 } 290 291 private static void preparePredictionData() throws IOException { 292 logger.info("Preparing prediction data..."); 293 for (String predClsUriStr : predictClsUriStrs) { 294 String localPart = getLocalPart(predClsUriStr); 295 String predOutFilePath = predictionDirPath + localPart; 296 logger.debug("-- " + predOutFilePath); 297 298 BufferedWriter buffWriter = new BufferedWriter( 299 new FileWriter(predOutFilePath)); 300// buffWriter.write("map"); 301// for (String clsUri : allClsUriStrs) { 302// buffWriter.write("\t" + clsUri); 303// } 304// buffWriter.newLine(); 305 306 Set<String> classifierFor = new TreeSet<>(); 307 classifierFor.add(predClsUriStr); 308 309 OWLClass predCls = factory.getOWLClass(IRI.create(predClsUriStr)); 310 Set<OWLClass> subClasses = 311 samplesReasoner.getSubClasses(predCls, false).getFlattened(); 312 313 for (OWLClass subClass : subClasses) { 314 String uriStr = subClass.getIRI().toString(); 315 classifierFor.add(uriStr); 316 } 317 318 List<String> negatives = new ArrayList<>(); 319 List<String> positives = new ArrayList<>(); 320 321 // build lines to write to file (SVM light format) 322 for (String mgiId : mgi2mp.keySet()) { 323 String outputLine = ""; 324 325 Set<String> mpUriStrs = mgi2mp.get(mgiId); 326 327 if (Sets.intersection(classifierFor, mpUriStrs).isEmpty()) { 328 outputLine += "0"; 329 } else { 330 outputLine += "1"; 331 } 332 333 if (!mgi2go.containsKey(mgiId)) continue; 334 for (String goUristr : mgi2go.get(mgiId)) { 335 outputLine += "\t" + goUristr; 336 } 337 338 if (Sets.intersection(classifierFor, mpUriStrs).isEmpty()) { 339 negatives.add(outputLine); 340 } else { 341 positives.add(outputLine); 342 } 343 } 344 345 // shorten negative SVM light lines set to the configured 346 // positives-negatives ratio 347 Collections.shuffle(negatives); 348 int cutoff = (int) Math.round(posNegExampleRatio * positives.size()); 349 if (cutoff < negatives.size()) { 350 negatives = negatives.subList(0, cutoff); 351 } 352 353 for (String posLine : positives) { 354 buffWriter.write(posLine); 355 buffWriter.newLine(); 356 } 357 for (String negLine : negatives) { 358 buffWriter.write(negLine); 359 buffWriter.newLine(); 360 } 361 362 buffWriter.close(); 363 } 364 logger.info("-Done-"); 365 } 366 367 private static String getLocalPart(String uriStr) { 368 int lastSlashIdx = uriStr.lastIndexOf('/'); 369 if (lastSlashIdx > -1) { 370 return uriStr.substring(lastSlashIdx); 371 372 } else { // try looking for a hash sign 373 int lastHashIdx = uriStr.lastIndexOf('#'); 374 if (lastHashIdx > -1) { 375 return uriStr.substring(lastHashIdx); 376 } 377 } 378 379 return uriStr; 380 } 381 382 private static void train() { 383 logger.info("Training..."); 384 kernel.train(); 385 logger.info("-Done-"); 386 } 387 388 private static void predict() { 389 logger.info("Running prediction..."); 390 kernel.predict(); 391 logger.info("-Done-"); 392 } 393}