Line data Source code
1 : /**
2 : * @file test_common.cpp
3 : *
4 : * @brief Tests for common ml functionality
5 : *
6 : * @author David Tellenbach
7 : */
8 :
9 : #include "doctest/doctest.h"
10 : #include <iostream>
11 :
12 : #include "DataContainer.h"
13 : #include "VolumeDescriptor.h"
14 : #include "DnnlMerging.h"
15 :
16 : using namespace elsa;
17 : using namespace elsa::ml;
18 : using namespace elsa::ml::detail;
19 : using namespace doctest;
20 :
21 : TEST_SUITE_BEGIN("ml-dnnl");
22 :
23 : // TODO(dfrank): remove and replace with proper doctest usage of test cases
24 : #define SECTION(name) DOCTEST_SUBCASE(name)
25 :
26 2 : TEST_CASE("DnnlMerging: DnnlSum")
27 : {
28 2 : index_t N = 11;
29 2 : index_t C = 10;
30 2 : index_t H = 9;
31 2 : index_t W = 8;
32 :
33 4 : IndexVector_t dims{{W, H, C, N}};
34 4 : VolumeDescriptor desc(dims);
35 :
36 4 : Eigen::VectorXf data(desc.getNumberOfCoefficients());
37 :
38 2 : data.setRandom();
39 4 : DataContainer<real_t> dc0(desc, data);
40 :
41 2 : data.setRandom();
42 4 : DataContainer<real_t> dc1(desc, data);
43 :
44 2 : data.setRandom();
45 4 : DataContainer<real_t> dc2(desc, data);
46 :
47 4 : IndexVector_t nchw_dims{{N, C, H, W}};
48 4 : VolumeDescriptor nchw_desc(nchw_dims);
49 12 : DnnlSum<real_t> sum({nchw_desc, nchw_desc, nchw_desc}, nchw_desc);
50 :
51 2 : REQUIRE(sum.canMerge() == true);
52 2 : REQUIRE(sum.isTrainable() == false);
53 2 : REQUIRE(sum.needsForwardSynchronisation() == true);
54 :
55 2 : sum.setInput(dc0, 0);
56 2 : sum.setInput(dc1, 1);
57 2 : sum.setInput(dc2, 2);
58 :
59 2 : REQUIRE(sum.getNumberOfInputs() == 3);
60 :
61 4 : auto engine = sum.getEngine();
62 4 : dnnl::stream s(*engine);
63 :
64 3 : SECTION("Forward propagate")
65 : {
66 1 : sum.compile(PropagationKind::Forward);
67 :
68 1 : sum.forwardPropagate(s);
69 :
70 2 : auto output = sum.getOutput();
71 1 : REQUIRE(output.getDataDescriptor().getNumberOfCoefficientsPerDimension()[0] == W);
72 1 : REQUIRE(output.getDataDescriptor().getNumberOfCoefficientsPerDimension()[1] == H);
73 1 : REQUIRE(output.getDataDescriptor().getNumberOfCoefficientsPerDimension()[2] == C);
74 1 : REQUIRE(output.getDataDescriptor().getNumberOfCoefficientsPerDimension()[3] == N);
75 :
76 12 : for (int n = 0; n < N; ++n) {
77 121 : for (int c = 0; c < C; ++c) {
78 1100 : for (int h = 0; h < H; ++h) {
79 8910 : for (int w = 0; w < W; ++w) {
80 7920 : REQUIRE(output(w, h, c, n)
81 : == dc0(w, h, c, n) + dc1(w, h, c, n) + dc2(w, h, c, n));
82 : }
83 : }
84 : }
85 : }
86 : }
87 :
88 3 : SECTION("Backward propagate")
89 : {
90 2 : Eigen::VectorXf outputGradientData(sum.getOutputDescriptor().getNumberOfCoefficients());
91 1 : outputGradientData.setRandom();
92 2 : DataContainer<real_t> outputGradient(desc, outputGradientData);
93 1 : sum.setOutputGradient(outputGradient);
94 :
95 1 : sum.compile(PropagationKind::Backward);
96 1 : sum.backwardPropagate(s);
97 :
98 2 : auto inputGradient0 = sum.getInputGradient(0);
99 2 : auto inputGradient1 = sum.getInputGradient(1);
100 2 : auto inputGradient2 = sum.getInputGradient(2);
101 :
102 7921 : for (int i = 0; i < inputGradient0.getSize(); ++i) {
103 7920 : REQUIRE(inputGradient0[i] == dc0[i] * outputGradient[i]);
104 : }
105 :
106 12 : for (int n = 0; n < N; ++n) {
107 121 : for (int c = 0; c < C; ++c) {
108 1100 : for (int h = 0; h < H; ++h) {
109 8910 : for (int w = 0; w < W; ++w) {
110 7920 : REQUIRE(inputGradient0(w, h, c, n)
111 : == dc0(w, h, c, n) * outputGradient(w, h, c, n));
112 7920 : REQUIRE(inputGradient1(w, h, c, n)
113 : == dc1(w, h, c, n) * outputGradient(w, h, c, n));
114 7920 : REQUIRE(inputGradient2(w, h, c, n)
115 : == dc2(w, h, c, n) * outputGradient(w, h, c, n));
116 : }
117 : }
118 : }
119 : }
120 : }
121 2 : }
122 : TEST_SUITE_END();
|