Line data Source code
1 : /** 2 : * @file test_ISTA.cpp 3 : * 4 : * @brief Tests for the ISTA class 5 : * 6 : * @author Andi Braimllari 7 : */ 8 : 9 : #include "doctest/doctest.h" 10 : 11 : #include "Error.h" 12 : #include "ISTA.h" 13 : #include "Identity.h" 14 : #include "Logger.h" 15 : #include "VolumeDescriptor.h" 16 : #include "QuadricProblem.h" 17 : #include "testHelpers.h" 18 : 19 : using namespace elsa; 20 : using namespace doctest; 21 : 22 : TEST_SUITE_BEGIN("solvers"); 23 : 24 2 : TEST_CASE("ISTA: Solving a LASSOProblem") 25 : { 26 : // eliminate the timing info from console for the tests 27 2 : Logger::setLevel(Logger::LogLevel::OFF); 28 : 29 4 : GIVEN("a LASSOProblem") 30 : { 31 4 : IndexVector_t numCoeff(2); 32 2 : numCoeff << 25, 31; 33 4 : VolumeDescriptor volDescr(numCoeff); 34 : 35 4 : RealVector_t bVec(volDescr.getNumberOfCoefficients()); 36 2 : bVec.setRandom(); 37 4 : DataContainer dcB(volDescr, bVec); 38 : 39 4 : Identity idOp(volDescr); 40 : 41 4 : WLSProblem wlsProb(idOp, dcB); 42 : 43 4 : L1Norm regFunc(volDescr); 44 4 : RegularizationTerm regTerm(0.000001f, regFunc); 45 : 46 4 : LASSOProblem lassoProb(wlsProb, regTerm); 47 : 48 4 : WHEN("setting up an ISTA solver") 49 : { 50 4 : ISTA solver(lassoProb); 51 : 52 3 : THEN("cloned ISTA solver equals original ISTA solver") 53 : { 54 2 : auto istaClone = solver.clone(); 55 : 56 1 : REQUIRE_NE(istaClone.get(), &solver); 57 1 : REQUIRE_EQ(*istaClone, solver); 58 : } 59 : 60 3 : THEN("the solution is correct") 61 : { 62 1 : auto solution = solver.solve(100); 63 1 : REQUIRE_UNARY(checkApproxEq(solution.squaredL2Norm(), bVec.squaredNorm())); 64 : } 65 : } 66 : } 67 2 : } 68 : 69 2 : TEST_CASE("ISTA: Solving various problems") 70 : { 71 : // eliminate the timing info from console for the tests 72 2 : Logger::setLevel(Logger::LogLevel::OFF); 73 : 74 4 : GIVEN("a DataContainer") 75 : { 76 4 : IndexVector_t numCoeff(2); 77 2 : numCoeff << 25, 31; 78 4 : VolumeDescriptor volDescr(numCoeff); 79 : 80 4 : RealVector_t bVec(volDescr.getNumberOfCoefficients()); 81 2 : bVec.setRandom(); 82 4 : DataContainer dcB(volDescr, bVec); 83 : 84 3 : WHEN("setting up an ISTA solver for a WLSProblem") 85 : { 86 2 : Identity idOp(volDescr); 87 : 88 2 : WLSProblem wlsProb(idOp, dcB); 89 : 90 2 : THEN("an exception is thrown as no regularization term is provided") 91 : { 92 2 : REQUIRE_THROWS_AS(ISTA{wlsProb}, InvalidArgumentError); 93 : } 94 : } 95 : 96 3 : WHEN("setting up an ISTA solver for a QuadricProblem without A and without b") 97 : { 98 2 : Identity idOp(volDescr); 99 : 100 2 : QuadricProblem<real_t> quadricProbWithoutAb(Quadric<real_t>{volDescr}); 101 : 102 2 : THEN("the vector b is initialized with zeroes and the operator A becomes an " 103 : "identity operator but an exception is thrown due to missing regularization term") 104 : { 105 2 : REQUIRE_THROWS_AS(ISTA{quadricProbWithoutAb}, InvalidArgumentError); 106 : } 107 : } 108 : } 109 2 : } 110 : 111 : TEST_SUITE_END();