LCOV - code coverage report
Current view: top level - ml - Graph.h (source / functions) Hit Total Coverage
Test: test_coverage.info.cleaned Lines: 97 133 72.9 %
Date: 2022-07-06 02:47:47 Functions: 60 108 55.6 %

          Line data    Source code
       1             : #pragma once
       2             : 
       3             : #include <map>
       4             : #include <vector>
       5             : #include <fstream>
       6             : #include <algorithm>
       7             : #include <deque>
       8             : #include <memory>
       9             : #include <string>
      10             : #include <utility>
      11             : 
      12             : #include "elsaDefines.h"
      13             : #include "Common.h"
      14             : #include "TypeCasts.hpp"
      15             : 
      16             : namespace elsa::ml
      17             : {
      18             :     namespace detail
      19             :     {
      20             :         /// A node in a graph.
      21             :         ///
      22             :         /// @author David Tellenbach
      23             :         ///
      24             :         /// @tparam T type of the data this node holds
      25             :         /// @tparam UseRawPtr If this parameter is set to `true` data is stored
      26             :         /// as a raw pointer, i.e., `T*`. If this parameter is set to `false`
      27             :         /// data is stored as `std::shared_ptr<T>`.
      28             :         template <typename T, bool UseRawPtr = true>
      29             :         class Node
      30             :         {
      31             :         public:
      32             :             using PointerType = std::conditional_t<UseRawPtr, T*, std::shared_ptr<T>>;
      33             : 
      34             :             /// construct a node by specifying its index in the graph
      35          92 :             explicit Node(index_t index) : index_(index), data_(nullptr) {}
      36             : 
      37             :             /// return the index of this node
      38         302 :             inline index_t getIndex() const { return index_; }
      39             : 
      40             :             /// return a pointer to the data held by this node
      41          73 :             PointerType getData() { return data_; }
      42             : 
      43             :             /// return a constant pointer to the data held by this node
      44           4 :             const PointerType getData() const { return data_; }
      45             : 
      46             :             /// set data held by this node
      47          22 :             inline void setData(PointerType data) { data_ = data; }
      48             : 
      49             :         private:
      50             :             /// this node's index in the graph
      51             :             index_t index_;
      52             : 
      53             :             /// the data this node holds
      54             :             PointerType data_;
      55             :         };
      56             : 
      57             :         /// An edge between nodes of the graph
      58             :         ///
      59             :         /// @author David Tellenbach
      60             :         ///
      61             :         /// @tparam T type of the data this node holds
      62             :         /// @tparam UseRawPtr If this parameter is set to `true` data is stored
      63             :         /// as a raw pointer, i.e., `T*`. If this parameter is set to `false`
      64             :         /// data is stored as `std::shared_ptr<T>`.
      65             :         template <typename T, bool UseRawPtr = true>
      66             :         class Edge
      67             :         {
      68             :         public:
      69             :             /// the type of nodes this edge connects
      70             :             using NodeType = Node<T, UseRawPtr>;
      71             : 
      72             :             /// construct an Edge by specifying its begin and end
      73          46 :             Edge(NodeType* begin, NodeType* end) : begin_(begin), end_(end) {}
      74             : 
      75             :             /// get a constant pointer to the begin-node of this edge
      76         219 :             const inline NodeType* begin() const { return begin_; }
      77             : 
      78             :             /// get a pointer to the begin-node of this edge
      79          26 :             inline NodeType* begin() { return begin_; }
      80             : 
      81             :             /// get a constant pointer to the end-node of this edge
      82          29 :             const inline NodeType* end() const { return end_; }
      83             : 
      84             :             /// get a pointer to the end-node of this edge
      85          46 :             inline NodeType* end() { return end_; }
      86             : 
      87             :         private:
      88             :             /// the begin-node of this edge
      89             :             NodeType* begin_;
      90             : 
      91             :             /// the end-node of this edge
      92             :             NodeType* end_;
      93             :         };
      94             : 
      95             :         /// A graph that can be altered and traversed in an efficient and
      96             :         /// structured manner.
      97             :         ///
      98             :         /// @author David Tellenbach
      99             :         ///
     100             :         /// @tparam T type of the data this node holds
     101             :         /// @tparam UseRawPtr If this parameter is set to `true` data is stored
     102             :         /// as a raw pointer, i.e., `T*`. If this parameter is set to `false`
     103             :         /// data is stored as `std::shared_ptr<T>`.
     104             :         template <typename T, bool UseRawPtr = true>
     105             :         class Graph
     106             :         {
     107             :         public:
     108             :             using PointerType = std::conditional_t<UseRawPtr, T*, std::shared_ptr<T>>;
     109             : 
     110             :             /// type of the graph nodes
     111             :             using NodeType = Node<T, UseRawPtr>;
     112             : 
     113             :             /// type of edges of the graph
     114             :             using EdgeType = Edge<T, UseRawPtr>;
     115             : 
     116             :             /// default constructor
     117           8 :             Graph() = default;
     118             : 
     119             :             /// construct a Graph by specifying an adjacency list, i.e., a list
     120             :             /// of edges connecting nodes
     121           5 :             Graph(std::initializer_list<std::pair<index_t, index_t>> edges)
     122           5 :             {
     123          33 :                 for (const auto& edge : edges)
     124          28 :                     addEdge(edge.first, edge.second);
     125           5 :             }
     126             : 
     127             :             /// return the number of nodes
     128          30 :             inline index_t getNumberOfNodes() const { return static_cast<index_t>(nodes_.size()); }
     129             : 
     130             :             /// return a pointer to the data held by the node at index
     131          15 :             inline PointerType getData(index_t index) { return nodes_.at(index).getData(); }
     132             : 
     133             :             /// return a constant pointer to the data held by the node at index
     134             :             const inline PointerType getData(index_t index) const
     135             :             {
     136             :                 return nodes_.at(index).getData();
     137             :             }
     138             : 
     139             :             /// set data of node index
     140          21 :             inline void setData(index_t index, PointerType data)
     141             :             {
     142          21 :                 return nodes_.at(index).setData(data);
     143             :             }
     144             : 
     145             :             /// add edge from begin to end
     146          46 :             inline void addEdge(index_t begin, index_t end)
     147             :             {
     148          46 :                 nodes_.insert({begin, NodeType(begin)});
     149          46 :                 nodes_.insert({end, NodeType(end)});
     150          46 :                 edges_.emplace_back(EdgeType(&nodes_.at(begin), &nodes_.at(end)));
     151          46 :             }
     152             : 
     153             :             /// @returns a reference to a vector of edges
     154           0 :             inline std::vector<EdgeType>& getEdges() { return edges_; }
     155             : 
     156             :             /// @returns a constant reference to a vector of edges
     157             :             inline const std::vector<EdgeType>& getEdges() const { return edges_; }
     158             : 
     159             :             /// @returns a vector containing all edges beginning at a given
     160             :             /// index
     161          31 :             inline std::vector<EdgeType> getOutgoingEdges(index_t begin) const
     162             :             {
     163          31 :                 std::vector<EdgeType> ret;
     164         243 :                 for (const auto& edge : edges_)
     165         212 :                     if (edge.begin()->getIndex() == begin)
     166          49 :                         ret.push_back(edge);
     167             : 
     168          31 :                 return ret;
     169             :             }
     170             : 
     171             :             /// @returns a vector containing all edges ending at a given index
     172           4 :             inline std::vector<EdgeType> getIncomingEdges(index_t end) const
     173             :             {
     174           4 :                 std::vector<EdgeType> ret;
     175          23 :                 for (const auto& edge : edges_)
     176          19 :                     if (edge.end()->getIndex() == end)
     177           7 :                         ret.push_back(edge);
     178             : 
     179           4 :                 return ret;
     180             :             }
     181             : 
     182             :             /// Insert a new node after a node with a given index.
     183             :             ///
     184             :             /// @param index Index of the node after that the insertion
     185             :             /// should be done
     186             :             /// @param newIndex Index of the to be inserted node
     187           1 :             void insertNodeAfter(index_t index, index_t newIndex)
     188             :             {
     189           4 :                 for (auto& e : getOutgoingEdges(index))
     190           3 :                     addEdge(newIndex, e.end()->getIndex());
     191           1 :                 removeOutgoingEdges(index);
     192           1 :                 addEdge(index, newIndex);
     193           1 :             }
     194             : 
     195             :             /// @returns a reference to a map of indices and nodes
     196           0 :             inline std::map<index_t, NodeType>& getNodes() { return nodes_; }
     197             : 
     198             :             /// @returns a constant reference to a map of indices and nodes
     199             :             inline const std::map<index_t, NodeType>& getNodes() const { return nodes_; }
     200             : 
     201             :             /// delete all nodes and edges of the graph
     202           0 :             inline void clear()
     203             :             {
     204           0 :                 edges_.clear();
     205           0 :                 nodes_.clear();
     206           0 :             }
     207             : 
     208             :             /// Remove all outgoing edges from node with index begin
     209           3 :             inline void removeOutgoingEdges(index_t begin)
     210             :             {
     211           3 :                 edges_.erase(std::remove_if(std::begin(edges_), std::end(edges_),
     212          25 :                                             [&begin](auto& edge) {
     213          25 :                                                 return edge.begin()->getIndex() == begin;
     214             :                                             }),
     215           3 :                              std::end(edges_));
     216           3 :             }
     217             : 
     218             :             /// Remove all incoming edges from node with index end
     219           1 :             inline void removeIncomingEdges(index_t end)
     220             :             {
     221           2 :                 edges_.erase(
     222           1 :                     std::remove_if(std::begin(edges_), std::end(edges_),
     223           7 :                                    [&end](auto& edge) { return edge.end()->getIndex() == end; }),
     224           1 :                     std::end(edges_));
     225           1 :             }
     226             : 
     227             :             /// Remove a node from the graph.
     228             :             ///
     229             :             /// If preserveConnectivity is true, all begin-nodes of incoming
     230             :             /// edges get begin-nodes of end-nodes of outgoing edges.
     231           1 :             inline void removeNode(index_t idx, bool preserveConnectivity = true)
     232             :             {
     233           1 :                 if (preserveConnectivity) {
     234             :                     // Add an edge that bypassed the node that is to be removed
     235           2 :                     auto incomingEdges = getIncomingEdges(idx);
     236           2 :                     auto outgoingEdges = getOutgoingEdges(idx);
     237           4 :                     for (const auto& inEdge : incomingEdges) {
     238           3 :                         auto beginIdx = inEdge.begin()->getIndex();
     239           9 :                         for (const auto& outEdge : outgoingEdges) {
     240           6 :                             auto endIdx = outEdge.end()->getIndex();
     241           6 :                             addEdge(beginIdx, endIdx);
     242             :                         }
     243             :                     }
     244             :                 }
     245             :                 // Remove all edges of the node
     246           1 :                 removeOutgoingEdges(idx);
     247           1 :                 removeIncomingEdges(idx);
     248             : 
     249             :                 // Remove the node itself
     250           1 :                 nodes_.erase(idx);
     251           1 :             }
     252             : 
     253             :             /// Visit all nodes of the graph. The visitors have access the node
     254             :             /// indices
     255             :             ///
     256             :             /// @param root Index of the node that serves as starting point
     257             :             /// of the traversal.
     258             :             ///
     259             :             /// @param visitor A function-like object (function, overloaded
     260             :             /// call operator, lambda) with the signature
     261             :             ///   `void(T* data, index_t index)`.
     262             :             /// The visitor will be applied to every node in a breadth-first
     263             :             /// traversal. The semantics of the parameter is
     264             :             /// @param nextVisitor
     265             :             /// @param stop
     266             :             ///
     267             :             /// - `data` is the data held by the current node in the traversal
     268             :             /// - `index` is the index of the current node
     269             :             template <typename Visitor, typename NextVisitor, typename StopFunctor>
     270           0 :             void visitWithIndex(index_t root, Visitor visitor, NextVisitor nextVisitor,
     271             :                                 StopFunctor stop)
     272             :             {
     273             :                 visitImpl<Visitor, NextVisitor, StopFunctor, /* access index */ true,
     274           0 :                           /* forward */ true>(root, std::forward<Visitor>(visitor),
     275           0 :                                               std::forward<NextVisitor>(nextVisitor),
     276           0 :                                               std::forward<StopFunctor>(stop));
     277           0 :             }
     278             : 
     279             :             template <typename Visitor, typename NextVisitor, typename StopFunctor>
     280           3 :             void visit(index_t root, Visitor visitor, NextVisitor nextVisitor, StopFunctor stop)
     281             :             {
     282             :                 visitImpl<Visitor, NextVisitor, StopFunctor, /* access index */ false,
     283           3 :                           /* forward */ true>(root, std::forward<Visitor>(visitor),
     284           3 :                                               std::forward<NextVisitor>(nextVisitor),
     285           3 :                                               std::forward<StopFunctor>(stop));
     286           3 :             }
     287             : 
     288             :             template <typename Visitor, typename NextVisitor>
     289           0 :             void visit(index_t root, Visitor visitor, NextVisitor nextVisitor)
     290             :             {
     291           0 :                 visit(root, visitor, nextVisitor, []([[maybe_unused]] auto node) { return false; });
     292           0 :             }
     293             : 
     294             :             template <typename Visitor>
     295           3 :             void visit(index_t root, Visitor visitor)
     296             :             {
     297          16 :                 visit(
     298             :                     root, visitor,
     299          13 :                     []([[maybe_unused]] auto node, [[maybe_unused]] auto nextNode) {},
     300          16 :                     []([[maybe_unused]] auto node) { return false; });
     301           3 :             }
     302             : 
     303             :             template <typename Visitor, typename NextVisitor, typename StopFunctor>
     304           0 :             void visitBackward(index_t root, Visitor visitor, NextVisitor nextVisitor,
     305             :                                StopFunctor stop)
     306             :             {
     307             :                 visitImpl<Visitor, NextVisitor, StopFunctor, /* access index */ false,
     308           0 :                           /* forward */ false>(root, std::forward<Visitor>(visitor),
     309           0 :                                                std::forward<NextVisitor>(nextVisitor),
     310           0 :                                                std::forward<StopFunctor>(stop));
     311           0 :             }
     312             : 
     313             :             template <typename Visitor, typename NextVisitor, typename StopFunctor>
     314           0 :             void visitBackwardWithIndex(index_t root, Visitor visitor, NextVisitor nextVisitor,
     315             :                                         StopFunctor stop)
     316             :             {
     317             :                 visitImpl<Visitor, NextVisitor, StopFunctor, /* access index */ true,
     318           0 :                           /* forward */ false>(root, std::forward<Visitor>(visitor),
     319           0 :                                                std::forward<NextVisitor>(nextVisitor),
     320           0 :                                                std::forward<StopFunctor>(stop));
     321           0 :             }
     322             : 
     323             :             enum class RankDir { TD, LR };
     324             : 
     325             :             /// Print a representation of the graph in the Dot language.
     326             :             ///
     327             :             /// @param filename the name of the file that is written to
     328             :             /// @param nodePrinter a function (or lambda) that defines how
     329             :             ///                        to print the data of a single node
     330             :             /// @param dpi Dots-per-inch definition for the Dot representation
     331             :             /// @param rankDir RankDir::TD if the graph should be drawn top-down,
     332             :             ///                    rankdDir::LR if it should be drawn left-right
     333             :             template <typename NodePrinter>
     334             :             void toDot(const std::string& filename, NodePrinter nodePrinter, index_t dpi = 90,
     335             :                        RankDir rankDir = RankDir::TD)
     336             :             {
     337             :                 std::ofstream os(filename);
     338             :                 os << "digraph {\n";
     339             :                 os << "graph [" << (rankDir == RankDir::TD ? "rankdir=TD" : "rankdir=LR")
     340             :                    << ", dpi=" << dpi << "];\n";
     341             : 
     342             :                 // print all nodes
     343             :                 for (const auto& node : getNodes()) {
     344             :                     if (node.second.getData()) {
     345             :                         os << nodePrinter(node.second.getData(), node.first) << "\n";
     346             :                     } else {
     347             :                         os << node.first << "\n";
     348             :                     }
     349             :                 }
     350             : 
     351             :                 // print all edges
     352             :                 for (const auto& edge : getEdges())
     353             :                     os << edge.begin()->getIndex() << "->" << edge.end()->getIndex() << ";\n";
     354             : 
     355             :                 os << "}";
     356             :             }
     357             : 
     358             :         private:
     359             :             template <typename Visitor, typename NextVisitor, typename StopFunctor,
     360             :                       bool AccessIndex, bool Forward>
     361           3 :             void visitImpl(index_t root, Visitor visitor, NextVisitor nextVisitor, StopFunctor stop)
     362             :             {
     363           3 :                 if (nodes_.find(root) == nodes_.end())
     364           0 :                     throw std::invalid_argument("Unknown node index");
     365             : 
     366             :                 // We maintain a list of nodes we've already visited and a call-queue to
     367             :                 // ensure the correct order of traversal.
     368           6 :                 std::vector<bool> visited(asUnsigned(getNumberOfNodes()));
     369             :                 // Note that this queue is in fact a deque, so we can push and pop from
     370             :                 // both, the front and back.
     371           6 :                 std::deque<index_t> queue;
     372             : 
     373             :                 // Perform an iterative depth-first traversal through the graph
     374             :                 // Push the input-node onto the call-queue
     375           3 :                 queue.push_back(root);
     376             : 
     377          19 :                 while (!queue.empty()) {
     378          16 :                     index_t s = queue.back();
     379          16 :                     queue.pop_back();
     380             : 
     381          16 :                     if (!visited[static_cast<std::size_t>(s)]) {
     382             : 
     383             :                         bool needStop;
     384             :                         if constexpr (AccessIndex) {
     385           0 :                             needStop = stop(nodes_.at(s).getData(), s);
     386             :                         } else {
     387          16 :                             needStop = stop(nodes_.at(s).getData());
     388             :                         }
     389          16 :                         if (!needStop) {
     390             :                             if constexpr (AccessIndex) {
     391           0 :                                 visitor(nodes_.at(s).getData(), s);
     392             :                             } else {
     393          16 :                                 visitor(nodes_.at(s).getData());
     394             :                             }
     395          16 :                             visited[asUnsigned(s)] = true;
     396             :                         } else {
     397           0 :                             queue.push_front(s);
     398           0 :                             continue;
     399             :                         }
     400             :                     }
     401             : 
     402             :                     if constexpr (Forward) {
     403          35 :                         for (auto& e : getOutgoingEdges(s)) {
     404          19 :                             auto idx = e.end()->getIndex();
     405             : 
     406          19 :                             if (!visited[static_cast<std::size_t>(idx)]) {
     407          13 :                                 queue.push_back(idx);
     408             :                                 if constexpr (AccessIndex) {
     409           0 :                                     nextVisitor(nodes_.at(s).getData(), s, e.end()->getData(), idx);
     410             :                                 } else {
     411          13 :                                     nextVisitor(nodes_.at(s).getData(), e.end()->getData());
     412             :                                 }
     413             :                             }
     414             :                         }
     415             :                     } else {
     416           0 :                         for (auto& e : getIncomingEdges(s)) {
     417           0 :                             auto idx = e.begin()->getIndex();
     418             : 
     419           0 :                             if (!visited[static_cast<std::size_t>(idx)]) {
     420           0 :                                 queue.push_back(idx);
     421             :                                 if constexpr (AccessIndex) {
     422           0 :                                     nextVisitor(nodes_.at(s).getData(), s, e.begin()->getData(),
     423             :                                                 idx);
     424             :                                 } else {
     425           0 :                                     nextVisitor(nodes_.at(s).getData(), e.begin()->getData());
     426             :                                 }
     427             :                             }
     428             :                         }
     429             :                     }
     430             :                 }
     431           3 :             }
     432             : 
     433             :             std::vector<EdgeType> edges_;
     434             :             std::map<index_t, NodeType> nodes_;
     435             :         };
     436             : 
     437             :     } // namespace detail
     438             : } // namespace elsa::ml

Generated by: LCOV version 1.15