Line data Source code
1 : /** 2 : * @file test_FISTA.cpp 3 : * 4 : * @brief Tests for the FISTA class 5 : * 6 : * @author Andi Braimllari 7 : */ 8 : 9 : #include "doctest/doctest.h" 10 : 11 : #include "Error.h" 12 : #include "FISTA.h" 13 : #include "Identity.h" 14 : #include "Logger.h" 15 : #include "VolumeDescriptor.h" 16 : #include "QuadricProblem.h" 17 : #include "testHelpers.h" 18 : 19 : using namespace elsa; 20 : using namespace doctest; 21 : 22 : TEST_SUITE_BEGIN("solvers"); 23 : 24 : TEST_CASE("FISTA: Solving a LASSOProblem") 25 2 : { 26 : // eliminate the timing info from console for the tests 27 2 : Logger::setLevel(Logger::LogLevel::OFF); 28 : 29 2 : GIVEN("a LASSOProblem") 30 2 : { 31 2 : IndexVector_t numCoeff(2); 32 2 : numCoeff << 25, 31; 33 2 : VolumeDescriptor volDescr(numCoeff); 34 : 35 2 : RealVector_t bVec(volDescr.getNumberOfCoefficients()); 36 2 : bVec.setRandom(); 37 2 : DataContainer dcB(volDescr, bVec); 38 : 39 2 : Identity idOp(volDescr); 40 : 41 2 : WLSProblem wlsProb(idOp, dcB); 42 : 43 2 : L1Norm regFunc(volDescr); 44 2 : RegularizationTerm regTerm(0.000001f, regFunc); 45 : 46 2 : LASSOProblem lassoProb(wlsProb, regTerm); 47 : 48 2 : WHEN("setting up a FISTA solver") 49 2 : { 50 2 : FISTA solver(lassoProb, geometry::Threshold(0.005f)); 51 : 52 2 : THEN("cloned FISTA solver equals original FISTA solver") 53 2 : { 54 1 : auto fistaClone = solver.clone(); 55 : 56 1 : REQUIRE_NE(fistaClone.get(), &solver); 57 1 : REQUIRE_EQ(*fistaClone, solver); 58 1 : } 59 : 60 2 : THEN("the solution is correct") 61 2 : { 62 1 : auto solution = solver.solve(500); 63 1 : REQUIRE_UNARY(checkApproxEq(solution.squaredL2Norm(), bVec.squaredNorm())); 64 1 : } 65 2 : } 66 2 : } 67 2 : } 68 : 69 : TEST_CASE("FISTA: Solving various problems") 70 2 : { 71 : // eliminate the timing info from console for the tests 72 2 : Logger::setLevel(Logger::LogLevel::OFF); 73 : 74 2 : GIVEN("a DataContainer") 75 2 : { 76 2 : IndexVector_t numCoeff(2); 77 2 : numCoeff << 14, 9; 78 2 : VolumeDescriptor volDescr(numCoeff); 79 : 80 2 : RealVector_t bVec(volDescr.getNumberOfCoefficients()); 81 2 : bVec.setRandom(); 82 2 : DataContainer dcB(volDescr, bVec); 83 : 84 2 : WHEN("setting up an FISTA solver for a WLSProblem") 85 2 : { 86 1 : Identity idOp(volDescr); 87 : 88 1 : WLSProblem wlsProb(idOp, dcB); 89 : 90 1 : THEN("an exception is thrown as no regularization term is provided") 91 1 : { 92 1 : REQUIRE_THROWS_AS(FISTA{wlsProb}, InvalidArgumentError); 93 1 : } 94 1 : } 95 : 96 2 : WHEN("setting up an FISTA solver for a QuadricProblem without A and without b") 97 2 : { 98 1 : Identity idOp(volDescr); 99 : 100 1 : QuadricProblem<real_t> quadricProbWithoutAb(Quadric<real_t>{volDescr}); 101 : 102 1 : THEN("the vector b is initialized with zeroes and the operator A becomes an " 103 1 : "identity operator but an exception is thrown due to missing regularization term") 104 1 : { 105 1 : REQUIRE_THROWS_AS(FISTA{quadricProbWithoutAb}, InvalidArgumentError); 106 1 : } 107 1 : } 108 2 : } 109 2 : } 110 : 111 : TEST_SUITE_END();