LCOV - code coverage report
Current view: top level - ml/backend/Dnnl/tests - test_DnnlPooling.cpp (source / functions) Hit Total Coverage
Test: test_coverage.info.cleaned Lines: 40 40 100.0 %
Date: 2022-02-28 03:37:41 Functions: 1 1 100.0 %

          Line data    Source code
       1             : #include "doctest/doctest.h"
       2             : #include <type_traits>
       3             : #include <random>
       4             : #include <iostream>
       5             : 
       6             : #include "elsaDefines.h"
       7             : #include "DataDescriptor.h"
       8             : #include "DnnlPoolingLayer.h"
       9             : 
      10             : using namespace elsa;
      11             : using namespace elsa::ml;
      12             : using namespace elsa::ml::detail;
      13             : using namespace doctest;
      14             : 
      15             : TEST_SUITE_BEGIN("ml-dnnl");
      16             : 
      17             : // TODO(dfrank): remove and replace with proper doctest usage of test cases
      18             : #define SECTION(name) DOCTEST_SUBCASE(name)
      19             : 
      20           1 : TEST_CASE("DnnlPoooling")
      21             : {
      22             :     // Example from http://cs231n.github.io/convolutional-networks/
      23             : 
      24             :     // Create input data
      25           2 :     IndexVector_t inputVec(4);
      26           1 :     inputVec << 1, 1, 4, 4;
      27           2 :     VolumeDescriptor inputDesc(inputVec);
      28             : 
      29           2 :     Eigen::VectorXf vec(1 * 1 * 4 * 4);
      30             :     // clang-format off
      31           2 :         vec << 1, 1, 2, 4,
      32           1 :                5, 6, 7, 8,
      33           1 :                3, 2, 1, 0,
      34           2 :                1, 2, 3, 4;
      35             :     // clang-format on
      36             : 
      37           2 :     DataContainer<float> input(inputDesc, vec);
      38             : 
      39             :     // Output descriptor
      40           2 :     IndexVector_t outVec(4);
      41           1 :     outVec << 1, 1, 2, 2;
      42           2 :     VolumeDescriptor outDesc(outVec);
      43             : 
      44             :     // Strides and pooling-window
      45           2 :     IndexVector_t stridesVec(2);
      46           1 :     stridesVec << 2, 2;
      47             : 
      48           2 :     IndexVector_t poolingVec(2);
      49           1 :     poolingVec << 2, 2;
      50             : 
      51           2 :     DnnlPoolingLayer<float> layer(inputDesc, outDesc, poolingVec, stridesVec);
      52             : 
      53             :     // Set input and compile layer
      54           1 :     layer.setInput(input);
      55           1 :     layer.compile(PropagationKind::Full);
      56             : 
      57             :     // Get Dnnl exection-stream
      58           2 :     auto engine = layer.getEngine();
      59           2 :     dnnl::stream s(*engine);
      60             : 
      61           1 :     layer.forwardPropagate(s);
      62           2 :     auto output = layer.getOutput();
      63             : 
      64           2 :     Eigen::VectorXf required(1 * 1 * 2 * 2);
      65             :     // clang-format off
      66           2 :     required << 6, 8,
      67           2 :                 3, 4;
      68             :     // clang-format on
      69             : 
      70           5 :     for (int i = 0; i < 4; ++i)
      71           4 :         REQUIRE(output[i] == required[i]);
      72             : 
      73           1 :     layer.setOutputGradient(output);
      74             : 
      75           1 :     layer.backwardPropagate(s);
      76           2 :     auto inputGradient = layer.getInputGradient();
      77             : 
      78           2 :     Eigen::VectorXf requiredGradientInput(4 * 4);
      79             :     // clang-format off
      80           2 :     requiredGradientInput << 0, 0, 0, 0,
      81           1 :                              0, 6, 0, 8,
      82           1 :                              3, 0, 0, 0,
      83           2 :                              0, 0, 0, 4;
      84             :     // clang-format on
      85             : 
      86          17 :     for (int i = 0; i < inputGradient.getSize(); ++i) {
      87          16 :         REQUIRE(inputGradient[i] == requiredGradientInput[i]);
      88             :     }
      89           1 : }
      90             : TEST_SUITE_END();

Generated by: LCOV version 1.15