LCOV - code coverage report
Current view: top level - ml/tests - test_CoreLayers.cpp (source / functions) Hit Total Coverage
Test: test_coverage.info.cleaned Lines: 33 33 100.0 %
Date: 2022-02-28 03:37:41 Functions: 3 3 100.0 %

          Line data    Source code
       1             : #include "doctest/doctest.h"
       2             : #include "Input.h"
       3             : #include "Dense.h"
       4             : 
       5             : using namespace elsa;
       6             : using namespace doctest;
       7             : 
       8             : TEST_SUITE_BEGIN("ml");
       9             : 
      10             : // TODO(dfrank): remove and replace with proper doctest usage of test cases
      11             : #define SECTION(name) DOCTEST_SUBCASE(name)
      12             : 
      13           4 : TEST_CASE_TEMPLATE("CoreLayers", TestType, float)
      14             : {
      15           3 :     SECTION("Input")
      16             :     {
      17           2 :         IndexVector_t dims{{1, 2, 3}};
      18           2 :         VolumeDescriptor desc(dims);
      19           2 :         auto i = ml::Input<TestType>(desc, 10, "input");
      20           1 :         REQUIRE(i.getInputDescriptor() == desc);
      21           1 :         i.computeOutputDescriptor();
      22           1 :         REQUIRE(i.getOutputDescriptor() == desc);
      23           1 :         REQUIRE(i.getLayerType() == ml::LayerType::Input);
      24           1 :         REQUIRE(i.getBatchSize() == 10);
      25           1 :         REQUIRE(i.getName() == "input");
      26           1 :         REQUIRE(i.getNumberOfTrainableParameters() == 0);
      27             :     }
      28           3 :     SECTION("Dense")
      29             :     {
      30           3 :         auto d =
      31             :             ml::Dense<TestType>(10, ml::Activation::Relu, false, ml::Initializer::GlorotUniform,
      32             :                                 ml::Initializer::Zeros, "dense");
      33             : 
      34             :         // Check number of units
      35           1 :         REQUIRE(d.getNumberOfUnits() == 10);
      36             : 
      37             :         // Check activation
      38           1 :         REQUIRE(d.getActivation() == ml::Activation::Relu);
      39             : 
      40             :         // Check if we are using a bias
      41           1 :         REQUIRE(d.useBias() == false);
      42             : 
      43             :         // Check initializers
      44           1 :         REQUIRE(d.getKernelInitializer() == ml::Initializer::GlorotUniform);
      45           1 :         REQUIRE(d.getBiasInitializer() == ml::Initializer::Zeros);
      46             : 
      47             :         // Check the name
      48           1 :         REQUIRE(d.getName() == "dense");
      49             : 
      50             :         // Check input-descriptor
      51           2 :         IndexVector_t validDims{{1}};
      52           2 :         VolumeDescriptor validDesc(validDims);
      53           1 :         d.setInputDescriptor(validDesc);
      54           1 :         REQUIRE(d.getInputDescriptor() == validDesc);
      55             : 
      56           1 :         d.computeOutputDescriptor();
      57           2 :         IndexVector_t outDims{{10}};
      58           2 :         VolumeDescriptor outDesc(outDims);
      59           1 :         REQUIRE(d.getOutputDescriptor() == outDesc);
      60             : 
      61             :         // Note here that we don't use a bias
      62           1 :         REQUIRE(d.getNumberOfTrainableParameters() == 1 * d.getNumberOfUnits());
      63             : 
      64             :         // Dense requires a 1D input
      65           2 :         IndexVector_t invalidDims{{1, 10}};
      66           2 :         VolumeDescriptor invalidDesc(invalidDims);
      67           2 :         REQUIRE_THROWS_AS(d.setInputDescriptor(invalidDesc), std::invalid_argument);
      68             :     }
      69           2 : }
      70             : TEST_SUITE_END();

Generated by: LCOV version 1.15