Line data Source code
1 : /**
2 : * @file test_common.cpp
3 : *
4 : * @brief Tests for common ml functionality
5 : *
6 : * @author David Tellenbach
7 : */
8 :
9 : #include "doctest/doctest.h"
10 :
11 : #include "elsaDefines.h"
12 : #include "DataContainer.h"
13 : #include "VolumeDescriptor.h"
14 : #include "Utils.h"
15 :
16 : using namespace elsa;
17 : using namespace doctest;
18 :
19 : TEST_SUITE_BEGIN("ml");
20 :
21 : // TODO(dfrank): remove and replace with proper doctest usage of test cases
22 : #define SECTION(name) DOCTEST_SUBCASE(name)
23 :
24 2 : TEST_CASE("Encoding")
25 : {
26 3 : SECTION("Encode one-hot")
27 : {
28 1 : index_t numClasses = 10;
29 1 : index_t batchSize = 4;
30 2 : IndexVector_t dims{{batchSize}};
31 2 : VolumeDescriptor desc(dims);
32 :
33 : // For each entry in a batch we require one label
34 2 : Eigen::VectorXf data{{3, 0, 9, 1}};
35 :
36 : /* This should give the following one-hot encoding
37 : {
38 : Idx 0 1 2 3 4 5 6 7 8 9
39 : {0, 0, 0, 1, 0, 0, 0, 0, 0, 0},
40 : {1, 0, 0, 0, 0, 0, 0, 0, 0, 0},
41 : {0, 0, 0, 0, 0, 0, 0, 0, 0, 1},
42 : {0, 1, 0, 0, 0, 0, 0, 0, 0, 0},
43 : }
44 : */
45 :
46 2 : DataContainer<real_t> notOneHot(desc, data);
47 1 : auto oneHot = ml::Utils::Encoding::toOneHot(notOneHot, numClasses, batchSize);
48 :
49 1 : REQUIRE(oneHot[3 + 0 * numClasses] == 1.f);
50 1 : REQUIRE(oneHot[0 + 1 * numClasses] == 1.f);
51 1 : REQUIRE(oneHot[9 + 2 * numClasses] == 1.f);
52 1 : REQUIRE(oneHot[1 + 3 * numClasses] == 1.f);
53 : }
54 :
55 3 : SECTION("Decode one-hot")
56 : {
57 1 : index_t numClasses = 10;
58 2 : IndexVector_t dims{{numClasses, 4}};
59 2 : VolumeDescriptor desc(dims);
60 : Eigen::VectorXf data{{0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0,
61 2 : 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0}};
62 2 : DataContainer<real_t> dc(desc, data);
63 1 : auto fromOneHot = ml::Utils::Encoding::fromOneHot(dc, numClasses);
64 1 : REQUIRE(fromOneHot[0] == Approx(3.f));
65 1 : REQUIRE(fromOneHot[1] == Approx(0.f));
66 1 : REQUIRE(fromOneHot[2] == Approx(9.f));
67 1 : REQUIRE(fromOneHot[3] == Approx(1.f));
68 : }
69 2 : }
70 : TEST_SUITE_END();
|