001package org.dllearner.algorithms.qtl.util;
002
003import java.util.ArrayList;
004import java.util.List;
005import java.util.Set;
006import java.util.function.Supplier;
007
008import org.jgrapht.Graph;
009import org.jgrapht.alg.shortestpath.BellmanFordShortestPath;
010import org.jgrapht.alg.shortestpath.DijkstraShortestPath;
011import org.jgrapht.alg.spanning.KruskalMinimumSpanningTree;
012import org.jgrapht.graph.Pseudograph;
013import org.jgrapht.graph.WeightedMultigraph;
014import org.jgrapht.graph.WeightedPseudograph;
015import org.slf4j.Logger;
016import org.slf4j.LoggerFactory;
017
018public class SteinerTreeGeneric<V, E> {
019
020        private static final Logger logger = LoggerFactory.getLogger(SteinerTreeGeneric.class);
021
022
023        Graph<V, E> graph;
024        WeightedMultigraph<V, E> tree;
025        List<V> steinerNodes;
026        private final Class<? extends E> edgeClass;
027
028        public SteinerTreeGeneric(Graph<V, E> graph, List<V> steinerNodes, Class<? extends E> edgeClass) {
029                this.graph = graph;
030                this.steinerNodes = steinerNodes;
031                this.edgeClass = edgeClass;
032
033                runAlgorithm();
034        }
035
036        /**
037         * Construct the complete undirected distance graph G1=(V1,EI,d1) from G and S.
038         */
039        private Pseudograph<V, E> step1() {
040
041                logger.debug("<enter");
042
043                Pseudograph<V, E> g = new WeightedPseudograph<V, E>(edgeClass);
044
045                for (V n : this.steinerNodes) {
046                        g.addVertex(n);
047                }
048
049                BellmanFordShortestPath<V, E> pathGen = new BellmanFordShortestPath<>(this.graph);
050
051                for (V n1 : this.steinerNodes) {
052                        for (V n2 : this.steinerNodes) {
053
054                                if (n1.equals(n2))
055                                        continue;
056
057                                if (g.containsEdge(n1, n2))
058                                        continue;
059
060                                E e = g.addEdge(n1, n2);
061                                g.setEdgeWeight(e, pathGen.getPathWeight(n1, n2));
062
063                        }
064
065                }
066
067                logger.debug("exit>");
068
069                return g;
070
071        }
072
073        /**
074         * Find the minimal spanning tree, T1, of G1. (If there are several minimal spanning trees, pick an arbitrary one.)
075         *
076         * @param g1
077         * @return
078         */
079        private WeightedMultigraph<V, E> step2(Pseudograph<V, E> g1) {
080
081                logger.debug("<enter");
082
083                KruskalMinimumSpanningTree<V, E> mst = new KruskalMinimumSpanningTree<>(g1);
084
085//      logger.debug("Total MST Cost: " + mst.getSpanningTreeCost());
086
087                Set<E> edges = mst.getSpanningTree().getEdges();
088
089                WeightedMultigraph<V, E> g2 = new WeightedMultigraph<>(edgeClass);
090
091                List<E> edgesSortedById = new ArrayList<>(edges);
092//              edgesSortedById.sort();
093
094                for (E edge : edgesSortedById) {
095                        g2.addVertex(g1.getEdgeSource(edge));
096                        g2.addVertex(g1.getEdgeTarget(edge));
097                        g2.addEdge(g1.getEdgeSource(edge), g1.getEdgeTarget(edge), edge);
098                }
099
100                logger.debug("exit>");
101
102                return g2;
103        }
104
105        /**
106         * Construct the subgraph, Gs, of G by replacing each edge in T1 by its corresponding shortest path in G.
107         * (If there are several shortest paths, pick an arbitrary one.)
108         *
109         * @param g2
110         * @return
111         */
112        private WeightedMultigraph<V, E> step3(WeightedMultigraph<V, E> g2) {
113
114                logger.debug("<enter");
115
116                WeightedMultigraph<V, E> g3 = new WeightedMultigraph<>(edgeClass);
117
118                Set<E> edges = g2.edgeSet();
119                DijkstraShortestPath<V, E> pathGen = new DijkstraShortestPath<>(this.graph);
120
121                V source, target;
122
123                for (E edge : edges) {
124                        source = g2.getEdgeSource(edge);
125                        target = g2.getEdgeTarget(edge);
126
127
128                        List<E> pathEdges = pathGen.getPath(source, target).getEdgeList();
129
130                        if (pathEdges == null)
131                                continue;
132
133                        for (int i = 0; i < pathEdges.size(); i++) {
134
135                                if (g3.edgeSet().contains(pathEdges.get(i)))
136                                        continue;
137
138                                source = g2.getEdgeSource(pathEdges.get(i));
139                                target = g2.getEdgeTarget(pathEdges.get(i));
140
141                                if (!g3.vertexSet().contains(source))
142                                        g3.addVertex(source);
143
144                                if (!g3.vertexSet().contains(target))
145                                        g3.addVertex(target);
146
147                                g3.addEdge(source, target, pathEdges.get(i));
148                        }
149                }
150
151                logger.debug("exit>");
152
153                return g3;
154        }
155
156        /**
157         * Find the minimal spanning tree, Ts, of Gs. (If there are several minimal spanning trees, pick an arbitrary one.)
158         *
159         * @param g3
160         * @return
161         */
162        private WeightedMultigraph<V, E> step4(WeightedMultigraph<V, E> g3) {
163
164                logger.debug("<enter");
165
166                KruskalMinimumSpanningTree<V, E> mst = new KruskalMinimumSpanningTree<>(g3);
167
168//      logger.debug("Total MST Cost: " + mst.getSpanningTreeCost());
169
170                Set<E> edges = mst.getSpanningTree().getEdges();
171
172                WeightedMultigraph<V, E> g4 = new WeightedMultigraph<>(edgeClass);
173
174                List<E> edgesSortedById = new ArrayList<>(edges);
175//              Collections.sort(edgesSortedById);
176
177                for (E edge : edgesSortedById) {
178                        g4.addVertex(g3.getEdgeSource(edge));
179                        g4.addVertex(g3.getEdgeTarget(edge));
180                        g4.addEdge(g3.getEdgeSource(edge), g3.getEdgeTarget(edge), edge);
181                }
182
183                logger.debug("exit>");
184
185                return g4;
186        }
187
188        /**
189         * Construct a Steiner tree, Th, from Ts by deleting edges in Ts,if necessary,
190         * so that all the leaves in Th are Steiner points.
191         *
192         * @param g4
193         * @return
194         */
195        private WeightedMultigraph<V, E> step5(WeightedMultigraph<V, E> g4) {
196
197                logger.debug("<enter");
198
199                WeightedMultigraph<V, E> g5 = g4;
200
201                List<V> nonSteinerLeaves = new ArrayList<>();
202
203                Set<V> vertexSet = g4.vertexSet();
204                for (V vertex : vertexSet) {
205                        if (g5.degreeOf(vertex) == 1 && steinerNodes.indexOf(vertex) == -1) {
206                                nonSteinerLeaves.add(vertex);
207                        }
208                }
209
210                V source, target;
211                for (int i = 0; i < nonSteinerLeaves.size(); i++) {
212                        source = nonSteinerLeaves.get(i);
213                        do {
214                                E e = g5.edgesOf(source).iterator().next();
215                                target = this.graph.getEdgeTarget(e);
216
217                                // this should not happen, but just in case of ...
218                                if (target.equals(source))
219                                        target = g5.getEdgeSource(e);
220
221                                g5.removeVertex(source);
222                                source = target;
223                        } while (g5.degreeOf(source) == 1 && steinerNodes.indexOf(source) == -1);
224
225                }
226
227                logger.debug("exit>");
228
229                return g5;
230        }
231
232        private void runAlgorithm() {
233
234                logger.debug("<enter");
235
236                logger.debug("step1 ...");
237                Pseudograph<V, E> g1 = step1();
238//              logger.info("after doing step 1 ....................................................................");
239//              GraphUtil.printGraphSimple(g1);
240//              GraphUtil.printGraph(g1);
241
242                if (g1.vertexSet().size() < 2) {
243                        this.tree = new WeightedMultigraph<>(edgeClass);
244                        for (V n : g1.vertexSet()) this.tree.addVertex(n);
245                        return;
246                }
247
248                logger.debug("step2 ...");
249                WeightedMultigraph<V, E> g2 = step2(g1);
250//              logger.info("after doing step 2 ....................................................................");
251//              GraphUtil.printGraphSimple(g2);
252//              GraphUtil.printGraph(g2);
253
254
255                logger.debug("step3 ...");
256                WeightedMultigraph<V, E> g3 = step3(g2);
257//              logger.info("after doing step 3 ....................................................................");
258//              GraphUtil.printGraphSimple(g3);
259//              GraphUtil.printGraph(g3);
260
261                logger.debug("step4 ...");
262                WeightedMultigraph<V, E> g4 = step4(g3);
263//              logger.info("after doing step 4 ....................................................................");
264//              GraphUtil.printGraphSimple(g4);
265//              GraphUtil.printGraph(g4);
266
267
268                logger.debug("step5 ...");
269                WeightedMultigraph<V, E> g5 = step5(g4);
270//              logger.info("after doing step 5 ....................................................................");
271//              GraphUtil.printGraphSimple(g5);
272//              GraphUtil.printGraph(g5);
273
274                this.tree = g5;
275                logger.debug("exit>");
276
277                //Add all the force added vertices
278//              for (V n : g1.vertexSet()) {
279//                      if (n.isForced())
280//                              this.tree.addVertex(n);
281//              }
282        }
283
284        public WeightedMultigraph<V, E> getDefaultSteinerTree() {
285                return this.tree;
286        }
287}