001package org.dllearner.algorithms.qtl.util; 002 003import java.util.ArrayList; 004import java.util.List; 005import java.util.Set; 006import java.util.function.Supplier; 007 008import org.jgrapht.Graph; 009import org.jgrapht.alg.shortestpath.BellmanFordShortestPath; 010import org.jgrapht.alg.shortestpath.DijkstraShortestPath; 011import org.jgrapht.alg.spanning.KruskalMinimumSpanningTree; 012import org.jgrapht.graph.Pseudograph; 013import org.jgrapht.graph.WeightedMultigraph; 014import org.jgrapht.graph.WeightedPseudograph; 015import org.slf4j.Logger; 016import org.slf4j.LoggerFactory; 017 018public class SteinerTreeGeneric<V, E> { 019 020 private static final Logger logger = LoggerFactory.getLogger(SteinerTreeGeneric.class); 021 022 023 Graph<V, E> graph; 024 WeightedMultigraph<V, E> tree; 025 List<V> steinerNodes; 026 private final Class<? extends E> edgeClass; 027 028 public SteinerTreeGeneric(Graph<V, E> graph, List<V> steinerNodes, Class<? extends E> edgeClass) { 029 this.graph = graph; 030 this.steinerNodes = steinerNodes; 031 this.edgeClass = edgeClass; 032 033 runAlgorithm(); 034 } 035 036 /** 037 * Construct the complete undirected distance graph G1=(V1,EI,d1) from G and S. 038 */ 039 private Pseudograph<V, E> step1() { 040 041 logger.debug("<enter"); 042 043 Pseudograph<V, E> g = new WeightedPseudograph<V, E>(edgeClass); 044 045 for (V n : this.steinerNodes) { 046 g.addVertex(n); 047 } 048 049 BellmanFordShortestPath<V, E> pathGen = new BellmanFordShortestPath<>(this.graph); 050 051 for (V n1 : this.steinerNodes) { 052 for (V n2 : this.steinerNodes) { 053 054 if (n1.equals(n2)) 055 continue; 056 057 if (g.containsEdge(n1, n2)) 058 continue; 059 060 E e = g.addEdge(n1, n2); 061 g.setEdgeWeight(e, pathGen.getPathWeight(n1, n2)); 062 063 } 064 065 } 066 067 logger.debug("exit>"); 068 069 return g; 070 071 } 072 073 /** 074 * Find the minimal spanning tree, T1, of G1. (If there are several minimal spanning trees, pick an arbitrary one.) 075 * 076 * @param g1 077 * @return 078 */ 079 private WeightedMultigraph<V, E> step2(Pseudograph<V, E> g1) { 080 081 logger.debug("<enter"); 082 083 KruskalMinimumSpanningTree<V, E> mst = new KruskalMinimumSpanningTree<>(g1); 084 085// logger.debug("Total MST Cost: " + mst.getSpanningTreeCost()); 086 087 Set<E> edges = mst.getSpanningTree().getEdges(); 088 089 WeightedMultigraph<V, E> g2 = new WeightedMultigraph<>(edgeClass); 090 091 List<E> edgesSortedById = new ArrayList<>(edges); 092// edgesSortedById.sort(); 093 094 for (E edge : edgesSortedById) { 095 g2.addVertex(g1.getEdgeSource(edge)); 096 g2.addVertex(g1.getEdgeTarget(edge)); 097 g2.addEdge(g1.getEdgeSource(edge), g1.getEdgeTarget(edge), edge); 098 } 099 100 logger.debug("exit>"); 101 102 return g2; 103 } 104 105 /** 106 * Construct the subgraph, Gs, of G by replacing each edge in T1 by its corresponding shortest path in G. 107 * (If there are several shortest paths, pick an arbitrary one.) 108 * 109 * @param g2 110 * @return 111 */ 112 private WeightedMultigraph<V, E> step3(WeightedMultigraph<V, E> g2) { 113 114 logger.debug("<enter"); 115 116 WeightedMultigraph<V, E> g3 = new WeightedMultigraph<>(edgeClass); 117 118 Set<E> edges = g2.edgeSet(); 119 DijkstraShortestPath<V, E> pathGen = new DijkstraShortestPath<>(this.graph); 120 121 V source, target; 122 123 for (E edge : edges) { 124 source = g2.getEdgeSource(edge); 125 target = g2.getEdgeTarget(edge); 126 127 128 List<E> pathEdges = pathGen.getPath(source, target).getEdgeList(); 129 130 if (pathEdges == null) 131 continue; 132 133 for (int i = 0; i < pathEdges.size(); i++) { 134 135 if (g3.edgeSet().contains(pathEdges.get(i))) 136 continue; 137 138 source = g2.getEdgeSource(pathEdges.get(i)); 139 target = g2.getEdgeTarget(pathEdges.get(i)); 140 141 if (!g3.vertexSet().contains(source)) 142 g3.addVertex(source); 143 144 if (!g3.vertexSet().contains(target)) 145 g3.addVertex(target); 146 147 g3.addEdge(source, target, pathEdges.get(i)); 148 } 149 } 150 151 logger.debug("exit>"); 152 153 return g3; 154 } 155 156 /** 157 * Find the minimal spanning tree, Ts, of Gs. (If there are several minimal spanning trees, pick an arbitrary one.) 158 * 159 * @param g3 160 * @return 161 */ 162 private WeightedMultigraph<V, E> step4(WeightedMultigraph<V, E> g3) { 163 164 logger.debug("<enter"); 165 166 KruskalMinimumSpanningTree<V, E> mst = new KruskalMinimumSpanningTree<>(g3); 167 168// logger.debug("Total MST Cost: " + mst.getSpanningTreeCost()); 169 170 Set<E> edges = mst.getSpanningTree().getEdges(); 171 172 WeightedMultigraph<V, E> g4 = new WeightedMultigraph<>(edgeClass); 173 174 List<E> edgesSortedById = new ArrayList<>(edges); 175// Collections.sort(edgesSortedById); 176 177 for (E edge : edgesSortedById) { 178 g4.addVertex(g3.getEdgeSource(edge)); 179 g4.addVertex(g3.getEdgeTarget(edge)); 180 g4.addEdge(g3.getEdgeSource(edge), g3.getEdgeTarget(edge), edge); 181 } 182 183 logger.debug("exit>"); 184 185 return g4; 186 } 187 188 /** 189 * Construct a Steiner tree, Th, from Ts by deleting edges in Ts,if necessary, 190 * so that all the leaves in Th are Steiner points. 191 * 192 * @param g4 193 * @return 194 */ 195 private WeightedMultigraph<V, E> step5(WeightedMultigraph<V, E> g4) { 196 197 logger.debug("<enter"); 198 199 WeightedMultigraph<V, E> g5 = g4; 200 201 List<V> nonSteinerLeaves = new ArrayList<>(); 202 203 Set<V> vertexSet = g4.vertexSet(); 204 for (V vertex : vertexSet) { 205 if (g5.degreeOf(vertex) == 1 && steinerNodes.indexOf(vertex) == -1) { 206 nonSteinerLeaves.add(vertex); 207 } 208 } 209 210 V source, target; 211 for (int i = 0; i < nonSteinerLeaves.size(); i++) { 212 source = nonSteinerLeaves.get(i); 213 do { 214 E e = g5.edgesOf(source).iterator().next(); 215 target = this.graph.getEdgeTarget(e); 216 217 // this should not happen, but just in case of ... 218 if (target.equals(source)) 219 target = g5.getEdgeSource(e); 220 221 g5.removeVertex(source); 222 source = target; 223 } while (g5.degreeOf(source) == 1 && steinerNodes.indexOf(source) == -1); 224 225 } 226 227 logger.debug("exit>"); 228 229 return g5; 230 } 231 232 private void runAlgorithm() { 233 234 logger.debug("<enter"); 235 236 logger.debug("step1 ..."); 237 Pseudograph<V, E> g1 = step1(); 238// logger.info("after doing step 1 ...................................................................."); 239// GraphUtil.printGraphSimple(g1); 240// GraphUtil.printGraph(g1); 241 242 if (g1.vertexSet().size() < 2) { 243 this.tree = new WeightedMultigraph<>(edgeClass); 244 for (V n : g1.vertexSet()) this.tree.addVertex(n); 245 return; 246 } 247 248 logger.debug("step2 ..."); 249 WeightedMultigraph<V, E> g2 = step2(g1); 250// logger.info("after doing step 2 ...................................................................."); 251// GraphUtil.printGraphSimple(g2); 252// GraphUtil.printGraph(g2); 253 254 255 logger.debug("step3 ..."); 256 WeightedMultigraph<V, E> g3 = step3(g2); 257// logger.info("after doing step 3 ...................................................................."); 258// GraphUtil.printGraphSimple(g3); 259// GraphUtil.printGraph(g3); 260 261 logger.debug("step4 ..."); 262 WeightedMultigraph<V, E> g4 = step4(g3); 263// logger.info("after doing step 4 ...................................................................."); 264// GraphUtil.printGraphSimple(g4); 265// GraphUtil.printGraph(g4); 266 267 268 logger.debug("step5 ..."); 269 WeightedMultigraph<V, E> g5 = step5(g4); 270// logger.info("after doing step 5 ...................................................................."); 271// GraphUtil.printGraphSimple(g5); 272// GraphUtil.printGraph(g5); 273 274 this.tree = g5; 275 logger.debug("exit>"); 276 277 //Add all the force added vertices 278// for (V n : g1.vertexSet()) { 279// if (n.isForced()) 280// this.tree.addVertex(n); 281// } 282 } 283 284 public WeightedMultigraph<V, E> getDefaultSteinerTree() { 285 return this.tree; 286 } 287}