LCOV - code coverage report
Current view: top level - elsa/problems/tests - test_LASSOProblem.cpp (source / functions) Hit Total Coverage
Test: coverage-all.lcov Lines: 89 89 100.0 %
Date: 2022-08-25 03:05:39 Functions: 2 2 100.0 %

          Line data    Source code
       1             : /**
       2             :  * @file test_LASSOProblem.cpp
       3             :  *
       4             :  * @brief Tests for the LASSOProblem class
       5             :  *
       6             :  * @author Andi Braimllari
       7             :  */
       8             : 
       9             : #include "doctest/doctest.h"
      10             : 
      11             : #include "Error.h"
      12             : #include "L2NormPow2.h"
      13             : #include "LASSOProblem.h"
      14             : #include "VolumeDescriptor.h"
      15             : #include "Identity.h"
      16             : #include "testHelpers.h"
      17             : 
      18             : using namespace elsa;
      19             : using namespace doctest;
      20             : 
      21             : TEST_SUITE_BEGIN("problems");
      22             : 
      23             : TEST_CASE_TEMPLATE("Scenario: Testing LASSOProblem", data_t, float, double)
      24          12 : {
      25          12 :     GIVEN("some data term and a regularization term")
      26          12 :     {
      27          12 :         IndexVector_t numCoeff(2);
      28          12 :         numCoeff << 17, 53;
      29          12 :         VolumeDescriptor dd(numCoeff);
      30             : 
      31          12 :         Vector_t<data_t> scaling(dd.getNumberOfCoefficients());
      32          12 :         scaling.setRandom();
      33          12 :         DataContainer<data_t> dcScaling(dd, scaling);
      34          12 :         Scaling scaleOp(dd, dcScaling);
      35             : 
      36          12 :         Vector_t<data_t> dataVec(dd.getNumberOfCoefficients());
      37          12 :         dataVec.setRandom();
      38          12 :         DataContainer<data_t> dcData(dd, dataVec);
      39             : 
      40          12 :         WLSProblem<data_t> wlsProblem(scaleOp, dcData);
      41             : 
      42          12 :         auto invalidWeight = static_cast<data_t>(-2.0);
      43          12 :         auto weight = static_cast<data_t>(2.0);
      44             : 
      45             :         // l1 norm regularization term
      46          12 :         L1Norm<data_t> regFunc(dd);
      47             : 
      48          12 :         WHEN("setting up a LASSOProblem with a negative regularization weight")
      49          12 :         {
      50           2 :             RegularizationTerm<data_t> invalidRegTerm(invalidWeight, regFunc);
      51           2 :             THEN("an invalid_argument exception is thrown")
      52           2 :             {
      53           2 :                 REQUIRE_THROWS_AS(LASSOProblem<data_t>(wlsProblem, invalidRegTerm),
      54           2 :                                   InvalidArgumentError);
      55           2 :             }
      56           2 :         }
      57             : 
      58             :         // l2 norm regularization term
      59          12 :         L2NormPow2<data_t> invalidRegFunc(dd);
      60             : 
      61          12 :         WHEN("setting up a LASSOProblem with a L2NormPow2 regularization term")
      62          12 :         {
      63           2 :             RegularizationTerm<data_t> invalidRegTerm(weight, invalidRegFunc);
      64           2 :             THEN("an invalid_argument exception is thrown")
      65           2 :             {
      66           2 :                 REQUIRE_THROWS_AS(LASSOProblem<data_t>(wlsProblem, invalidRegTerm),
      67           2 :                                   InvalidArgumentError);
      68           2 :             }
      69           2 :         }
      70             : 
      71          12 :         RegularizationTerm<data_t> regTerm(weight, regFunc);
      72             : 
      73          12 :         WHEN("setting up a LASSOProblem without an x0")
      74          12 :         {
      75           2 :             LASSOProblem<data_t> lassoProb(wlsProblem, regTerm);
      76             : 
      77           2 :             THEN("cloned LASSOProblem equals original LASSOProblem")
      78           2 :             {
      79           2 :                 auto lassoProbClone = lassoProb.clone();
      80             : 
      81           2 :                 REQUIRE_NE(lassoProbClone.get(), &lassoProb);
      82           2 :                 REQUIRE_EQ(*lassoProbClone, lassoProb);
      83           2 :             }
      84           2 :         }
      85             : 
      86          12 :         WHEN("setting up a LASSOProblem with an x0")
      87          12 :         {
      88           2 :             Eigen::Matrix<data_t, Eigen::Dynamic, 1> x0Vec(dd.getNumberOfCoefficients());
      89           2 :             x0Vec.setRandom();
      90           2 :             DataContainer<data_t> dcX0(dd, x0Vec);
      91             : 
      92           2 :             wlsProblem.getCurrentSolution() = dcX0;
      93           2 :             LASSOProblem<data_t> lassoProb(wlsProblem, regTerm);
      94             : 
      95           2 :             THEN("cloned LASSOProblem equals original LASSOProblem")
      96           2 :             {
      97           2 :                 auto lassoProbClone = lassoProb.clone();
      98             : 
      99           2 :                 REQUIRE_NE(lassoProbClone.get(), &lassoProb);
     100           2 :                 REQUIRE_EQ(*lassoProbClone, lassoProb);
     101           2 :             }
     102           2 :         }
     103             : 
     104          12 :         Identity<data_t> idOp(dd);
     105          12 :         WLSProblem<data_t> wlsProblemForLC(idOp, dcData);
     106             : 
     107          12 :         WHEN("setting up the Lipschitz Constant of a LASSOProblem without an x0")
     108          12 :         {
     109           2 :             LASSOProblem<data_t> lassoProb(wlsProblemForLC, regTerm);
     110             : 
     111           2 :             auto lipschitzConstant = lassoProb.getLipschitzConstant();
     112             : 
     113           2 :             THEN("the Lipschitz Constant of a LASSOProblem with an Identity Operator as the "
     114           2 :                  "Linear Operator A is 1")
     115           2 :             {
     116           2 :                 REQUIRE_UNARY(checkApproxEq(lipschitzConstant, as<data_t>(1.0)));
     117           2 :             }
     118           2 :         }
     119             : 
     120          12 :         WHEN("setting up the Lipschitz Constant of a LASSOProblem with an x0")
     121          12 :         {
     122           2 :             Vector_t<data_t> x0Vec(dd.getNumberOfCoefficients());
     123           2 :             x0Vec.setRandom();
     124           2 :             DataContainer<data_t> dcX0(dd, x0Vec);
     125           2 :             wlsProblemForLC.getCurrentSolution() = dcX0;
     126             : 
     127           2 :             LASSOProblem<data_t> lassoProb(wlsProblemForLC, regTerm);
     128             : 
     129           2 :             auto lipschitzConstant = lassoProb.getLipschitzConstant();
     130             : 
     131           2 :             THEN("the Lipschitz Constant of a LASSOProblem with an Identity Operator as the "
     132           2 :                  "Linear Operator A is 1")
     133           2 :             {
     134           2 :                 REQUIRE_EQ(lipschitzConstant, Approx(as<data_t>(1.0)));
     135           2 :             }
     136           2 :         }
     137          12 :     }
     138          12 : }
     139             : 
     140             : TEST_SUITE_END();

Generated by: LCOV version 1.14