001package org.dllearner.algorithms.qtl.util; 002 003import java.util.ArrayList; 004import java.util.List; 005import java.util.Set; 006 007import org.apache.jena.graph.Node; 008import org.jgrapht.Graph; 009import org.jgrapht.alg.shortestpath.BellmanFordShortestPath; 010import org.jgrapht.alg.shortestpath.DijkstraShortestPath; 011import org.jgrapht.alg.spanning.KruskalMinimumSpanningTree; 012import org.jgrapht.graph.DefaultEdge; 013import org.jgrapht.graph.Pseudograph; 014import org.jgrapht.graph.WeightedMultigraph; 015import org.slf4j.Logger; 016import org.slf4j.LoggerFactory; 017 018public class SteinerTree { 019 020 private static final Logger logger = LoggerFactory.getLogger(SteinerTree.class); 021 022 Graph<Node, DefaultEdge> graph; 023 WeightedMultigraph<Node, DefaultEdge> tree; 024 List<Node> steinerNodes; 025 026 public SteinerTree(Graph<Node, DefaultEdge> graph, List<Node> steinerNodes) { 027 this.graph = graph; 028 this.steinerNodes = steinerNodes; 029 030 runAlgorithm(); 031 } 032 033 /** 034 * Construct the complete undirected distance graph G1=(V1,EI,d1) from G and S. 035 */ 036 private Pseudograph<Node, DefaultEdge> step1() { 037 038 logger.debug("<enter"); 039 040 Pseudograph<Node, DefaultEdge> g = new Pseudograph<>(DefaultEdge.class); 041 042 for (Node n : this.steinerNodes) { 043 g.addVertex(n); 044 } 045 046 BellmanFordShortestPath<Node, DefaultEdge> pathGen = new BellmanFordShortestPath<>(this.graph); 047 048 for (Node n1 : this.steinerNodes) { 049 for (Node n2 : this.steinerNodes) { 050 051 if (n1.equals(n2)) 052 continue; 053 054 if (g.containsEdge(n1, n2)) 055 continue; 056 057 DefaultEdge e = new DefaultEdge(); 058 g.addEdge(n1, n2, e); 059 g.setEdgeWeight(e, pathGen.getPathWeight(n1, n2)); 060 061 } 062 063 } 064 065 logger.debug("exit>"); 066 067 return g; 068 069 } 070 071 /** 072 * Find the minimal spanning tree, T1, of G1. (If there are several minimal spanning trees, pick an arbitrary one.) 073 * 074 * @param g1 075 * @return 076 */ 077 private WeightedMultigraph<Node, DefaultEdge> step2(Pseudograph<Node, DefaultEdge> g1) { 078 079 logger.debug("<enter"); 080 081 KruskalMinimumSpanningTree<Node, DefaultEdge> mst = new KruskalMinimumSpanningTree<>(g1); 082 083// logger.debug("Total MST Cost: " + mst.getSpanningTreeCost()); 084 085 Set<DefaultEdge> edges = mst.getSpanningTree().getEdges(); 086 087 WeightedMultigraph<Node, DefaultEdge> g2 = new WeightedMultigraph<>(DefaultEdge.class); 088 089 List<DefaultEdge> edgesSortedById = new ArrayList<>(edges); 090// edgesSortedById.sort(); 091 092 for (DefaultEdge edge : edgesSortedById) { 093 g2.addVertex(g1.getEdgeSource(edge)); 094 g2.addVertex(g1.getEdgeTarget(edge)); 095 g2.addEdge(g1.getEdgeSource(edge), g1.getEdgeTarget(edge), edge); 096 } 097 098 logger.debug("exit>"); 099 100 return g2; 101 } 102 103 /** 104 * Construct the subgraph, Gs, of G by replacing each edge in T1 by its corresponding shortest path in G. 105 * (If there are several shortest paths, pick an arbitrary one.) 106 * 107 * @param g2 108 * @return 109 */ 110 private WeightedMultigraph<Node, DefaultEdge> step3(WeightedMultigraph<Node, DefaultEdge> g2) { 111 112 logger.debug("<enter"); 113 114 WeightedMultigraph<Node, DefaultEdge> g3 = new WeightedMultigraph<>(DefaultEdge.class); 115 116 Set<DefaultEdge> edges = g2.edgeSet(); 117 DijkstraShortestPath<Node, DefaultEdge> pathGen = new DijkstraShortestPath<>(this.graph); 118 119 Node source, target; 120 121 for (DefaultEdge edge : edges) { 122 source = g2.getEdgeSource(edge); 123 target = g2.getEdgeTarget(edge); 124 125 126 List<DefaultEdge> pathEdges = pathGen.getPath(source, target).getEdgeList(); 127 128 if (pathEdges == null) 129 continue; 130 131 for (int i = 0; i < pathEdges.size(); i++) { 132 133 if (g3.edgeSet().contains(pathEdges.get(i))) 134 continue; 135 136 source = g2.getEdgeSource(pathEdges.get(i)); 137 target = g2.getEdgeTarget(pathEdges.get(i)); 138 139 if (!g3.vertexSet().contains(source)) 140 g3.addVertex(source); 141 142 if (!g3.vertexSet().contains(target)) 143 g3.addVertex(target); 144 145 g3.addEdge(source, target, pathEdges.get(i)); 146 } 147 } 148 149 logger.debug("exit>"); 150 151 return g3; 152 } 153 154 /** 155 * Find the minimal spanning tree, Ts, of Gs. (If there are several minimal spanning trees, pick an arbitrary one.) 156 * 157 * @param g3 158 * @return 159 */ 160 private WeightedMultigraph<Node, DefaultEdge> step4(WeightedMultigraph<Node, DefaultEdge> g3) { 161 162 logger.debug("<enter"); 163 164 KruskalMinimumSpanningTree<Node, DefaultEdge> mst = new KruskalMinimumSpanningTree<>(g3); 165 166// logger.debug("Total MST Cost: " + mst.getSpanningTreeCost()); 167 168 Set<DefaultEdge> edges = mst.getSpanningTree().getEdges(); 169 170 WeightedMultigraph<Node, DefaultEdge> g4 = 171 new WeightedMultigraph<>(DefaultEdge.class); 172 173 List<DefaultEdge> edgesSortedById = new ArrayList<>(edges); 174// Collections.sort(edgesSortedById); 175 176 for (DefaultEdge edge : edgesSortedById) { 177 g4.addVertex(g3.getEdgeSource(edge)); 178 g4.addVertex(g3.getEdgeTarget(edge)); 179 g4.addEdge(g3.getEdgeSource(edge), g3.getEdgeTarget(edge), edge); 180 } 181 182 logger.debug("exit>"); 183 184 return g4; 185 } 186 187 /** 188 * Construct a Steiner tree, Th, from Ts by deleting edges in Ts,if necessary, 189 * so that all the leaves in Th are Steiner points. 190 * 191 * @param g4 192 * @return 193 */ 194 private WeightedMultigraph<Node, DefaultEdge> step5(WeightedMultigraph<Node, DefaultEdge> g4) { 195 196 logger.debug("<enter"); 197 198 WeightedMultigraph<Node, DefaultEdge> g5 = g4; 199 200 List<Node> nonSteinerLeaves = new ArrayList<>(); 201 202 Set<Node> vertexSet = g4.vertexSet(); 203 for (Node vertex : vertexSet) { 204 if (g5.degreeOf(vertex) == 1 && steinerNodes.indexOf(vertex) == -1) { 205 nonSteinerLeaves.add(vertex); 206 } 207 } 208 209 Node source, target; 210 for (int i = 0; i < nonSteinerLeaves.size(); i++) { 211 source = nonSteinerLeaves.get(i); 212 do { 213 DefaultEdge e = g5.edgesOf(source).toArray(new DefaultEdge[0])[0]; 214 target = this.graph.getEdgeTarget(e); 215 216 // this should not happen, but just in case of ... 217 if (target.equals(source)) 218 target = g5.getEdgeSource(e); 219 220 g5.removeVertex(source); 221 source = target; 222 } while (g5.degreeOf(source) == 1 && steinerNodes.indexOf(source) == -1); 223 224 } 225 226 logger.debug("exit>"); 227 228 return g5; 229 } 230 231 private void runAlgorithm() { 232 233 logger.debug("<enter"); 234 235 logger.debug("step1 ..."); 236 Pseudograph<Node, DefaultEdge> g1 = step1(); 237// logger.info("after doing step 1 ...................................................................."); 238// GraphUtil.printGraphSimple(g1); 239// GraphUtil.printGraph(g1); 240 241 if (g1.vertexSet().size() < 2) { 242 this.tree = new WeightedMultigraph<>(DefaultEdge.class); 243 for (Node n : g1.vertexSet()) this.tree.addVertex(n); 244 return; 245 } 246 247 logger.debug("step2 ..."); 248 WeightedMultigraph<Node, DefaultEdge> g2 = step2(g1); 249// logger.info("after doing step 2 ...................................................................."); 250// GraphUtil.printGraphSimple(g2); 251// GraphUtil.printGraph(g2); 252 253 254 logger.debug("step3 ..."); 255 WeightedMultigraph<Node, DefaultEdge> g3 = step3(g2); 256// logger.info("after doing step 3 ...................................................................."); 257// GraphUtil.printGraphSimple(g3); 258// GraphUtil.printGraph(g3); 259 260 logger.debug("step4 ..."); 261 WeightedMultigraph<Node, DefaultEdge> g4 = step4(g3); 262// logger.info("after doing step 4 ...................................................................."); 263// GraphUtil.printGraphSimple(g4); 264// GraphUtil.printGraph(g4); 265 266 267 logger.debug("step5 ..."); 268 WeightedMultigraph<Node, DefaultEdge> g5 = step5(g4); 269// logger.info("after doing step 5 ...................................................................."); 270// GraphUtil.printGraphSimple(g5); 271// GraphUtil.printGraph(g5); 272 273 this.tree = g5; 274 logger.debug("exit>"); 275 276 //Add all the force added vertices 277// for (Node n : g1.vertexSet()) { 278// if (n.isForced()) 279// this.tree.addVertex(n); 280// } 281 } 282 283 public WeightedMultigraph<Node, DefaultEdge> getDefaultSteinerTree() { 284 return this.tree; 285 } 286}