Line data Source code
1 : /** 2 : * @file test_FISTA.cpp 3 : * 4 : * @brief Tests for the FISTA class 5 : * 6 : * @author Andi Braimllari 7 : */ 8 : 9 : #include "doctest/doctest.h" 10 : 11 : #include "Error.h" 12 : #include "FISTA.h" 13 : #include "Identity.h" 14 : #include "Logger.h" 15 : #include "VolumeDescriptor.h" 16 : #include "QuadricProblem.h" 17 : #include "testHelpers.h" 18 : 19 : using namespace elsa; 20 : using namespace doctest; 21 : 22 : TEST_SUITE_BEGIN("solvers"); 23 : 24 2 : TEST_CASE("FISTA: Solving a LASSOProblem") 25 : { 26 : // eliminate the timing info from console for the tests 27 2 : Logger::setLevel(Logger::LogLevel::OFF); 28 : 29 4 : GIVEN("a LASSOProblem") 30 : { 31 4 : IndexVector_t numCoeff(2); 32 2 : numCoeff << 25, 31; 33 4 : VolumeDescriptor volDescr(numCoeff); 34 : 35 4 : RealVector_t bVec(volDescr.getNumberOfCoefficients()); 36 2 : bVec.setRandom(); 37 4 : DataContainer dcB(volDescr, bVec); 38 : 39 4 : Identity idOp(volDescr); 40 : 41 4 : WLSProblem wlsProb(idOp, dcB); 42 : 43 4 : L1Norm regFunc(volDescr); 44 4 : RegularizationTerm regTerm(0.000001f, regFunc); 45 : 46 4 : LASSOProblem lassoProb(wlsProb, regTerm); 47 : 48 4 : WHEN("setting up a FISTA solver") 49 : { 50 4 : FISTA solver(lassoProb, geometry::Threshold(0.005f)); 51 : 52 3 : THEN("cloned FISTA solver equals original FISTA solver") 53 : { 54 2 : auto fistaClone = solver.clone(); 55 : 56 1 : REQUIRE_NE(fistaClone.get(), &solver); 57 1 : REQUIRE_EQ(*fistaClone, solver); 58 : } 59 : 60 3 : THEN("the solution is correct") 61 : { 62 1 : auto solution = solver.solve(500); 63 1 : REQUIRE_UNARY(checkApproxEq(solution.squaredL2Norm(), bVec.squaredNorm())); 64 : } 65 : } 66 : } 67 2 : } 68 : 69 2 : TEST_CASE("FISTA: Solving various problems") 70 : { 71 : // eliminate the timing info from console for the tests 72 2 : Logger::setLevel(Logger::LogLevel::OFF); 73 : 74 4 : GIVEN("a DataContainer") 75 : { 76 4 : IndexVector_t numCoeff(2); 77 2 : numCoeff << 14, 9; 78 4 : VolumeDescriptor volDescr(numCoeff); 79 : 80 4 : RealVector_t bVec(volDescr.getNumberOfCoefficients()); 81 2 : bVec.setRandom(); 82 4 : DataContainer dcB(volDescr, bVec); 83 : 84 3 : WHEN("setting up an FISTA solver for a WLSProblem") 85 : { 86 2 : Identity idOp(volDescr); 87 : 88 2 : WLSProblem wlsProb(idOp, dcB); 89 : 90 2 : THEN("an exception is thrown as no regularization term is provided") 91 : { 92 2 : REQUIRE_THROWS_AS(FISTA{wlsProb}, InvalidArgumentError); 93 : } 94 : } 95 : 96 3 : WHEN("setting up an FISTA solver for a QuadricProblem without A and without b") 97 : { 98 2 : Identity idOp(volDescr); 99 : 100 2 : QuadricProblem<real_t> quadricProbWithoutAb(Quadric<real_t>{volDescr}); 101 : 102 2 : THEN("the vector b is initialized with zeroes and the operator A becomes an " 103 : "identity operator but an exception is thrown due to missing regularization term") 104 : { 105 2 : REQUIRE_THROWS_AS(FISTA{quadricProbWithoutAb}, InvalidArgumentError); 106 : } 107 : } 108 : } 109 2 : } 110 : 111 : TEST_SUITE_END();