LCOV - code coverage report
Current view: top level - problems/tests - test_TikhonovProblem.cpp (source / functions) Hit Total Coverage
Test: test_coverage.info.cleaned Lines: 151 151 100.0 %
Date: 2022-02-28 03:37:41 Functions: 10 10 100.0 %

          Line data    Source code
       1             : /**
       2             :  * @file test_TikhonovProblem.cpp
       3             :  *
       4             :  * @brief Tests for the TikhonovProblem class
       5             :  *
       6             :  * @author Nikola Dinev
       7             :  */
       8             : 
       9             : #include "doctest/doctest.h"
      10             : 
      11             : #include "Problem.h"
      12             : #include "Identity.h"
      13             : #include "Scaling.h"
      14             : #include "LinearResidual.h"
      15             : #include "L2NormPow2.h"
      16             : #include "L1Norm.h"
      17             : #include "TikhonovProblem.h"
      18             : #include "VolumeDescriptor.h"
      19             : #include "Logger.h"
      20             : #include "testHelpers.h"
      21             : 
      22             : using namespace elsa;
      23             : using namespace doctest;
      24             : 
      25             : TEST_SUITE_BEGIN("problems");
      26             : 
      27          28 : TEST_CASE_TEMPLATE("TikhonovProblem: Testing with one regularization term", data_t, float, double)
      28             : {
      29          16 :     Logger::setLevel(Logger::LogLevel::WARN);
      30             : 
      31          32 :     GIVEN("some data term and some regularization term")
      32             :     {
      33             :         // least squares data term
      34          32 :         IndexVector_t numCoeff(2);
      35          16 :         numCoeff << 23, 47;
      36          32 :         VolumeDescriptor dd(numCoeff);
      37             : 
      38          32 :         Vector_t<data_t> scaling(dd.getNumberOfCoefficients());
      39          16 :         scaling.setRandom();
      40          32 :         DataContainer<data_t> dcScaling(dd, scaling);
      41          32 :         Scaling scaleOp(dd, dcScaling);
      42             : 
      43          32 :         Vector_t<data_t> dataVec(dd.getNumberOfCoefficients());
      44          16 :         dataVec.setRandom();
      45          32 :         DataContainer<data_t> dcData(dd, dataVec);
      46             : 
      47          32 :         WLSProblem<data_t> wls(scaleOp, dcData);
      48             : 
      49             :         // l2 norm regularization term
      50          32 :         L2NormPow2<data_t> regFunc(dd);
      51          16 :         auto weight = data_t{2.0};
      52          32 :         RegularizationTerm<data_t> regTerm(weight, regFunc);
      53             : 
      54          18 :         WHEN("setting up a TikhonovProblem without regularization terms")
      55             :         {
      56           4 :             THEN("an exception is thrown")
      57             :             {
      58           6 :                 REQUIRE_THROWS_AS(
      59             :                     TikhonovProblem<data_t>(wls, std::vector<RegularizationTerm<data_t>>{}),
      60             :                     InvalidArgumentError);
      61             :             }
      62             :         }
      63             : 
      64          18 :         WHEN("setting up a TikhonovProblem with a non (Weighted)L2NormPow2 regularization term")
      65             :         {
      66           4 :             L1Norm<data_t> invalidRegFunc(dd);
      67           4 :             RegularizationTerm<data_t> invalidRegTerm(1.0, invalidRegFunc);
      68           4 :             THEN("an exception is thrown")
      69             :             {
      70           4 :                 REQUIRE_THROWS_AS(TikhonovProblem<data_t>(wls, invalidRegTerm),
      71             :                                   InvalidArgumentError);
      72             :             }
      73             :         }
      74             : 
      75          22 :         WHEN("setting up the TikhonovProblem without x0")
      76             :         {
      77          12 :             TikhonovProblem<data_t> prob(wls, regTerm);
      78             : 
      79           8 :             THEN("the clone works correctly")
      80             :             {
      81           4 :                 auto probClone = prob.clone();
      82             : 
      83           2 :                 REQUIRE_NE(probClone.get(), &prob);
      84           2 :                 REQUIRE_EQ(*probClone, prob);
      85             :             }
      86             : 
      87           8 :             THEN("the TikhonovProblem behaves as expected")
      88             :             {
      89           4 :                 DataContainer<data_t> dcZero(dd);
      90           2 :                 dcZero = 0;
      91           2 :                 REQUIRE_UNARY(isApprox(prob.getCurrentSolution(), dcZero));
      92             : 
      93           2 :                 REQUIRE_UNARY(
      94             :                     checkApproxEq(prob.evaluate(), as<data_t>(0.5) * dataVec.squaredNorm()));
      95           2 :                 REQUIRE_UNARY(isApprox(prob.getGradient(), as<data_t>(-1.0) * dcScaling * dcData));
      96             : 
      97           4 :                 auto hessian = prob.getHessian();
      98           4 :                 auto result = hessian.apply(dcData);
      99        2164 :                 for (index_t i = 0; i < result.getSize(); ++i)
     100        2162 :                     REQUIRE_UNARY(checkApproxEq(result[i], scaling[i] * scaling[i] * dataVec[i]
     101             :                                                                + weight * dataVec[i]));
     102             :             }
     103             : 
     104           8 :             THEN("the TikhonovProblem is different from a Problem with the same terms")
     105             :             {
     106           4 :                 Problem optProb(prob.getDataTerm(), prob.getRegularizationTerms());
     107           2 :                 REQUIRE_NE(prob, optProb);
     108           2 :                 REQUIRE_NE(optProb, prob);
     109             :             }
     110             :         }
     111             : 
     112          22 :         WHEN("setting up the TikhonovProblem with x0")
     113             :         {
     114          12 :             Vector_t<data_t> x0Vec(dd.getNumberOfCoefficients());
     115           6 :             x0Vec.setRandom();
     116          12 :             DataContainer<data_t> dcX0(dd, x0Vec);
     117             : 
     118           6 :             wls.getCurrentSolution() = dcX0;
     119          12 :             TikhonovProblem<data_t> prob(wls, regTerm);
     120             : 
     121           8 :             THEN("the clone works correctly")
     122             :             {
     123           4 :                 auto probClone = prob.clone();
     124             : 
     125           2 :                 REQUIRE_NE(probClone.get(), &prob);
     126           2 :                 REQUIRE_EQ(*probClone, prob);
     127             :             }
     128             : 
     129           8 :             THEN("the TikhonovProblem behaves as expected")
     130             :             {
     131           2 :                 REQUIRE_UNARY(isApprox(prob.getCurrentSolution(), dcX0));
     132             : 
     133           2 :                 auto valueData =
     134           2 :                     as<data_t>(0.5)
     135           2 :                     * (scaling.array() * x0Vec.array() - dataVec.array()).matrix().squaredNorm();
     136           2 :                 REQUIRE_UNARY(checkApproxEq(
     137             :                     prob.evaluate(), valueData + weight * as<data_t>(0.5) * x0Vec.squaredNorm()));
     138             : 
     139           4 :                 auto gradient = prob.getGradient();
     140           4 :                 DataContainer gradientDirect =
     141             :                     dcScaling * (dcScaling * dcX0 - dcData) + weight * dcX0;
     142             : 
     143        2164 :                 for (index_t i = 0; i < gradient.getSize(); ++i)
     144        2162 :                     REQUIRE_UNARY(checkApproxEq(gradient[i], gradientDirect[i]));
     145             : 
     146           4 :                 auto hessian = prob.getHessian();
     147           4 :                 auto result = hessian.apply(dcData);
     148        2164 :                 for (index_t i = 0; i < result.getSize(); ++i)
     149        2162 :                     REQUIRE_UNARY(checkApproxEq(result[i], scaling[i] * scaling[i] * dataVec[i]
     150             :                                                                + weight * dataVec[i]));
     151             :             }
     152             : 
     153           8 :             THEN("the TikhonovProblem is different from a Problem with the same terms")
     154             :             {
     155           4 :                 Problem optProb(prob.getDataTerm(), prob.getRegularizationTerms());
     156           2 :                 REQUIRE_NE(prob, optProb);
     157           2 :                 REQUIRE_NE(optProb, prob);
     158             :             }
     159             :         }
     160             :     }
     161          16 : }
     162             : 
     163          26 : TEST_CASE_TEMPLATE("TikhonovProblem: Testing with several regularization terms", data_t, float,
     164             :                    double)
     165             : {
     166          14 :     Logger::setLevel(Logger::LogLevel::WARN);
     167             : 
     168          28 :     GIVEN("some data term and several regularization terms")
     169             :     {
     170             :         // least squares data term
     171          28 :         IndexVector_t numCoeff(3);
     172          14 :         numCoeff << 17, 33, 52;
     173          28 :         VolumeDescriptor dd(numCoeff);
     174             : 
     175          28 :         Vector_t<data_t> scaling(dd.getNumberOfCoefficients());
     176          14 :         scaling.setRandom();
     177          28 :         DataContainer<data_t> dcScaling(dd, scaling);
     178          28 :         Scaling scaleOp(dd, dcScaling);
     179             : 
     180          28 :         Vector_t<data_t> dataVec(dd.getNumberOfCoefficients());
     181          14 :         dataVec.setRandom();
     182          28 :         DataContainer<data_t> dcData(dd, dataVec);
     183             : 
     184          28 :         WLSProblem<data_t> wls(scaleOp, dcData);
     185             : 
     186             :         // l2 norm regularization term
     187          28 :         L2NormPow2<data_t> regFunc(dd);
     188          14 :         auto weight1 = data_t{2.0};
     189          28 :         RegularizationTerm<data_t> regTerm1(weight1, regFunc);
     190             : 
     191          14 :         auto weight2 = data_t{3.0};
     192          28 :         RegularizationTerm<data_t> regTerm2(weight2, regFunc);
     193             : 
     194          70 :         std::vector<RegularizationTerm<data_t>> vecReg{regTerm1, regTerm2};
     195             : 
     196          16 :         WHEN("setting up a TikhonovProblem with a non (Weighted)L2NormPow2 regularization term")
     197             :         {
     198           4 :             L1Norm<data_t> invalidRegFunc(dd);
     199           4 :             RegularizationTerm<data_t> invalidRegTerm(1.0, invalidRegFunc);
     200          10 :             std::vector<RegularizationTerm<data_t>> invalidVecReg1{regTerm1, invalidRegTerm};
     201          10 :             std::vector<RegularizationTerm<data_t>> invalidVecReg2{invalidRegTerm, regTerm2};
     202           4 :             THEN("an exception is thrown")
     203             :             {
     204           4 :                 REQUIRE_THROWS_AS(TikhonovProblem<data_t>(wls, invalidVecReg1),
     205             :                                   InvalidArgumentError);
     206           4 :                 REQUIRE_THROWS_AS(TikhonovProblem<data_t>(wls, invalidVecReg2),
     207             :                                   InvalidArgumentError);
     208             :             }
     209             :         }
     210             : 
     211          20 :         WHEN("setting up the TikhonovProblem without x0")
     212             :         {
     213          12 :             TikhonovProblem<data_t> prob(wls, vecReg);
     214             : 
     215           8 :             THEN("the clone works correctly")
     216             :             {
     217           4 :                 auto probClone = prob.clone();
     218             : 
     219           2 :                 REQUIRE_NE(probClone.get(), &prob);
     220           2 :                 REQUIRE_EQ(*probClone, prob);
     221             :             }
     222             : 
     223           8 :             THEN("the TikhonovProblem behaves as expected")
     224             :             {
     225           4 :                 DataContainer<data_t> dcZero(dd);
     226           2 :                 dcZero = 0;
     227           2 :                 REQUIRE_UNARY(isApprox(prob.getCurrentSolution(), dcZero));
     228             : 
     229           2 :                 REQUIRE_UNARY(
     230             :                     checkApproxEq(prob.evaluate(), as<data_t>(0.5) * dataVec.squaredNorm()));
     231           2 :                 REQUIRE_UNARY(isApprox(prob.getGradient(), as<data_t>(-1.0) * dcScaling * dcData));
     232             : 
     233           4 :                 auto hessian = prob.getHessian();
     234           4 :                 auto result = hessian.apply(dcData);
     235       58346 :                 for (index_t i = 0; i < result.getSize(); ++i)
     236       58344 :                     REQUIRE_UNARY(checkApproxEq(result[i], scaling[i] * scaling[i] * dataVec[i]
     237             :                                                                + weight1 * dataVec[i]
     238             :                                                                + weight2 * dataVec[i]));
     239             :             }
     240             : 
     241           8 :             THEN("the TikhonovProblem is different from a Problem with the same terms")
     242             :             {
     243           4 :                 Problem optProb(prob.getDataTerm(), prob.getRegularizationTerms());
     244           2 :                 REQUIRE_NE(prob, optProb);
     245           2 :                 REQUIRE_NE(optProb, prob);
     246             :             }
     247             :         }
     248             : 
     249          20 :         WHEN("setting up the TikhonovProblem with x0")
     250             :         {
     251          12 :             Vector_t<data_t> x0Vec(dd.getNumberOfCoefficients());
     252           6 :             x0Vec.setRandom();
     253          12 :             DataContainer<data_t> dcX0(dd, x0Vec);
     254             : 
     255           6 :             wls.getCurrentSolution() = dcX0;
     256          12 :             TikhonovProblem<data_t> prob(wls, vecReg);
     257             : 
     258           8 :             THEN("the clone works correctly")
     259             :             {
     260           4 :                 auto probClone = prob.clone();
     261             : 
     262           2 :                 REQUIRE_NE(probClone.get(), &prob);
     263           2 :                 REQUIRE_EQ(*probClone, prob);
     264             :             }
     265             : 
     266           8 :             THEN("the TikhonovProblem behaves as expected")
     267             :             {
     268           2 :                 REQUIRE_UNARY(isApprox(prob.getCurrentSolution(), dcX0));
     269             : 
     270           2 :                 auto valueData =
     271           2 :                     as<data_t>(0.5)
     272           2 :                     * (scaling.array() * x0Vec.array() - dataVec.array()).matrix().squaredNorm();
     273           2 :                 REQUIRE_UNARY(checkApproxEq(
     274             :                     prob.evaluate(), valueData + weight1 * as<data_t>(0.5) * x0Vec.squaredNorm()
     275             :                                          + weight2 * as<data_t>(0.5) * x0Vec.squaredNorm()));
     276             : 
     277           4 :                 auto gradient = prob.getGradient();
     278           4 :                 DataContainer gradientDirect =
     279             :                     dcScaling * (dcScaling * dcX0 - dcData) + weight1 * dcX0 + weight2 * dcX0;
     280       58346 :                 for (index_t i = 0; i < gradient.getSize(); ++i)
     281       58344 :                     REQUIRE_UNARY(checkApproxEq(gradient[i], gradientDirect[i]));
     282             : 
     283           4 :                 auto hessian = prob.getHessian();
     284           4 :                 auto result = hessian.apply(dcData);
     285       58346 :                 for (index_t i = 0; i < result.getSize(); ++i)
     286       58344 :                     REQUIRE_UNARY(checkApproxEq(result[i], scaling[i] * scaling[i] * dataVec[i]
     287             :                                                                + weight1 * dataVec[i]
     288             :                                                                + weight2 * dataVec[i]));
     289             :             }
     290             : 
     291           8 :             THEN("the TikhonovProblem is different from a Problem with the same terms")
     292             :             {
     293           4 :                 Problem optProb(prob.getDataTerm(), prob.getRegularizationTerms());
     294           2 :                 REQUIRE_NE(prob, optProb);
     295           2 :                 REQUIRE_NE(optProb, prob);
     296             :             }
     297             :         }
     298             :     }
     299          14 : }
     300             : 
     301             : TEST_SUITE_END();

Generated by: LCOV version 1.15