Line data Source code
1 : /** 2 : * @file test_common.cpp 3 : * 4 : * @brief Tests for common ml functionality 5 : * 6 : * @author David Tellenbach 7 : */ 8 : 9 : #include "doctest/doctest.h" 10 : #include "DataContainer.h" 11 : #include "VolumeDescriptor.h" 12 : #include "Common.h" 13 : 14 : using namespace elsa; 15 : using namespace doctest; 16 : 17 : TEST_SUITE_BEGIN("ml"); 18 : 19 : // TODO(dfrank): remove and replace with proper doctest usage of test cases 20 : #define SECTION(name) DOCTEST_SUBCASE(name) 21 : 22 4 : TEST_CASE("Common") 23 : { 24 5 : SECTION("LayerType") 25 : { 26 1 : REQUIRE(ml::detail::getEnumMemberAsString(ml::LayerType::Input) == "Input"); 27 1 : REQUIRE(ml::detail::getEnumMemberAsString(ml::LayerType::Dense) == "Dense"); 28 1 : REQUIRE(ml::detail::getEnumMemberAsString(ml::LayerType::Conv1D) == "Conv1D"); 29 1 : REQUIRE(ml::detail::getEnumMemberAsString(ml::LayerType::Conv2D) == "Conv2D"); 30 1 : REQUIRE(ml::detail::getEnumMemberAsString(ml::LayerType::Conv3D) == "Conv3D"); 31 1 : REQUIRE(ml::detail::getEnumMemberAsString(ml::LayerType::Sum) == "Sum"); 32 1 : REQUIRE(ml::detail::getEnumMemberAsString(ml::LayerType::Concatenate) == "Concatenate"); 33 : } 34 5 : SECTION("PropagationKind") 35 : { 36 1 : REQUIRE(ml::detail::getEnumMemberAsString(ml::PropagationKind::Forward) == "Forward"); 37 1 : REQUIRE(ml::detail::getEnumMemberAsString(ml::PropagationKind::Backward) == "Backward"); 38 1 : REQUIRE(ml::detail::getEnumMemberAsString(ml::PropagationKind::Full) == "Full"); 39 : } 40 5 : SECTION("MlBackend") 41 : { 42 1 : REQUIRE(ml::detail::getEnumMemberAsString(ml::MlBackend::Dnnl) == "Dnnl"); 43 1 : REQUIRE(ml::detail::getEnumMemberAsString(ml::MlBackend::Cudnn) == "Cudnn"); 44 : } 45 5 : SECTION("Initializer") 46 : { 47 1 : REQUIRE(ml::detail::getEnumMemberAsString(ml::Initializer::Zeros) == "Zeros"); 48 1 : REQUIRE(ml::detail::getEnumMemberAsString(ml::Initializer::Ones) == "Ones"); 49 1 : REQUIRE(ml::detail::getEnumMemberAsString(ml::Initializer::Normal) == "Normal"); 50 1 : REQUIRE(ml::detail::getEnumMemberAsString(ml::Initializer::Uniform) == "Uniform"); 51 1 : REQUIRE(ml::detail::getEnumMemberAsString(ml::Initializer::GlorotNormal) == "GlorotNormal"); 52 1 : REQUIRE(ml::detail::getEnumMemberAsString(ml::Initializer::GlorotUniform) 53 : == "GlorotUniform"); 54 1 : REQUIRE(ml::detail::getEnumMemberAsString(ml::Initializer::HeNormal) == "HeNormal"); 55 1 : REQUIRE(ml::detail::getEnumMemberAsString(ml::Initializer::HeUniform) == "HeUniform"); 56 : } 57 4 : } 58 : TEST_SUITE_END();