Line data Source code
1 : #include "doctest/doctest.h" 2 : #include "Input.h" 3 : #include "Reshape.h" 4 : 5 : using namespace elsa; 6 : using namespace doctest; 7 : 8 : TEST_SUITE_BEGIN("ml"); 9 : 10 : // TODO(dfrank): remove and replace with proper doctest usage of test cases 11 : #define SECTION(name) DOCTEST_SUBCASE(name) 12 : 13 1 : TEST_CASE("Reshape") 14 : { 15 2 : IndexVector_t inputDims{{3, 4, 5, 2}}; 16 2 : VolumeDescriptor inputDesc(inputDims); 17 : 18 2 : VolumeDescriptor targetShape({60, 2}); 19 2 : auto layer = ml::Reshape(targetShape, "Reshape"); 20 1 : REQUIRE(layer.getName() == "Reshape"); 21 : 22 1 : layer.setInputDescriptor(inputDesc); 23 1 : REQUIRE(layer.getInputDescriptor() == inputDesc); 24 : 25 1 : layer.computeOutputDescriptor(); 26 : 27 1 : REQUIRE(layer.getOutputDescriptor() == targetShape); 28 1 : } 29 : 30 1 : TEST_CASE("Flatten") 31 : { 32 2 : IndexVector_t inputDims{{3, 4, 5, 2}}; 33 : 34 2 : VolumeDescriptor inputDesc(inputDims); 35 : 36 3 : auto input = ml::Input(inputDesc); 37 : 38 3 : auto layer = ml::Flatten("Flatten"); 39 1 : REQUIRE(layer.getName() == "Flatten"); 40 : 41 1 : layer.setInput(&input); 42 : 43 1 : input.computeOutputDescriptor(); 44 : 45 1 : REQUIRE(input.getOutputDescriptor() == inputDesc); 46 : 47 1 : layer.setInputDescriptor(input.getOutputDescriptor()); 48 1 : layer.computeOutputDescriptor(); 49 1 : REQUIRE(layer.getInputDescriptor(0) == inputDesc); 50 : 51 2 : IndexVector_t requiredOutDims{{inputDims[0] * inputDims[1] * inputDims[2] * inputDims[3]}}; 52 : 53 1 : VolumeDescriptor requiredOutDesc(requiredOutDims); 54 1 : REQUIRE(layer.getOutputDescriptor() == requiredOutDesc); 55 1 : } 56 : TEST_SUITE_END();