Line data Source code
1 : #include "doctest/doctest.h"
2 :
3 : #include "Graph.h"
4 : #include <numeric>
5 :
6 : using namespace elsa;
7 : using namespace doctest;
8 :
9 : TEST_SUITE_BEGIN("ml");
10 :
11 : // TODO(dfrank): remove and replace with proper doctest usage of test cases
12 : #define SECTION(name) DOCTEST_SUBCASE(name)
13 :
14 5 : TEST_CASE("Graph")
15 : {
16 6 : SECTION("Graph of int")
17 : {
18 : // Construct a graph
19 2 : ml::detail::Graph<float> g({{0, 1}, {1, 2}, {0, 2}, {2, 0}});
20 1 : REQUIRE(g.getNumberOfNodes() == 3);
21 :
22 : // Adding an edge including a new node increases the overall number of nodes
23 1 : g.addEdge(1, 3);
24 1 : REQUIRE(g.getNumberOfNodes() == 4);
25 :
26 : // Adding an edge that doesn't introduce a new node doesn't increase the number of nodes
27 1 : g.addEdge(0, 1);
28 1 : REQUIRE(g.getNumberOfNodes() == 4);
29 :
30 : // Add some data
31 2 : std::vector<float> v(static_cast<std::size_t>(g.getNumberOfNodes()));
32 1 : std::generate(std::begin(v), std::end(v),
33 4 : []() { return Eigen::internal::random<float>(); });
34 :
35 5 : for (index_t i = 0; i < g.getNumberOfNodes(); ++i) {
36 4 : g.setData(i, &v[std::size_t(i)]);
37 : }
38 :
39 5 : for (index_t i = 0; i < g.getNumberOfNodes(); ++i) {
40 4 : REQUIRE(*g.getData(i) == Approx(v[std::size_t(i)]));
41 : }
42 :
43 : // Get all edges from given node
44 2 : auto edges = g.getOutgoingEdges(1);
45 1 : REQUIRE(edges.size() == 2);
46 3 : for (const auto& e : edges) {
47 2 : REQUIRE(e.begin()->getIndex() == 1);
48 2 : REQUIRE(*e.begin()->getData() == v[1]);
49 : }
50 :
51 : // Get all edges into a given node
52 1 : auto inedges = g.getIncomingEdges(2);
53 1 : REQUIRE(inedges.size() == 2);
54 3 : for (const auto& e : inedges) {
55 2 : REQUIRE(e.end()->getIndex() == 2);
56 2 : REQUIRE(*e.end()->getData() == v[2]);
57 : }
58 :
59 : // Changing the received edge list changes nodes in the graph
60 1 : float t = 123.456f;
61 1 : edges[0].begin()->setData(&t);
62 :
63 1 : REQUIRE(g.getData(1) == &t);
64 1 : REQUIRE(*g.getData(1) == t);
65 : }
66 :
67 6 : SECTION("Graph Visitor")
68 : {
69 2 : ml::detail::Graph<int> g({{0, 1}, {1, 2}, {0, 2}, {2, 0}});
70 2 : std::vector<int> v({0, 1, 2});
71 1 : std::vector<int> cv(v);
72 4 : for (index_t i = 0; i < g.getNumberOfNodes(); ++i) {
73 3 : g.setData(i, &v[std::size_t(i)]);
74 : }
75 4 : for (index_t i = 0; i < g.getNumberOfNodes(); ++i) {
76 3 : REQUIRE(*g.getData(i) == v[std::size_t(i)]);
77 : }
78 1 : g.visit(0, [](auto node) { *node = *node + 1; });
79 4 : for (index_t i = 0; i < g.getNumberOfNodes(); ++i) {
80 3 : REQUIRE(*g.getData(i) == v[std::size_t(i)]);
81 3 : REQUIRE(*g.getData(i) == cv[std::size_t(i)] + 1);
82 : }
83 :
84 1 : int sum = 0;
85 1 : g.visit(0, [&sum](auto node) { sum += *node; });
86 1 : REQUIRE(sum == std::accumulate(std::begin(v), std::end(v), 0));
87 : }
88 6 : SECTION("Remove edges and nodes")
89 : {
90 1 : ml::detail::Graph<int> g({{0, 1}, {1, 2}, {0, 2}, {2, 0}});
91 :
92 : // There is one outgoing edge from 2, i.e., edge (0,2)
93 1 : REQUIRE(g.getOutgoingEdges(2).size() == 1);
94 :
95 : // Add two more
96 1 : g.addEdge(2, 1);
97 1 : g.addEdge(2, 2);
98 1 : REQUIRE(g.getOutgoingEdges(2).size() == 3);
99 :
100 : // Remove all outgoing edges of 2
101 1 : g.removeOutgoingEdges(2);
102 1 : REQUIRE(g.getOutgoingEdges(2).size() == 0);
103 :
104 : // Add edges again
105 1 : g.addEdge(2, 1);
106 1 : g.addEdge(2, 2);
107 1 : REQUIRE(g.getOutgoingEdges(2).size() == 2);
108 :
109 : // Remove node
110 1 : g.removeNode(2);
111 1 : REQUIRE(g.getOutgoingEdges(2).size() == 0);
112 1 : REQUIRE(g.getIncomingEdges(2).size() == 0);
113 1 : REQUIRE(g.getNumberOfNodes() == 2);
114 : }
115 6 : SECTION("Insert node")
116 : {
117 : // Build a graph with a central node 2 that is reachable from two nodes
118 : // 0 and 1 and that has outgoing edges to three nodes 3, 4 and 5
119 1 : ml::detail::Graph<int> g({{0, 2}, {1, 2}, {2, 3}, {2, 4}, {2, 5}});
120 1 : REQUIRE(g.getOutgoingEdges(2).size() == 3);
121 1 : REQUIRE(g.getIncomingEdges(2).size() == 2);
122 :
123 : // Insert a node after 2 with index 6. This node should have exactly one
124 : // incoming edge (2, 6) and should overtake all outgoing edges of node
125 : // 2, i.e., it should have the edges (6, 3), (6, 4) and (6, 5).
126 : // Node 2 should have no other edges than (2, 6) anymore.
127 1 : g.insertNodeAfter(2, 6);
128 :
129 1 : REQUIRE(g.getOutgoingEdges(2).size() == 1);
130 1 : REQUIRE(g.getOutgoingEdges(2).front().end()->getIndex() == 6);
131 1 : REQUIRE(g.getOutgoingEdges(6).size() == 3);
132 1 : REQUIRE(g.getOutgoingEdges(6)[0].end()->getIndex() == 3);
133 1 : REQUIRE(g.getOutgoingEdges(6)[1].end()->getIndex() == 4);
134 1 : REQUIRE(g.getOutgoingEdges(6)[2].end()->getIndex() == 5);
135 : }
136 6 : SECTION("More visitors")
137 : {
138 : ml::detail::Graph<int> g({{0, 1},
139 : {1, 2},
140 : {1, 3},
141 : {2, 4},
142 : {3, 4},
143 : {4, 5},
144 : {5, 6},
145 : {5, 7},
146 : {6, 9},
147 : {7, 8},
148 2 : {8, 9}});
149 :
150 1 : std::vector<int> v({0, 1, 2, 3, 4, 5, 6, 7, 8, 9});
151 11 : for (int i = 0; i < 10; ++i)
152 10 : g.setData(i, &v[asUnsigned(i)]);
153 :
154 1 : int sum = 0;
155 :
156 1 : g.visit(0, [&sum](auto node) { sum += *node; });
157 :
158 1 : REQUIRE(sum == std::accumulate(v.begin(), v.end(), 0));
159 : }
160 5 : }
161 : TEST_SUITE_END();
|