Line data Source code
1 : #pragma once
2 :
3 : #include <map>
4 : #include <vector>
5 : #include <fstream>
6 : #include <algorithm>
7 : #include <deque>
8 : #include <memory>
9 : #include <string>
10 : #include <utility>
11 :
12 : #include "elsaDefines.h"
13 : #include "Common.h"
14 : #include "TypeCasts.hpp"
15 :
16 : namespace elsa::ml
17 : {
18 : namespace detail
19 : {
20 : /// A node in a graph.
21 : ///
22 : /// @author David Tellenbach
23 : ///
24 : /// @tparam T type of the data this node holds
25 : /// @tparam UseRawPtr If this parameter is set to `true` data is stored
26 : /// as a raw pointer, i.e., `T*`. If this parameter is set to `false`
27 : /// data is stored as `std::shared_ptr<T>`.
28 : template <typename T, bool UseRawPtr = true>
29 : class Node
30 : {
31 : public:
32 : using PointerType = std::conditional_t<UseRawPtr, T*, std::shared_ptr<T>>;
33 :
34 : /// construct a node by specifying its index in the graph
35 92 : explicit Node(index_t index) : index_(index), data_(nullptr) {}
36 :
37 : /// return the index of this node
38 302 : inline index_t getIndex() const { return index_; }
39 :
40 : /// return a pointer to the data held by this node
41 73 : PointerType getData() { return data_; }
42 :
43 : /// return a constant pointer to the data held by this node
44 4 : const PointerType getData() const { return data_; }
45 :
46 : /// set data held by this node
47 22 : inline void setData(PointerType data) { data_ = data; }
48 :
49 : private:
50 : /// this node's index in the graph
51 : index_t index_;
52 :
53 : /// the data this node holds
54 : PointerType data_;
55 : };
56 :
57 : /// An edge between nodes of the graph
58 : ///
59 : /// @author David Tellenbach
60 : ///
61 : /// @tparam T type of the data this node holds
62 : /// @tparam UseRawPtr If this parameter is set to `true` data is stored
63 : /// as a raw pointer, i.e., `T*`. If this parameter is set to `false`
64 : /// data is stored as `std::shared_ptr<T>`.
65 : template <typename T, bool UseRawPtr = true>
66 : class Edge
67 : {
68 : public:
69 : /// the type of nodes this edge connects
70 : using NodeType = Node<T, UseRawPtr>;
71 :
72 : /// construct an Edge by specifying its begin and end
73 46 : Edge(NodeType* begin, NodeType* end) : begin_(begin), end_(end) {}
74 :
75 : /// get a constant pointer to the begin-node of this edge
76 219 : const inline NodeType* begin() const { return begin_; }
77 :
78 : /// get a pointer to the begin-node of this edge
79 26 : inline NodeType* begin() { return begin_; }
80 :
81 : /// get a constant pointer to the end-node of this edge
82 29 : const inline NodeType* end() const { return end_; }
83 :
84 : /// get a pointer to the end-node of this edge
85 46 : inline NodeType* end() { return end_; }
86 :
87 : private:
88 : /// the begin-node of this edge
89 : NodeType* begin_;
90 :
91 : /// the end-node of this edge
92 : NodeType* end_;
93 : };
94 :
95 : /// A graph that can be altered and traversed in an efficient and
96 : /// structured manner.
97 : ///
98 : /// @author David Tellenbach
99 : ///
100 : /// @tparam T type of the data this node holds
101 : /// @tparam UseRawPtr If this parameter is set to `true` data is stored
102 : /// as a raw pointer, i.e., `T*`. If this parameter is set to `false`
103 : /// data is stored as `std::shared_ptr<T>`.
104 : template <typename T, bool UseRawPtr = true>
105 : class Graph
106 : {
107 : public:
108 : using PointerType = std::conditional_t<UseRawPtr, T*, std::shared_ptr<T>>;
109 :
110 : /// type of the graph nodes
111 : using NodeType = Node<T, UseRawPtr>;
112 :
113 : /// type of edges of the graph
114 : using EdgeType = Edge<T, UseRawPtr>;
115 :
116 : /// default constructor
117 8 : Graph() = default;
118 :
119 : /// construct a Graph by specifying an adjacency list, i.e., a list
120 : /// of edges connecting nodes
121 5 : Graph(std::initializer_list<std::pair<index_t, index_t>> edges)
122 5 : {
123 33 : for (const auto& edge : edges)
124 28 : addEdge(edge.first, edge.second);
125 5 : }
126 :
127 : /// return the number of nodes
128 30 : inline index_t getNumberOfNodes() const { return static_cast<index_t>(nodes_.size()); }
129 :
130 : /// return a pointer to the data held by the node at index
131 15 : inline PointerType getData(index_t index) { return nodes_.at(index).getData(); }
132 :
133 : /// return a constant pointer to the data held by the node at index
134 : const inline PointerType getData(index_t index) const
135 : {
136 : return nodes_.at(index).getData();
137 : }
138 :
139 : /// set data of node index
140 21 : inline void setData(index_t index, PointerType data)
141 : {
142 21 : return nodes_.at(index).setData(data);
143 : }
144 :
145 : /// add edge from begin to end
146 46 : inline void addEdge(index_t begin, index_t end)
147 : {
148 46 : nodes_.insert({begin, NodeType(begin)});
149 46 : nodes_.insert({end, NodeType(end)});
150 46 : edges_.emplace_back(EdgeType(&nodes_.at(begin), &nodes_.at(end)));
151 46 : }
152 :
153 : /// @returns a reference to a vector of edges
154 0 : inline std::vector<EdgeType>& getEdges() { return edges_; }
155 :
156 : /// @returns a constant reference to a vector of edges
157 : inline const std::vector<EdgeType>& getEdges() const { return edges_; }
158 :
159 : /// @returns a vector containing all edges beginning at a given
160 : /// index
161 31 : inline std::vector<EdgeType> getOutgoingEdges(index_t begin) const
162 : {
163 31 : std::vector<EdgeType> ret;
164 243 : for (const auto& edge : edges_)
165 212 : if (edge.begin()->getIndex() == begin)
166 49 : ret.push_back(edge);
167 :
168 31 : return ret;
169 : }
170 :
171 : /// @returns a vector containing all edges ending at a given index
172 4 : inline std::vector<EdgeType> getIncomingEdges(index_t end) const
173 : {
174 4 : std::vector<EdgeType> ret;
175 23 : for (const auto& edge : edges_)
176 19 : if (edge.end()->getIndex() == end)
177 7 : ret.push_back(edge);
178 :
179 4 : return ret;
180 : }
181 :
182 : /// Insert a new node after a node with a given index.
183 : ///
184 : /// @param index Index of the node after that the insertion
185 : /// should be done
186 : /// @param newIndex Index of the to be inserted node
187 1 : void insertNodeAfter(index_t index, index_t newIndex)
188 : {
189 4 : for (auto& e : getOutgoingEdges(index))
190 3 : addEdge(newIndex, e.end()->getIndex());
191 1 : removeOutgoingEdges(index);
192 1 : addEdge(index, newIndex);
193 1 : }
194 :
195 : /// @returns a reference to a map of indices and nodes
196 0 : inline std::map<index_t, NodeType>& getNodes() { return nodes_; }
197 :
198 : /// @returns a constant reference to a map of indices and nodes
199 : inline const std::map<index_t, NodeType>& getNodes() const { return nodes_; }
200 :
201 : /// delete all nodes and edges of the graph
202 0 : inline void clear()
203 : {
204 0 : edges_.clear();
205 0 : nodes_.clear();
206 0 : }
207 :
208 : /// Remove all outgoing edges from node with index begin
209 3 : inline void removeOutgoingEdges(index_t begin)
210 : {
211 3 : edges_.erase(std::remove_if(std::begin(edges_), std::end(edges_),
212 25 : [&begin](auto& edge) {
213 25 : return edge.begin()->getIndex() == begin;
214 : }),
215 3 : std::end(edges_));
216 3 : }
217 :
218 : /// Remove all incoming edges from node with index end
219 1 : inline void removeIncomingEdges(index_t end)
220 : {
221 2 : edges_.erase(
222 1 : std::remove_if(std::begin(edges_), std::end(edges_),
223 7 : [&end](auto& edge) { return edge.end()->getIndex() == end; }),
224 1 : std::end(edges_));
225 1 : }
226 :
227 : /// Remove a node from the graph.
228 : ///
229 : /// If preserveConnectivity is true, all begin-nodes of incoming
230 : /// edges get begin-nodes of end-nodes of outgoing edges.
231 1 : inline void removeNode(index_t idx, bool preserveConnectivity = true)
232 : {
233 1 : if (preserveConnectivity) {
234 : // Add an edge that bypassed the node that is to be removed
235 2 : auto incomingEdges = getIncomingEdges(idx);
236 2 : auto outgoingEdges = getOutgoingEdges(idx);
237 4 : for (const auto& inEdge : incomingEdges) {
238 3 : auto beginIdx = inEdge.begin()->getIndex();
239 9 : for (const auto& outEdge : outgoingEdges) {
240 6 : auto endIdx = outEdge.end()->getIndex();
241 6 : addEdge(beginIdx, endIdx);
242 : }
243 : }
244 : }
245 : // Remove all edges of the node
246 1 : removeOutgoingEdges(idx);
247 1 : removeIncomingEdges(idx);
248 :
249 : // Remove the node itself
250 1 : nodes_.erase(idx);
251 1 : }
252 :
253 : /// Visit all nodes of the graph. The visitors have access the node
254 : /// indices
255 : ///
256 : /// @param root Index of the node that serves as starting point
257 : /// of the traversal.
258 : ///
259 : /// @param visitor A function-like object (function, overloaded
260 : /// call operator, lambda) with the signature
261 : /// `void(T* data, index_t index)`.
262 : /// The visitor will be applied to every node in a breadth-first
263 : /// traversal. The semantics of the parameter is
264 : /// @param nextVisitor
265 : /// @param stop
266 : ///
267 : /// - `data` is the data held by the current node in the traversal
268 : /// - `index` is the index of the current node
269 : template <typename Visitor, typename NextVisitor, typename StopFunctor>
270 0 : void visitWithIndex(index_t root, Visitor visitor, NextVisitor nextVisitor,
271 : StopFunctor stop)
272 : {
273 : visitImpl<Visitor, NextVisitor, StopFunctor, /* access index */ true,
274 0 : /* forward */ true>(root, std::forward<Visitor>(visitor),
275 0 : std::forward<NextVisitor>(nextVisitor),
276 0 : std::forward<StopFunctor>(stop));
277 0 : }
278 :
279 : template <typename Visitor, typename NextVisitor, typename StopFunctor>
280 3 : void visit(index_t root, Visitor visitor, NextVisitor nextVisitor, StopFunctor stop)
281 : {
282 : visitImpl<Visitor, NextVisitor, StopFunctor, /* access index */ false,
283 3 : /* forward */ true>(root, std::forward<Visitor>(visitor),
284 3 : std::forward<NextVisitor>(nextVisitor),
285 3 : std::forward<StopFunctor>(stop));
286 3 : }
287 :
288 : template <typename Visitor, typename NextVisitor>
289 0 : void visit(index_t root, Visitor visitor, NextVisitor nextVisitor)
290 : {
291 0 : visit(root, visitor, nextVisitor, []([[maybe_unused]] auto node) { return false; });
292 0 : }
293 :
294 : template <typename Visitor>
295 3 : void visit(index_t root, Visitor visitor)
296 : {
297 16 : visit(
298 : root, visitor,
299 13 : []([[maybe_unused]] auto node, [[maybe_unused]] auto nextNode) {},
300 16 : []([[maybe_unused]] auto node) { return false; });
301 3 : }
302 :
303 : template <typename Visitor, typename NextVisitor, typename StopFunctor>
304 0 : void visitBackward(index_t root, Visitor visitor, NextVisitor nextVisitor,
305 : StopFunctor stop)
306 : {
307 : visitImpl<Visitor, NextVisitor, StopFunctor, /* access index */ false,
308 0 : /* forward */ false>(root, std::forward<Visitor>(visitor),
309 0 : std::forward<NextVisitor>(nextVisitor),
310 0 : std::forward<StopFunctor>(stop));
311 0 : }
312 :
313 : template <typename Visitor, typename NextVisitor, typename StopFunctor>
314 0 : void visitBackwardWithIndex(index_t root, Visitor visitor, NextVisitor nextVisitor,
315 : StopFunctor stop)
316 : {
317 : visitImpl<Visitor, NextVisitor, StopFunctor, /* access index */ true,
318 0 : /* forward */ false>(root, std::forward<Visitor>(visitor),
319 0 : std::forward<NextVisitor>(nextVisitor),
320 0 : std::forward<StopFunctor>(stop));
321 0 : }
322 :
323 : enum class RankDir { TD, LR };
324 :
325 : /// Print a representation of the graph in the Dot language.
326 : ///
327 : /// @param filename the name of the file that is written to
328 : /// @param nodePrinter a function (or lambda) that defines how
329 : /// to print the data of a single node
330 : /// @param dpi Dots-per-inch definition for the Dot representation
331 : /// @param rankDir RankDir::TD if the graph should be drawn top-down,
332 : /// rankdDir::LR if it should be drawn left-right
333 : template <typename NodePrinter>
334 : void toDot(const std::string& filename, NodePrinter nodePrinter, index_t dpi = 90,
335 : RankDir rankDir = RankDir::TD)
336 : {
337 : std::ofstream os(filename);
338 : os << "digraph {\n";
339 : os << "graph [" << (rankDir == RankDir::TD ? "rankdir=TD" : "rankdir=LR")
340 : << ", dpi=" << dpi << "];\n";
341 :
342 : // print all nodes
343 : for (const auto& node : getNodes()) {
344 : if (node.second.getData()) {
345 : os << nodePrinter(node.second.getData(), node.first) << "\n";
346 : } else {
347 : os << node.first << "\n";
348 : }
349 : }
350 :
351 : // print all edges
352 : for (const auto& edge : getEdges())
353 : os << edge.begin()->getIndex() << "->" << edge.end()->getIndex() << ";\n";
354 :
355 : os << "}";
356 : }
357 :
358 : private:
359 : template <typename Visitor, typename NextVisitor, typename StopFunctor,
360 : bool AccessIndex, bool Forward>
361 3 : void visitImpl(index_t root, Visitor visitor, NextVisitor nextVisitor, StopFunctor stop)
362 : {
363 3 : if (nodes_.find(root) == nodes_.end())
364 0 : throw std::invalid_argument("Unknown node index");
365 :
366 : // We maintain a list of nodes we've already visited and a call-queue to
367 : // ensure the correct order of traversal.
368 6 : std::vector<bool> visited(asUnsigned(getNumberOfNodes()));
369 : // Note that this queue is in fact a deque, so we can push and pop from
370 : // both, the front and back.
371 6 : std::deque<index_t> queue;
372 :
373 : // Perform an iterative depth-first traversal through the graph
374 : // Push the input-node onto the call-queue
375 3 : queue.push_back(root);
376 :
377 19 : while (!queue.empty()) {
378 16 : index_t s = queue.back();
379 16 : queue.pop_back();
380 :
381 16 : if (!visited[static_cast<std::size_t>(s)]) {
382 :
383 : bool needStop;
384 : if constexpr (AccessIndex) {
385 0 : needStop = stop(nodes_.at(s).getData(), s);
386 : } else {
387 16 : needStop = stop(nodes_.at(s).getData());
388 : }
389 16 : if (!needStop) {
390 : if constexpr (AccessIndex) {
391 0 : visitor(nodes_.at(s).getData(), s);
392 : } else {
393 16 : visitor(nodes_.at(s).getData());
394 : }
395 16 : visited[asUnsigned(s)] = true;
396 : } else {
397 0 : queue.push_front(s);
398 0 : continue;
399 : }
400 : }
401 :
402 : if constexpr (Forward) {
403 35 : for (auto& e : getOutgoingEdges(s)) {
404 19 : auto idx = e.end()->getIndex();
405 :
406 19 : if (!visited[static_cast<std::size_t>(idx)]) {
407 13 : queue.push_back(idx);
408 : if constexpr (AccessIndex) {
409 0 : nextVisitor(nodes_.at(s).getData(), s, e.end()->getData(), idx);
410 : } else {
411 13 : nextVisitor(nodes_.at(s).getData(), e.end()->getData());
412 : }
413 : }
414 : }
415 : } else {
416 0 : for (auto& e : getIncomingEdges(s)) {
417 0 : auto idx = e.begin()->getIndex();
418 :
419 0 : if (!visited[static_cast<std::size_t>(idx)]) {
420 0 : queue.push_back(idx);
421 : if constexpr (AccessIndex) {
422 0 : nextVisitor(nodes_.at(s).getData(), s, e.begin()->getData(),
423 : idx);
424 : } else {
425 0 : nextVisitor(nodes_.at(s).getData(), e.begin()->getData());
426 : }
427 : }
428 : }
429 : }
430 : }
431 3 : }
432 :
433 : std::vector<EdgeType> edges_;
434 : std::map<index_t, NodeType> nodes_;
435 : };
436 :
437 : } // namespace detail
438 : } // namespace elsa::ml
|