LCOV - code coverage report
Current view: top level - problems/tests - test_Problem.cpp (source / functions) Hit Total Coverage
Test: test_coverage.info.cleaned Lines: 194 194 100.0 %
Date: 2022-02-28 03:37:41 Functions: 3 3 100.0 %

          Line data    Source code
       1             : /**
       2             :  * @file test_Problem.cpp
       3             :  *
       4             :  * @brief Tests for the Problem class
       5             :  *
       6             :  * @author David Frank - initial code
       7             :  * @author Tobias Lasser - rewrite
       8             :  */
       9             : 
      10             : #include "doctest/doctest.h"
      11             : 
      12             : #include <iostream>
      13             : #include <Logger.h>
      14             : #include "Problem.h"
      15             : #include "Identity.h"
      16             : #include "Scaling.h"
      17             : #include "LinearResidual.h"
      18             : #include "L2NormPow2.h"
      19             : #include "VolumeDescriptor.h"
      20             : #include "testHelpers.h"
      21             : #include "TypeCasts.hpp"
      22             : 
      23             : using namespace elsa;
      24             : using namespace doctest;
      25             : 
      26             : TEST_SUITE_BEGIN("problems");
      27             : 
      28           4 : TEST_CASE("Problem: Testing without regularization")
      29             : {
      30             :     // eliminate the timing info from console for the tests
      31           4 :     Logger::setLevel(Logger::LogLevel::WARN);
      32             : 
      33           8 :     GIVEN("some data term")
      34             :     {
      35           8 :         IndexVector_t numCoeff(2);
      36           4 :         numCoeff << 17, 23;
      37           8 :         VolumeDescriptor dd(numCoeff);
      38             : 
      39           8 :         RealVector_t scaling(dd.getNumberOfCoefficients());
      40           4 :         scaling.setRandom();
      41           8 :         DataContainer dcScaling(dd, scaling);
      42           8 :         Scaling scaleOp(dd, dcScaling);
      43             : 
      44           8 :         RealVector_t dataVec(dd.getNumberOfCoefficients());
      45           4 :         dataVec.setRandom();
      46           8 :         DataContainer dcData(dd, dataVec);
      47             : 
      48           8 :         LinearResidual linRes(scaleOp, dcData);
      49           8 :         L2NormPow2 func(linRes);
      50             : 
      51           6 :         WHEN("setting up the problem without x0")
      52             :         {
      53           4 :             Problem prob(func);
      54             : 
      55           3 :             THEN("the clone works correctly")
      56             :             {
      57           2 :                 auto probClone = prob.clone();
      58             : 
      59           1 :                 REQUIRE_NE(probClone.get(), &prob);
      60           1 :                 REQUIRE_EQ(*probClone, prob);
      61             :             }
      62             : 
      63           3 :             THEN("the problem behaves as expected")
      64             :             {
      65           2 :                 DataContainer dcZero(dd);
      66           1 :                 dcZero = 0;
      67           1 :                 REQUIRE_EQ(prob.getCurrentSolution(), dcZero);
      68             : 
      69           1 :                 REQUIRE_UNARY(checkApproxEq(prob.evaluate(), 0.5f * dataVec.squaredNorm()));
      70           1 :                 REQUIRE_UNARY(checkApproxEq(prob.getGradient(), -1.0f * dcScaling * dcData));
      71             : 
      72           2 :                 auto hessian = prob.getHessian();
      73           1 :                 auto result = hessian.apply(dcData);
      74         392 :                 for (index_t i = 0; i < result.getSize(); ++i)
      75         391 :                     REQUIRE_UNARY(checkApproxEq(result[i], scaling[i] * scaling[i] * dataVec[i]));
      76             : 
      77           1 :                 REQUIRE_UNARY(checkApproxEq(prob.getLipschitzConstant(100), 1.0f));
      78             :             }
      79             :         }
      80             : 
      81           6 :         WHEN("setting up the problem with x0")
      82             :         {
      83           4 :             RealVector_t x0Vec(dd.getNumberOfCoefficients());
      84           2 :             x0Vec.setRandom();
      85           4 :             DataContainer dcX0(dd, x0Vec);
      86             : 
      87           4 :             Problem prob(func, dcX0);
      88             : 
      89           3 :             THEN("the clone works correctly")
      90             :             {
      91           2 :                 auto probClone = prob.clone();
      92             : 
      93           1 :                 REQUIRE_NE(probClone.get(), &prob);
      94           1 :                 REQUIRE_EQ(*probClone, prob);
      95             :             }
      96             : 
      97           3 :             THEN("the problem behaves as expected")
      98             :             {
      99           1 :                 REQUIRE_EQ(prob.getCurrentSolution(), dcX0);
     100             : 
     101           1 :                 REQUIRE_UNARY(checkApproxEq(
     102             :                     prob.evaluate(), 0.5f
     103             :                                          * (scaling.array() * x0Vec.array() - dataVec.array())
     104             :                                                .matrix()
     105             :                                                .squaredNorm()));
     106             : 
     107           2 :                 DataContainer gradientDirect = dcScaling * (dcScaling * dcX0 - dcData);
     108           2 :                 auto gradient = prob.getGradient();
     109         392 :                 for (index_t i = 0; i < gradientDirect.getSize(); ++i)
     110         391 :                     REQUIRE_UNARY(checkApproxEq(gradient[i], gradientDirect[i]));
     111             : 
     112           2 :                 auto hessian = prob.getHessian();
     113           1 :                 auto result = hessian.apply(dcData);
     114         392 :                 for (index_t i = 0; i < result.getSize(); ++i)
     115         391 :                     REQUIRE_UNARY(checkApproxEq(result[i], scaling[i] * scaling[i] * dataVec[i]));
     116             : 
     117           1 :                 REQUIRE_UNARY(checkApproxEq(prob.getLipschitzConstant(100), 1.0f));
     118             :             }
     119             :         }
     120             :     }
     121           4 : }
     122             : 
     123           5 : TEST_CASE("Problem: Testing with one regularization term")
     124             : {
     125             :     // eliminate the timing info from console for the tests
     126           5 :     Logger::setLevel(Logger::LogLevel::WARN);
     127             : 
     128          10 :     GIVEN("some data term and some regularization term")
     129             :     {
     130             :         // least squares data term
     131          10 :         IndexVector_t numCoeff(2);
     132           5 :         numCoeff << 23, 47;
     133          10 :         VolumeDescriptor dd(numCoeff);
     134             : 
     135          10 :         RealVector_t scaling(dd.getNumberOfCoefficients());
     136           5 :         scaling.setRandom();
     137          10 :         DataContainer dcScaling(dd, scaling);
     138          10 :         Scaling scaleOp(dd, dcScaling);
     139             : 
     140          10 :         RealVector_t dataVec(dd.getNumberOfCoefficients());
     141           5 :         dataVec.setRandom();
     142          10 :         DataContainer dcData(dd, dataVec);
     143             : 
     144          10 :         LinearResidual linRes(scaleOp, dcData);
     145          10 :         L2NormPow2 func(linRes);
     146             : 
     147             :         // l2 norm regularization term
     148          10 :         L2NormPow2 regFunc(dd);
     149           5 :         real_t weight = 2.0;
     150          10 :         RegularizationTerm regTerm(weight, regFunc);
     151             : 
     152           7 :         WHEN("setting up the problem without x0")
     153             :         {
     154           4 :             Problem prob(func, regTerm);
     155             : 
     156           3 :             THEN("the clone works correctly")
     157             :             {
     158           2 :                 auto probClone = prob.clone();
     159             : 
     160           1 :                 REQUIRE_NE(probClone.get(), &prob);
     161           1 :                 REQUIRE_EQ(*probClone, prob);
     162             :             }
     163             : 
     164           3 :             THEN("the problem behaves as expected")
     165             :             {
     166           2 :                 DataContainer dcZero(dd);
     167           1 :                 dcZero = 0;
     168           1 :                 REQUIRE_UNARY(checkApproxEq(prob.getCurrentSolution(), dcZero));
     169             : 
     170           1 :                 REQUIRE_UNARY(checkApproxEq(prob.evaluate(), 0.5f * dataVec.squaredNorm()));
     171           1 :                 REQUIRE_UNARY(checkApproxEq(prob.getGradient(), -1.0f * dcScaling * dcData));
     172             : 
     173           2 :                 auto hessian = prob.getHessian();
     174           1 :                 auto result = hessian.apply(dcData);
     175        1082 :                 for (index_t i = 0; i < result.getSize(); ++i)
     176        1081 :                     REQUIRE_UNARY(checkApproxEq(result[i], scaling[i] * scaling[i] * dataVec[i]
     177             :                                                                + weight * dataVec[i]));
     178             : 
     179           1 :                 REQUIRE_UNARY(checkApproxEq(prob.getLipschitzConstant(100), 1.0f + weight));
     180             :             }
     181             :         }
     182             : 
     183           7 :         WHEN("setting up the problem with x0")
     184             :         {
     185           4 :             RealVector_t x0Vec(dd.getNumberOfCoefficients());
     186           2 :             x0Vec.setRandom();
     187           4 :             DataContainer dcX0(dd, x0Vec);
     188             : 
     189           4 :             Problem prob(func, regTerm, dcX0);
     190             : 
     191           3 :             THEN("the clone works correctly")
     192             :             {
     193           2 :                 auto probClone = prob.clone();
     194             : 
     195           1 :                 REQUIRE_NE(probClone.get(), &prob);
     196           1 :                 REQUIRE_EQ(*probClone, prob);
     197             :             }
     198             : 
     199           3 :             THEN("the problem behaves as expected")
     200             :             {
     201           1 :                 REQUIRE_EQ(prob.getCurrentSolution(), dcX0);
     202             : 
     203             :                 auto valueData =
     204             :                     0.5f
     205           1 :                     * (scaling.array() * x0Vec.array() - dataVec.array()).matrix().squaredNorm();
     206           1 :                 REQUIRE_UNARY(checkApproxEq(prob.evaluate(),
     207             :                                             valueData + weight * 0.5f * x0Vec.squaredNorm()));
     208             : 
     209             :                 DataContainer gradientDirect =
     210           2 :                     dcScaling * (dcScaling * dcX0 - dcData) + weight * dcX0;
     211           2 :                 auto gradient = prob.getGradient();
     212        1082 :                 for (index_t i = 0; i < gradient.getSize(); ++i)
     213        1081 :                     REQUIRE_UNARY(checkApproxEq(gradient[i], gradientDirect[i]));
     214             : 
     215           2 :                 auto hessian = prob.getHessian();
     216           1 :                 auto result = hessian.apply(dcData);
     217        1082 :                 for (index_t i = 0; i < result.getSize(); ++i)
     218        1081 :                     REQUIRE_UNARY(checkApproxEq(result[i], scaling[i] * scaling[i] * dataVec[i]
     219             :                                                                + weight * dataVec[i]));
     220             : 
     221           1 :                 REQUIRE_UNARY(checkApproxEq(prob.getLipschitzConstant(100), 1.0f + weight));
     222             :             }
     223             :         }
     224             : 
     225           6 :         WHEN("given a different data descriptor and another regularization term with a different "
     226             :              "domain descriptor")
     227             :         {
     228             :             // three-dimensional data descriptor
     229           2 :             IndexVector_t otherNumCoeff(3);
     230           1 :             otherNumCoeff << 15, 38, 22;
     231           2 :             VolumeDescriptor otherDD(otherNumCoeff);
     232             : 
     233             :             // l2 norm regularization term
     234           2 :             L2NormPow2 otherRegFunc(otherDD);
     235           2 :             RegularizationTerm otherRegTerm(weight, otherRegFunc);
     236             : 
     237           2 :             THEN("no exception is thrown when setting up a problem with different domain "
     238             :                  "descriptors")
     239             :             {
     240           1 :                 REQUIRE_NOTHROW(Problem{func, otherRegTerm});
     241             :             }
     242             :         }
     243             :     }
     244           5 : }
     245             : 
     246           5 : TEST_CASE("Problem: Testing with several regularization terms")
     247             : {
     248             :     // eliminate the timing info from console for the tests
     249           5 :     Logger::setLevel(Logger::LogLevel::WARN);
     250             : 
     251          10 :     GIVEN("some data term and several regularization terms")
     252             :     {
     253             :         // least squares data term
     254          10 :         IndexVector_t numCoeff(3);
     255           5 :         numCoeff << 17, 33, 52;
     256          10 :         VolumeDescriptor dd(numCoeff);
     257             : 
     258          10 :         RealVector_t scaling(dd.getNumberOfCoefficients());
     259           5 :         scaling.setRandom();
     260          10 :         DataContainer dcScaling(dd, scaling);
     261          10 :         Scaling scaleOp(dd, dcScaling);
     262             : 
     263          10 :         RealVector_t dataVec(dd.getNumberOfCoefficients());
     264           5 :         dataVec.setRandom();
     265          10 :         DataContainer dcData(dd, dataVec);
     266             : 
     267          10 :         LinearResidual linRes(scaleOp, dcData);
     268          10 :         L2NormPow2 func(linRes);
     269             : 
     270             :         // l2 norm regularization term
     271          10 :         L2NormPow2 regFunc(dd);
     272           5 :         real_t weight1 = 2.0;
     273          10 :         RegularizationTerm regTerm1(weight1, regFunc);
     274             : 
     275           5 :         real_t weight2 = 3.0;
     276          10 :         RegularizationTerm regTerm2(weight2, regFunc);
     277             : 
     278          25 :         std::vector<RegularizationTerm<real_t>> vecReg{regTerm1, regTerm2};
     279             : 
     280           7 :         WHEN("setting up the problem without x0")
     281             :         {
     282           4 :             Problem prob(func, vecReg);
     283             : 
     284           3 :             THEN("the clone works correctly")
     285             :             {
     286           2 :                 auto probClone = prob.clone();
     287             : 
     288           1 :                 REQUIRE_NE(probClone.get(), &prob);
     289           1 :                 REQUIRE_EQ(*probClone, prob);
     290             :             }
     291             : 
     292           3 :             THEN("the problem behaves as expected")
     293             :             {
     294           2 :                 DataContainer dcZero(dd);
     295           1 :                 dcZero = 0;
     296           1 :                 REQUIRE_UNARY(checkApproxEq(prob.getCurrentSolution(), dcZero));
     297             : 
     298           1 :                 REQUIRE_UNARY(checkApproxEq(prob.evaluate(), 0.5f * dataVec.squaredNorm()));
     299           1 :                 REQUIRE_UNARY(checkApproxEq(prob.getGradient(), -1.0f * dcScaling * dcData));
     300             : 
     301           2 :                 auto hessian = prob.getHessian();
     302           1 :                 auto result = hessian.apply(dcData);
     303       29173 :                 for (index_t i = 0; i < result.getSize(); ++i)
     304       29172 :                     REQUIRE_UNARY(checkApproxEq(result[i], scaling[i] * scaling[i] * dataVec[i]
     305             :                                                                + weight1 * dataVec[i]
     306             :                                                                + weight2 * dataVec[i]));
     307             : 
     308           1 :                 REQUIRE_UNARY(
     309             :                     checkApproxEq(prob.getLipschitzConstant(100), 1.0f + weight1 + weight2));
     310             :             }
     311             :         }
     312             : 
     313           7 :         WHEN("setting up the problem with x0")
     314             :         {
     315           4 :             RealVector_t x0Vec(dd.getNumberOfCoefficients());
     316           2 :             x0Vec.setRandom();
     317           4 :             DataContainer dcX0(dd, x0Vec);
     318             : 
     319           4 :             Problem prob(func, vecReg, dcX0);
     320             : 
     321           3 :             THEN("the clone works correctly")
     322             :             {
     323           2 :                 auto probClone = prob.clone();
     324             : 
     325           1 :                 REQUIRE_NE(probClone.get(), &prob);
     326           1 :                 REQUIRE_EQ(*probClone, prob);
     327             :             }
     328             : 
     329           3 :             THEN("the problem behaves as expected")
     330             :             {
     331           1 :                 REQUIRE_UNARY(isApprox(prob.getCurrentSolution(), dcX0));
     332             : 
     333             :                 auto valueData =
     334             :                     0.5f
     335           1 :                     * (scaling.array() * x0Vec.array() - dataVec.array()).matrix().squaredNorm();
     336           1 :                 REQUIRE_UNARY(
     337             :                     checkApproxEq(prob.evaluate(), valueData + weight1 * 0.5f * x0Vec.squaredNorm()
     338             :                                                        + weight2 * 0.5f * x0Vec.squaredNorm()));
     339             : 
     340           2 :                 auto gradient = prob.getGradient();
     341             :                 DataContainer gradientDirect =
     342           2 :                     dcScaling * (dcScaling * dcX0 - dcData) + weight1 * dcX0 + weight2 * dcX0;
     343       29173 :                 for (index_t i = 0; i < gradient.getSize(); ++i)
     344       29172 :                     REQUIRE_UNARY(checkApproxEq(gradient[i], gradientDirect[i]));
     345             : 
     346           2 :                 auto hessian = prob.getHessian();
     347           1 :                 auto result = hessian.apply(dcData);
     348       29173 :                 for (index_t i = 0; i < result.getSize(); ++i)
     349       29172 :                     REQUIRE_UNARY(checkApproxEq(result[i], scaling[i] * scaling[i] * dataVec[i]
     350             :                                                                + weight1 * dataVec[i]
     351             :                                                                + weight2 * dataVec[i]));
     352             : 
     353           1 :                 REQUIRE_UNARY(
     354             :                     checkApproxEq(prob.getLipschitzConstant(100), 1.0f + weight1 + weight2));
     355             :             }
     356             :         }
     357             : 
     358           6 :         WHEN("given two different data descriptors and two regularization terms with a different "
     359             :              "domain descriptor")
     360             :         {
     361             :             // three-dimensional data descriptor
     362           2 :             IndexVector_t otherNumCoeff(3);
     363           1 :             otherNumCoeff << 15, 38, 22;
     364           2 :             VolumeDescriptor otherDD(otherNumCoeff);
     365             : 
     366             :             // four-dimensional data descriptor
     367           2 :             IndexVector_t anotherNumCoeff(4);
     368           1 :             anotherNumCoeff << 7, 9, 21, 17;
     369           2 :             VolumeDescriptor anotherDD(anotherNumCoeff);
     370             : 
     371             :             // l2 norm regularization term
     372           2 :             L2NormPow2 otherRegFunc(otherDD);
     373           2 :             RegularizationTerm otherRegTerm(weight1, otherRegFunc);
     374             : 
     375             :             // l2 norm regularization term
     376           2 :             L2NormPow2 anotherRegFunc(anotherDD);
     377           2 :             RegularizationTerm anotherRegTerm(weight2, anotherRegFunc);
     378             : 
     379           2 :             THEN("no exception is thrown when setting up a problem with different domain "
     380             :                  "descriptors")
     381             :             {
     382           3 :                 REQUIRE_NOTHROW(Problem{func, std::vector{otherRegTerm, anotherRegTerm}});
     383             :             }
     384             :         }
     385             :     }
     386           5 : }
     387             : 
     388             : TEST_SUITE_END();

Generated by: LCOV version 1.15