Line data Source code
1 : #include "doctest/doctest.h" 2 : #include "Input.h" 3 : #include "Dense.h" 4 : 5 : using namespace elsa; 6 : using namespace doctest; 7 : 8 : TEST_SUITE_BEGIN("ml"); 9 : 10 : // TODO(dfrank): remove and replace with proper doctest usage of test cases 11 : #define SECTION(name) DOCTEST_SUBCASE(name) 12 : 13 4 : TEST_CASE_TEMPLATE("CoreLayers", TestType, float) 14 : { 15 3 : SECTION("Input") 16 : { 17 2 : IndexVector_t dims{{1, 2, 3}}; 18 2 : VolumeDescriptor desc(dims); 19 2 : auto i = ml::Input<TestType>(desc, 10, "input"); 20 1 : REQUIRE(i.getInputDescriptor() == desc); 21 1 : i.computeOutputDescriptor(); 22 1 : REQUIRE(i.getOutputDescriptor() == desc); 23 1 : REQUIRE(i.getLayerType() == ml::LayerType::Input); 24 1 : REQUIRE(i.getBatchSize() == 10); 25 1 : REQUIRE(i.getName() == "input"); 26 1 : REQUIRE(i.getNumberOfTrainableParameters() == 0); 27 : } 28 3 : SECTION("Dense") 29 : { 30 3 : auto d = 31 : ml::Dense<TestType>(10, ml::Activation::Relu, false, ml::Initializer::GlorotUniform, 32 : ml::Initializer::Zeros, "dense"); 33 : 34 : // Check number of units 35 1 : REQUIRE(d.getNumberOfUnits() == 10); 36 : 37 : // Check activation 38 1 : REQUIRE(d.getActivation() == ml::Activation::Relu); 39 : 40 : // Check if we are using a bias 41 1 : REQUIRE(d.useBias() == false); 42 : 43 : // Check initializers 44 1 : REQUIRE(d.getKernelInitializer() == ml::Initializer::GlorotUniform); 45 1 : REQUIRE(d.getBiasInitializer() == ml::Initializer::Zeros); 46 : 47 : // Check the name 48 1 : REQUIRE(d.getName() == "dense"); 49 : 50 : // Check input-descriptor 51 2 : IndexVector_t validDims{{1}}; 52 2 : VolumeDescriptor validDesc(validDims); 53 1 : d.setInputDescriptor(validDesc); 54 1 : REQUIRE(d.getInputDescriptor() == validDesc); 55 : 56 1 : d.computeOutputDescriptor(); 57 2 : IndexVector_t outDims{{10}}; 58 2 : VolumeDescriptor outDesc(outDims); 59 1 : REQUIRE(d.getOutputDescriptor() == outDesc); 60 : 61 : // Note here that we don't use a bias 62 1 : REQUIRE(d.getNumberOfTrainableParameters() == 1 * d.getNumberOfUnits()); 63 : 64 : // Dense requires a 1D input 65 2 : IndexVector_t invalidDims{{1, 10}}; 66 2 : VolumeDescriptor invalidDesc(invalidDims); 67 2 : REQUIRE_THROWS_AS(d.setInputDescriptor(invalidDesc), std::invalid_argument); 68 : } 69 2 : } 70 : TEST_SUITE_END();