LCOV - code coverage report
Current view: top level - solvers/tests - test_ADMM.cpp (source / functions) Hit Total Coverage
Test: test_coverage.info.cleaned Lines: 34 34 100.0 %
Date: 2022-02-28 03:37:41 Functions: 5 5 100.0 %

          Line data    Source code
       1             : /**
       2             :  * @file test_ADMM.cpp
       3             :  *
       4             :  * @brief Tests for the ADMM class
       5             :  *
       6             :  * @author Andi Braimllari
       7             :  */
       8             : 
       9             : #include "doctest/doctest.h"
      10             : 
      11             : #include "ADMM.h"
      12             : #include "CG.h"
      13             : #include "SoftThresholding.h"
      14             : #include "HardThresholding.h"
      15             : #include "Identity.h"
      16             : #include "FISTA.h"
      17             : #include "Logger.h"
      18             : #include "VolumeDescriptor.h"
      19             : #include <testHelpers.h>
      20             : 
      21             : using namespace elsa;
      22             : using namespace doctest;
      23             : 
      24             : TEST_SUITE_BEGIN("solvers");
      25             : 
      26          10 : TEST_CASE_TEMPLATE("ADMM: Solving problems", data_t, float, double)
      27             : {
      28           4 :     Logger::setLevel(Logger::LogLevel::OFF);
      29             : 
      30           8 :     GIVEN("some problems and a constraint")
      31             :     {
      32           8 :         IndexVector_t numCoeff(2);
      33           4 :         numCoeff << 21, 11;
      34           8 :         VolumeDescriptor volDescr(numCoeff);
      35             : 
      36           8 :         Vector_t<data_t> bVec(volDescr.getNumberOfCoefficients());
      37           4 :         bVec.setRandom();
      38           8 :         DataContainer<data_t> dcB(volDescr, bVec);
      39             : 
      40           8 :         Identity<data_t> idOp(volDescr);
      41           8 :         Scaling<data_t> negativeIdOp(volDescr, -1);
      42           8 :         DataContainer<data_t> dCC(volDescr);
      43           4 :         dCC = 0;
      44             : 
      45           8 :         WLSProblem<data_t> wlsProb(idOp, dcB);
      46             : 
      47           8 :         Constraint<data_t> constraint(idOp, negativeIdOp, dCC);
      48             : 
      49           6 :         WHEN("setting up ADMM and FISTA to solve a LASSOProblem")
      50             :         {
      51           4 :             L1Norm<data_t> regFunc(volDescr);
      52           4 :             RegularizationTerm<data_t> regTerm(0.000001f, regFunc);
      53             : 
      54           4 :             SplittingProblem<data_t> splittingProblem(wlsProb.getDataTerm(), regTerm, constraint);
      55             : 
      56           4 :             ADMM<CG, SoftThresholding, data_t> admm(splittingProblem);
      57             : 
      58           4 :             LASSOProblem<data_t> lassoProb(wlsProb, regTerm);
      59           4 :             FISTA<data_t> fista(lassoProb);
      60             : 
      61           4 :             THEN("the solutions match")
      62             :             {
      63           2 :                 REQUIRE_UNARY(isApprox(admm.solve(100), fista.solve(100)));
      64             :             }
      65             :         }
      66             : 
      67           6 :         WHEN("setting up ADMM to solve a WLSProblem + L0PseudoNorm")
      68             :         {
      69           4 :             L0PseudoNorm<data_t> regFunc(volDescr);
      70           4 :             RegularizationTerm<data_t> regTerm(0.000001f, regFunc);
      71             : 
      72           4 :             SplittingProblem<data_t> splittingProblem(wlsProb.getDataTerm(), regTerm, constraint);
      73             : 
      74           4 :             ADMM<CG, HardThresholding, data_t> admm(splittingProblem);
      75             : 
      76           4 :             THEN("the solution doesn't throw, is not nan and is approximate to the b vector")
      77             :             {
      78           2 :                 REQUIRE_NOTHROW(admm.solve(10));
      79           2 :                 REQUIRE_UNARY(!std::isnan(admm.solve(20).squaredL2Norm()));
      80           2 :                 REQUIRE_UNARY(isApprox(admm.solve(20), dcB));
      81             :             }
      82             :         }
      83             :     }
      84           4 : }
      85             : 
      86             : TEST_SUITE_END();

Generated by: LCOV version 1.15