LCOV - code coverage report
Current view: top level - ml/backend/Dnnl/tests - test_DnnlMerging.cpp (source / functions) Hit Total Coverage
Test: test_coverage.info.cleaned Lines: 59 59 100.0 %
Date: 2022-02-28 03:37:41 Functions: 1 1 100.0 %

          Line data    Source code
       1             : /**
       2             :  * @file test_common.cpp
       3             :  *
       4             :  * @brief Tests for common ml functionality
       5             :  *
       6             :  * @author David Tellenbach
       7             :  */
       8             : 
       9             : #include "doctest/doctest.h"
      10             : #include <iostream>
      11             : 
      12             : #include "DataContainer.h"
      13             : #include "VolumeDescriptor.h"
      14             : #include "DnnlMerging.h"
      15             : 
      16             : using namespace elsa;
      17             : using namespace elsa::ml;
      18             : using namespace elsa::ml::detail;
      19             : using namespace doctest;
      20             : 
      21             : TEST_SUITE_BEGIN("ml-dnnl");
      22             : 
      23             : // TODO(dfrank): remove and replace with proper doctest usage of test cases
      24             : #define SECTION(name) DOCTEST_SUBCASE(name)
      25             : 
      26           2 : TEST_CASE("DnnlMerging: DnnlSum")
      27             : {
      28           2 :     index_t N = 11;
      29           2 :     index_t C = 10;
      30           2 :     index_t H = 9;
      31           2 :     index_t W = 8;
      32             : 
      33           4 :     IndexVector_t dims{{W, H, C, N}};
      34           4 :     VolumeDescriptor desc(dims);
      35             : 
      36           4 :     Eigen::VectorXf data(desc.getNumberOfCoefficients());
      37             : 
      38           2 :     data.setRandom();
      39           4 :     DataContainer<real_t> dc0(desc, data);
      40             : 
      41           2 :     data.setRandom();
      42           4 :     DataContainer<real_t> dc1(desc, data);
      43             : 
      44           2 :     data.setRandom();
      45           4 :     DataContainer<real_t> dc2(desc, data);
      46             : 
      47           4 :     IndexVector_t nchw_dims{{N, C, H, W}};
      48           4 :     VolumeDescriptor nchw_desc(nchw_dims);
      49          12 :     DnnlSum<real_t> sum({nchw_desc, nchw_desc, nchw_desc}, nchw_desc);
      50             : 
      51           2 :     REQUIRE(sum.canMerge() == true);
      52           2 :     REQUIRE(sum.isTrainable() == false);
      53           2 :     REQUIRE(sum.needsForwardSynchronisation() == true);
      54             : 
      55           2 :     sum.setInput(dc0, 0);
      56           2 :     sum.setInput(dc1, 1);
      57           2 :     sum.setInput(dc2, 2);
      58             : 
      59           2 :     REQUIRE(sum.getNumberOfInputs() == 3);
      60             : 
      61           4 :     auto engine = sum.getEngine();
      62           4 :     dnnl::stream s(*engine);
      63             : 
      64           3 :     SECTION("Forward propagate")
      65             :     {
      66           1 :         sum.compile(PropagationKind::Forward);
      67             : 
      68           1 :         sum.forwardPropagate(s);
      69             : 
      70           2 :         auto output = sum.getOutput();
      71           1 :         REQUIRE(output.getDataDescriptor().getNumberOfCoefficientsPerDimension()[0] == W);
      72           1 :         REQUIRE(output.getDataDescriptor().getNumberOfCoefficientsPerDimension()[1] == H);
      73           1 :         REQUIRE(output.getDataDescriptor().getNumberOfCoefficientsPerDimension()[2] == C);
      74           1 :         REQUIRE(output.getDataDescriptor().getNumberOfCoefficientsPerDimension()[3] == N);
      75             : 
      76          12 :         for (int n = 0; n < N; ++n) {
      77         121 :             for (int c = 0; c < C; ++c) {
      78        1100 :                 for (int h = 0; h < H; ++h) {
      79        8910 :                     for (int w = 0; w < W; ++w) {
      80        7920 :                         REQUIRE(output(w, h, c, n)
      81             :                                 == dc0(w, h, c, n) + dc1(w, h, c, n) + dc2(w, h, c, n));
      82             :                     }
      83             :                 }
      84             :             }
      85             :         }
      86             :     }
      87             : 
      88           3 :     SECTION("Backward propagate")
      89             :     {
      90           2 :         Eigen::VectorXf outputGradientData(sum.getOutputDescriptor().getNumberOfCoefficients());
      91           1 :         outputGradientData.setRandom();
      92           2 :         DataContainer<real_t> outputGradient(desc, outputGradientData);
      93           1 :         sum.setOutputGradient(outputGradient);
      94             : 
      95           1 :         sum.compile(PropagationKind::Backward);
      96           1 :         sum.backwardPropagate(s);
      97             : 
      98           2 :         auto inputGradient0 = sum.getInputGradient(0);
      99           2 :         auto inputGradient1 = sum.getInputGradient(1);
     100           2 :         auto inputGradient2 = sum.getInputGradient(2);
     101             : 
     102        7921 :         for (int i = 0; i < inputGradient0.getSize(); ++i) {
     103        7920 :             REQUIRE(inputGradient0[i] == dc0[i] * outputGradient[i]);
     104             :         }
     105             : 
     106          12 :         for (int n = 0; n < N; ++n) {
     107         121 :             for (int c = 0; c < C; ++c) {
     108        1100 :                 for (int h = 0; h < H; ++h) {
     109        8910 :                     for (int w = 0; w < W; ++w) {
     110        7920 :                         REQUIRE(inputGradient0(w, h, c, n)
     111             :                                 == dc0(w, h, c, n) * outputGradient(w, h, c, n));
     112        7920 :                         REQUIRE(inputGradient1(w, h, c, n)
     113             :                                 == dc1(w, h, c, n) * outputGradient(w, h, c, n));
     114        7920 :                         REQUIRE(inputGradient2(w, h, c, n)
     115             :                                 == dc2(w, h, c, n) * outputGradient(w, h, c, n));
     116             :                     }
     117             :                 }
     118             :             }
     119             :         }
     120             :     }
     121           2 : }
     122             : TEST_SUITE_END();

Generated by: LCOV version 1.15