Line data Source code
1 : /** 2 : * @file test_LASSOProblem.cpp 3 : * 4 : * @brief Tests for the LASSOProblem class 5 : * 6 : * @author Andi Braimllari 7 : */ 8 : 9 : #include "doctest/doctest.h" 10 : 11 : #include "Error.h" 12 : #include "L2NormPow2.h" 13 : #include "LASSOProblem.h" 14 : #include "VolumeDescriptor.h" 15 : #include "Identity.h" 16 : #include "testHelpers.h" 17 : 18 : using namespace elsa; 19 : using namespace doctest; 20 : 21 : TEST_SUITE_BEGIN("problems"); 22 : 23 : TEST_CASE_TEMPLATE("Scenario: Testing LASSOProblem", data_t, float, double) 24 12 : { 25 12 : GIVEN("some data term and a regularization term") 26 12 : { 27 12 : IndexVector_t numCoeff(2); 28 12 : numCoeff << 17, 53; 29 12 : VolumeDescriptor dd(numCoeff); 30 : 31 12 : Vector_t<data_t> scaling(dd.getNumberOfCoefficients()); 32 12 : scaling.setRandom(); 33 12 : DataContainer<data_t> dcScaling(dd, scaling); 34 12 : Scaling scaleOp(dd, dcScaling); 35 : 36 12 : Vector_t<data_t> dataVec(dd.getNumberOfCoefficients()); 37 12 : dataVec.setRandom(); 38 12 : DataContainer<data_t> dcData(dd, dataVec); 39 : 40 12 : WLSProblem<data_t> wlsProblem(scaleOp, dcData); 41 : 42 12 : auto invalidWeight = static_cast<data_t>(-2.0); 43 12 : auto weight = static_cast<data_t>(2.0); 44 : 45 : // l1 norm regularization term 46 12 : L1Norm<data_t> regFunc(dd); 47 : 48 12 : WHEN("setting up a LASSOProblem with a negative regularization weight") 49 12 : { 50 2 : RegularizationTerm<data_t> invalidRegTerm(invalidWeight, regFunc); 51 2 : THEN("an invalid_argument exception is thrown") 52 2 : { 53 2 : REQUIRE_THROWS_AS(LASSOProblem<data_t>(wlsProblem, invalidRegTerm), 54 2 : InvalidArgumentError); 55 2 : } 56 2 : } 57 : 58 : // l2 norm regularization term 59 12 : L2NormPow2<data_t> invalidRegFunc(dd); 60 : 61 12 : WHEN("setting up a LASSOProblem with a L2NormPow2 regularization term") 62 12 : { 63 2 : RegularizationTerm<data_t> invalidRegTerm(weight, invalidRegFunc); 64 2 : THEN("an invalid_argument exception is thrown") 65 2 : { 66 2 : REQUIRE_THROWS_AS(LASSOProblem<data_t>(wlsProblem, invalidRegTerm), 67 2 : InvalidArgumentError); 68 2 : } 69 2 : } 70 : 71 12 : RegularizationTerm<data_t> regTerm(weight, regFunc); 72 : 73 12 : WHEN("setting up a LASSOProblem without an x0") 74 12 : { 75 2 : LASSOProblem<data_t> lassoProb(wlsProblem, regTerm); 76 : 77 2 : THEN("cloned LASSOProblem equals original LASSOProblem") 78 2 : { 79 2 : auto lassoProbClone = lassoProb.clone(); 80 : 81 2 : REQUIRE_NE(lassoProbClone.get(), &lassoProb); 82 2 : REQUIRE_EQ(*lassoProbClone, lassoProb); 83 2 : } 84 2 : } 85 : 86 12 : WHEN("setting up a LASSOProblem with an x0") 87 12 : { 88 2 : Eigen::Matrix<data_t, Eigen::Dynamic, 1> x0Vec(dd.getNumberOfCoefficients()); 89 2 : x0Vec.setRandom(); 90 2 : DataContainer<data_t> dcX0(dd, x0Vec); 91 : 92 2 : wlsProblem.getCurrentSolution() = dcX0; 93 2 : LASSOProblem<data_t> lassoProb(wlsProblem, regTerm); 94 : 95 2 : THEN("cloned LASSOProblem equals original LASSOProblem") 96 2 : { 97 2 : auto lassoProbClone = lassoProb.clone(); 98 : 99 2 : REQUIRE_NE(lassoProbClone.get(), &lassoProb); 100 2 : REQUIRE_EQ(*lassoProbClone, lassoProb); 101 2 : } 102 2 : } 103 : 104 12 : Identity<data_t> idOp(dd); 105 12 : WLSProblem<data_t> wlsProblemForLC(idOp, dcData); 106 : 107 12 : WHEN("setting up the Lipschitz Constant of a LASSOProblem without an x0") 108 12 : { 109 2 : LASSOProblem<data_t> lassoProb(wlsProblemForLC, regTerm); 110 : 111 2 : auto lipschitzConstant = lassoProb.getLipschitzConstant(); 112 : 113 2 : THEN("the Lipschitz Constant of a LASSOProblem with an Identity Operator as the " 114 2 : "Linear Operator A is 1") 115 2 : { 116 2 : REQUIRE_UNARY(checkApproxEq(lipschitzConstant, as<data_t>(1.0))); 117 2 : } 118 2 : } 119 : 120 12 : WHEN("setting up the Lipschitz Constant of a LASSOProblem with an x0") 121 12 : { 122 2 : Vector_t<data_t> x0Vec(dd.getNumberOfCoefficients()); 123 2 : x0Vec.setRandom(); 124 2 : DataContainer<data_t> dcX0(dd, x0Vec); 125 2 : wlsProblemForLC.getCurrentSolution() = dcX0; 126 : 127 2 : LASSOProblem<data_t> lassoProb(wlsProblemForLC, regTerm); 128 : 129 2 : auto lipschitzConstant = lassoProb.getLipschitzConstant(); 130 : 131 2 : THEN("the Lipschitz Constant of a LASSOProblem with an Identity Operator as the " 132 2 : "Linear Operator A is 1") 133 2 : { 134 2 : REQUIRE_EQ(lipschitzConstant, Approx(as<data_t>(1.0))); 135 2 : } 136 2 : } 137 12 : } 138 12 : } 139 : 140 : TEST_SUITE_END();