LCOV - code coverage report
Current view: top level - problems/tests - test_LASSOProblem.cpp (source / functions) Hit Total Coverage
Test: test_coverage.info.cleaned Lines: 59 59 100.0 %
Date: 2022-02-28 03:37:41 Functions: 5 5 100.0 %

          Line data    Source code
       1             : /**
       2             :  * @file test_LASSOProblem.cpp
       3             :  *
       4             :  * @brief Tests for the LASSOProblem class
       5             :  *
       6             :  * @author Andi Braimllari
       7             :  */
       8             : 
       9             : #include "doctest/doctest.h"
      10             : 
      11             : #include "Error.h"
      12             : #include "L2NormPow2.h"
      13             : #include "LASSOProblem.h"
      14             : #include "VolumeDescriptor.h"
      15             : #include "Identity.h"
      16             : #include "testHelpers.h"
      17             : 
      18             : using namespace elsa;
      19             : using namespace doctest;
      20             : 
      21             : TEST_SUITE_BEGIN("problems");
      22             : 
      23          18 : TEST_CASE_TEMPLATE("Scenario: Testing LASSOProblem", data_t, float, double)
      24             : {
      25          24 :     GIVEN("some data term and a regularization term")
      26             :     {
      27          24 :         IndexVector_t numCoeff(2);
      28          12 :         numCoeff << 17, 53;
      29          24 :         VolumeDescriptor dd(numCoeff);
      30             : 
      31          24 :         Vector_t<data_t> scaling(dd.getNumberOfCoefficients());
      32          12 :         scaling.setRandom();
      33          24 :         DataContainer<data_t> dcScaling(dd, scaling);
      34          24 :         Scaling scaleOp(dd, dcScaling);
      35             : 
      36          24 :         Vector_t<data_t> dataVec(dd.getNumberOfCoefficients());
      37          12 :         dataVec.setRandom();
      38          24 :         DataContainer<data_t> dcData(dd, dataVec);
      39             : 
      40          24 :         WLSProblem<data_t> wlsProblem(scaleOp, dcData);
      41             : 
      42          12 :         auto invalidWeight = static_cast<data_t>(-2.0);
      43          12 :         auto weight = static_cast<data_t>(2.0);
      44             : 
      45             :         // l1 norm regularization term
      46          24 :         L1Norm<data_t> regFunc(dd);
      47             : 
      48          14 :         WHEN("setting up a LASSOProblem with a negative regularization weight")
      49             :         {
      50           4 :             RegularizationTerm<data_t> invalidRegTerm(invalidWeight, regFunc);
      51           4 :             THEN("an invalid_argument exception is thrown")
      52             :             {
      53           6 :                 REQUIRE_THROWS_AS(LASSOProblem<data_t>(wlsProblem, invalidRegTerm),
      54             :                                   InvalidArgumentError);
      55             :             }
      56             :         }
      57             : 
      58             :         // l2 norm regularization term
      59          24 :         L2NormPow2<data_t> invalidRegFunc(dd);
      60             : 
      61          14 :         WHEN("setting up a LASSOProblem with a L2NormPow2 regularization term")
      62             :         {
      63           4 :             RegularizationTerm<data_t> invalidRegTerm(weight, invalidRegFunc);
      64           4 :             THEN("an invalid_argument exception is thrown")
      65             :             {
      66           6 :                 REQUIRE_THROWS_AS(LASSOProblem<data_t>(wlsProblem, invalidRegTerm),
      67             :                                   InvalidArgumentError);
      68             :             }
      69             :         }
      70             : 
      71          24 :         RegularizationTerm<data_t> regTerm(weight, regFunc);
      72             : 
      73          14 :         WHEN("setting up a LASSOProblem without an x0")
      74             :         {
      75           4 :             LASSOProblem<data_t> lassoProb(wlsProblem, regTerm);
      76             : 
      77           4 :             THEN("cloned LASSOProblem equals original LASSOProblem")
      78             :             {
      79           4 :                 auto lassoProbClone = lassoProb.clone();
      80             : 
      81           2 :                 REQUIRE_NE(lassoProbClone.get(), &lassoProb);
      82           2 :                 REQUIRE_EQ(*lassoProbClone, lassoProb);
      83             :             }
      84             :         }
      85             : 
      86          14 :         WHEN("setting up a LASSOProblem with an x0")
      87             :         {
      88           4 :             Eigen::Matrix<data_t, Eigen::Dynamic, 1> x0Vec(dd.getNumberOfCoefficients());
      89           2 :             x0Vec.setRandom();
      90           4 :             DataContainer<data_t> dcX0(dd, x0Vec);
      91             : 
      92           2 :             wlsProblem.getCurrentSolution() = dcX0;
      93           4 :             LASSOProblem<data_t> lassoProb(wlsProblem, regTerm);
      94             : 
      95           4 :             THEN("cloned LASSOProblem equals original LASSOProblem")
      96             :             {
      97           4 :                 auto lassoProbClone = lassoProb.clone();
      98             : 
      99           2 :                 REQUIRE_NE(lassoProbClone.get(), &lassoProb);
     100           2 :                 REQUIRE_EQ(*lassoProbClone, lassoProb);
     101             :             }
     102             :         }
     103             : 
     104          24 :         Identity<data_t> idOp(dd);
     105          24 :         WLSProblem<data_t> wlsProblemForLC(idOp, dcData);
     106             : 
     107          14 :         WHEN("setting up the Lipschitz Constant of a LASSOProblem without an x0")
     108             :         {
     109           4 :             LASSOProblem<data_t> lassoProb(wlsProblemForLC, regTerm);
     110             : 
     111           2 :             auto lipschitzConstant = lassoProb.getLipschitzConstant();
     112             : 
     113           4 :             THEN("the Lipschitz Constant of a LASSOProblem with an Identity Operator as the "
     114             :                  "Linear Operator A is 1")
     115             :             {
     116           2 :                 REQUIRE_UNARY(checkApproxEq(lipschitzConstant, as<data_t>(1.0)));
     117             :             }
     118             :         }
     119             : 
     120          14 :         WHEN("setting up the Lipschitz Constant of a LASSOProblem with an x0")
     121             :         {
     122           4 :             Vector_t<data_t> x0Vec(dd.getNumberOfCoefficients());
     123           2 :             x0Vec.setRandom();
     124           4 :             DataContainer<data_t> dcX0(dd, x0Vec);
     125           2 :             wlsProblemForLC.getCurrentSolution() = dcX0;
     126             : 
     127           4 :             LASSOProblem<data_t> lassoProb(wlsProblemForLC, regTerm);
     128             : 
     129           2 :             auto lipschitzConstant = lassoProb.getLipschitzConstant();
     130             : 
     131           4 :             THEN("the Lipschitz Constant of a LASSOProblem with an Identity Operator as the "
     132             :                  "Linear Operator A is 1")
     133             :             {
     134           2 :                 REQUIRE_EQ(lipschitzConstant, as<data_t>(1.0));
     135             :             }
     136             :         }
     137             :     }
     138          12 : }
     139             : 
     140             : TEST_SUITE_END();

Generated by: LCOV version 1.15