LCOV - code coverage report
Current view: top level - ml/tests - test_Utils.cpp (source / functions) Hit Total Coverage
Test: test_coverage.info.cleaned Lines: 25 25 100.0 %
Date: 2022-02-28 03:37:41 Functions: 1 1 100.0 %

          Line data    Source code
       1             : /**
       2             :  * @file test_common.cpp
       3             :  *
       4             :  * @brief Tests for common ml functionality
       5             :  *
       6             :  * @author David Tellenbach
       7             :  */
       8             : 
       9             : #include "doctest/doctest.h"
      10             : 
      11             : #include "elsaDefines.h"
      12             : #include "DataContainer.h"
      13             : #include "VolumeDescriptor.h"
      14             : #include "Utils.h"
      15             : 
      16             : using namespace elsa;
      17             : using namespace doctest;
      18             : 
      19             : TEST_SUITE_BEGIN("ml");
      20             : 
      21             : // TODO(dfrank): remove and replace with proper doctest usage of test cases
      22             : #define SECTION(name) DOCTEST_SUBCASE(name)
      23             : 
      24           2 : TEST_CASE("Encoding")
      25             : {
      26           3 :     SECTION("Encode one-hot")
      27             :     {
      28           1 :         index_t numClasses = 10;
      29           1 :         index_t batchSize = 4;
      30           2 :         IndexVector_t dims{{batchSize}};
      31           2 :         VolumeDescriptor desc(dims);
      32             : 
      33             :         // For each entry in a batch we require one label
      34           2 :         Eigen::VectorXf data{{3, 0, 9, 1}};
      35             : 
      36             :         /* This should give the following one-hot encoding
      37             :         {
      38             :       Idx  0  1  2  3  4  5  6  7  8  9
      39             :           {0, 0, 0, 1, 0, 0, 0, 0, 0, 0},
      40             :           {1, 0, 0, 0, 0, 0, 0, 0, 0, 0},
      41             :           {0, 0, 0, 0, 0, 0, 0, 0, 0, 1},
      42             :           {0, 1, 0, 0, 0, 0, 0, 0, 0, 0},
      43             :         }
      44             :         */
      45             : 
      46           2 :         DataContainer<real_t> notOneHot(desc, data);
      47           1 :         auto oneHot = ml::Utils::Encoding::toOneHot(notOneHot, numClasses, batchSize);
      48             : 
      49           1 :         REQUIRE(oneHot[3 + 0 * numClasses] == 1.f);
      50           1 :         REQUIRE(oneHot[0 + 1 * numClasses] == 1.f);
      51           1 :         REQUIRE(oneHot[9 + 2 * numClasses] == 1.f);
      52           1 :         REQUIRE(oneHot[1 + 3 * numClasses] == 1.f);
      53             :     }
      54             : 
      55           3 :     SECTION("Decode one-hot")
      56             :     {
      57           1 :         index_t numClasses = 10;
      58           2 :         IndexVector_t dims{{numClasses, 4}};
      59           2 :         VolumeDescriptor desc(dims);
      60             :         Eigen::VectorXf data{{0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0,
      61           2 :                               0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0}};
      62           2 :         DataContainer<real_t> dc(desc, data);
      63           1 :         auto fromOneHot = ml::Utils::Encoding::fromOneHot(dc, numClasses);
      64           1 :         REQUIRE(fromOneHot[0] == Approx(3.f));
      65           1 :         REQUIRE(fromOneHot[1] == Approx(0.f));
      66           1 :         REQUIRE(fromOneHot[2] == Approx(9.f));
      67           1 :         REQUIRE(fromOneHot[3] == Approx(1.f));
      68             :     }
      69           2 : }
      70             : TEST_SUITE_END();

Generated by: LCOV version 1.15