001package org.dllearner.algorithms.qtl.util;
002
003import java.util.ArrayList;
004import java.util.List;
005import java.util.Set;
006
007import org.apache.jena.graph.Node;
008import org.jgrapht.Graph;
009import org.jgrapht.alg.shortestpath.BellmanFordShortestPath;
010import org.jgrapht.alg.shortestpath.DijkstraShortestPath;
011import org.jgrapht.alg.spanning.KruskalMinimumSpanningTree;
012import org.jgrapht.graph.DefaultEdge;
013import org.jgrapht.graph.Pseudograph;
014import org.jgrapht.graph.WeightedMultigraph;
015import org.slf4j.Logger;
016import org.slf4j.LoggerFactory;
017
018public class SteinerTree {
019
020        private static final Logger logger = LoggerFactory.getLogger(SteinerTree.class);
021
022        Graph<Node, DefaultEdge> graph;
023        WeightedMultigraph<Node, DefaultEdge> tree;
024        List<Node> steinerNodes;
025
026        public SteinerTree(Graph<Node, DefaultEdge> graph, List<Node> steinerNodes) {
027                this.graph = graph;
028                this.steinerNodes = steinerNodes;
029
030                runAlgorithm();
031        }
032
033        /**
034         * Construct the complete undirected distance graph G1=(V1,EI,d1) from G and S.
035         */
036        private Pseudograph<Node, DefaultEdge> step1() {
037
038                logger.debug("<enter");
039
040                Pseudograph<Node, DefaultEdge> g = new Pseudograph<>(DefaultEdge.class);
041
042                for (Node n : this.steinerNodes) {
043                        g.addVertex(n);
044                }
045
046                BellmanFordShortestPath<Node, DefaultEdge> pathGen = new BellmanFordShortestPath<>(this.graph);
047
048                for (Node n1 : this.steinerNodes) {
049                        for (Node n2 : this.steinerNodes) {
050
051                                if (n1.equals(n2))
052                                        continue;
053
054                                if (g.containsEdge(n1, n2))
055                                        continue;
056
057                                DefaultEdge e = new DefaultEdge();
058                                g.addEdge(n1, n2, e);
059                                g.setEdgeWeight(e, pathGen.getPathWeight(n1, n2));
060
061                        }
062
063                }
064
065                logger.debug("exit>");
066
067                return g;
068
069        }
070
071        /**
072         * Find the minimal spanning tree, T1, of G1. (If there are several minimal spanning trees, pick an arbitrary one.)
073         *
074         * @param g1
075         * @return
076         */
077        private WeightedMultigraph<Node, DefaultEdge> step2(Pseudograph<Node, DefaultEdge> g1) {
078
079                logger.debug("<enter");
080
081                KruskalMinimumSpanningTree<Node, DefaultEdge> mst = new KruskalMinimumSpanningTree<>(g1);
082
083//      logger.debug("Total MST Cost: " + mst.getSpanningTreeCost());
084
085                Set<DefaultEdge> edges = mst.getSpanningTree().getEdges();
086
087                WeightedMultigraph<Node, DefaultEdge> g2 = new WeightedMultigraph<>(DefaultEdge.class);
088
089                List<DefaultEdge> edgesSortedById = new ArrayList<>(edges);
090//              edgesSortedById.sort();
091
092                for (DefaultEdge edge : edgesSortedById) {
093                        g2.addVertex(g1.getEdgeSource(edge));
094                        g2.addVertex(g1.getEdgeTarget(edge));
095                        g2.addEdge(g1.getEdgeSource(edge), g1.getEdgeTarget(edge), edge);
096                }
097
098                logger.debug("exit>");
099
100                return g2;
101        }
102
103        /**
104         * Construct the subgraph, Gs, of G by replacing each edge in T1 by its corresponding shortest path in G.
105         * (If there are several shortest paths, pick an arbitrary one.)
106         *
107         * @param g2
108         * @return
109         */
110        private WeightedMultigraph<Node, DefaultEdge> step3(WeightedMultigraph<Node, DefaultEdge> g2) {
111
112                logger.debug("<enter");
113
114                WeightedMultigraph<Node, DefaultEdge> g3 = new WeightedMultigraph<>(DefaultEdge.class);
115
116                Set<DefaultEdge> edges = g2.edgeSet();
117                DijkstraShortestPath<Node, DefaultEdge> pathGen = new DijkstraShortestPath<>(this.graph);
118
119                Node source, target;
120
121                for (DefaultEdge edge : edges) {
122                        source = g2.getEdgeSource(edge);
123                        target = g2.getEdgeTarget(edge);
124
125
126                        List<DefaultEdge> pathEdges = pathGen.getPath(source, target).getEdgeList();
127
128                        if (pathEdges == null)
129                                continue;
130
131                        for (int i = 0; i < pathEdges.size(); i++) {
132
133                                if (g3.edgeSet().contains(pathEdges.get(i)))
134                                        continue;
135
136                                source = g2.getEdgeSource(pathEdges.get(i));
137                                target = g2.getEdgeTarget(pathEdges.get(i));
138
139                                if (!g3.vertexSet().contains(source))
140                                        g3.addVertex(source);
141
142                                if (!g3.vertexSet().contains(target))
143                                        g3.addVertex(target);
144
145                                g3.addEdge(source, target, pathEdges.get(i));
146                        }
147                }
148
149                logger.debug("exit>");
150
151                return g3;
152        }
153
154        /**
155         * Find the minimal spanning tree, Ts, of Gs. (If there are several minimal spanning trees, pick an arbitrary one.)
156         *
157         * @param g3
158         * @return
159         */
160        private WeightedMultigraph<Node, DefaultEdge> step4(WeightedMultigraph<Node, DefaultEdge> g3) {
161
162                logger.debug("<enter");
163
164                KruskalMinimumSpanningTree<Node, DefaultEdge> mst = new KruskalMinimumSpanningTree<>(g3);
165
166//      logger.debug("Total MST Cost: " + mst.getSpanningTreeCost());
167
168                Set<DefaultEdge> edges = mst.getSpanningTree().getEdges();
169
170                WeightedMultigraph<Node, DefaultEdge> g4 =
171                                new WeightedMultigraph<>(DefaultEdge.class);
172
173                List<DefaultEdge> edgesSortedById = new ArrayList<>(edges);
174//              Collections.sort(edgesSortedById);
175
176                for (DefaultEdge edge : edgesSortedById) {
177                        g4.addVertex(g3.getEdgeSource(edge));
178                        g4.addVertex(g3.getEdgeTarget(edge));
179                        g4.addEdge(g3.getEdgeSource(edge), g3.getEdgeTarget(edge), edge);
180                }
181
182                logger.debug("exit>");
183
184                return g4;
185        }
186
187        /**
188         * Construct a Steiner tree, Th, from Ts by deleting edges in Ts,if necessary,
189         * so that all the leaves in Th are Steiner points.
190         *
191         * @param g4
192         * @return
193         */
194        private WeightedMultigraph<Node, DefaultEdge> step5(WeightedMultigraph<Node, DefaultEdge> g4) {
195
196                logger.debug("<enter");
197
198                WeightedMultigraph<Node, DefaultEdge> g5 = g4;
199
200                List<Node> nonSteinerLeaves = new ArrayList<>();
201
202                Set<Node> vertexSet = g4.vertexSet();
203                for (Node vertex : vertexSet) {
204                        if (g5.degreeOf(vertex) == 1 && steinerNodes.indexOf(vertex) == -1) {
205                                nonSteinerLeaves.add(vertex);
206                        }
207                }
208
209                Node source, target;
210                for (int i = 0; i < nonSteinerLeaves.size(); i++) {
211                        source = nonSteinerLeaves.get(i);
212                        do {
213                                DefaultEdge e = g5.edgesOf(source).toArray(new DefaultEdge[0])[0];
214                                target = this.graph.getEdgeTarget(e);
215
216                                // this should not happen, but just in case of ...
217                                if (target.equals(source))
218                                        target = g5.getEdgeSource(e);
219
220                                g5.removeVertex(source);
221                                source = target;
222                        } while (g5.degreeOf(source) == 1 && steinerNodes.indexOf(source) == -1);
223
224                }
225
226                logger.debug("exit>");
227
228                return g5;
229        }
230
231        private void runAlgorithm() {
232
233                logger.debug("<enter");
234
235                logger.debug("step1 ...");
236                Pseudograph<Node, DefaultEdge> g1 = step1();
237//              logger.info("after doing step 1 ....................................................................");
238//              GraphUtil.printGraphSimple(g1);
239//              GraphUtil.printGraph(g1);
240
241                if (g1.vertexSet().size() < 2) {
242                        this.tree = new WeightedMultigraph<>(DefaultEdge.class);
243                        for (Node n : g1.vertexSet()) this.tree.addVertex(n);
244                        return;
245                }
246
247                logger.debug("step2 ...");
248                WeightedMultigraph<Node, DefaultEdge> g2 = step2(g1);
249//              logger.info("after doing step 2 ....................................................................");
250//              GraphUtil.printGraphSimple(g2);
251//              GraphUtil.printGraph(g2);
252
253
254                logger.debug("step3 ...");
255                WeightedMultigraph<Node, DefaultEdge> g3 = step3(g2);
256//              logger.info("after doing step 3 ....................................................................");
257//              GraphUtil.printGraphSimple(g3);
258//              GraphUtil.printGraph(g3);
259
260                logger.debug("step4 ...");
261                WeightedMultigraph<Node, DefaultEdge> g4 = step4(g3);
262//              logger.info("after doing step 4 ....................................................................");
263//              GraphUtil.printGraphSimple(g4);
264//              GraphUtil.printGraph(g4);
265
266
267                logger.debug("step5 ...");
268                WeightedMultigraph<Node, DefaultEdge> g5 = step5(g4);
269//              logger.info("after doing step 5 ....................................................................");
270//              GraphUtil.printGraphSimple(g5);
271//              GraphUtil.printGraph(g5);
272
273                this.tree = g5;
274                logger.debug("exit>");
275
276                //Add all the force added vertices
277//              for (Node n : g1.vertexSet()) {
278//                      if (n.isForced())
279//                              this.tree.addVertex(n);
280//              }
281        }
282
283        public WeightedMultigraph<Node, DefaultEdge> getDefaultSteinerTree() {
284                return this.tree;
285        }
286}