LCOV - code coverage report
Current view: top level - elsa/problems/tests - test_Problem.cpp (source / functions) Hit Total Coverage
Test: coverage-all.lcov Lines: 269 269 100.0 %
Date: 2022-08-25 03:05:39 Functions: 3 3 100.0 %

          Line data    Source code
       1             : /**
       2             :  * @file test_Problem.cpp
       3             :  *
       4             :  * @brief Tests for the Problem class
       5             :  *
       6             :  * @author David Frank - initial code
       7             :  * @author Tobias Lasser - rewrite
       8             :  */
       9             : 
      10             : #include "doctest/doctest.h"
      11             : 
      12             : #include <iostream>
      13             : #include <Logger.h>
      14             : #include "Problem.h"
      15             : #include "Identity.h"
      16             : #include "Scaling.h"
      17             : #include "LinearResidual.h"
      18             : #include "L2NormPow2.h"
      19             : #include "VolumeDescriptor.h"
      20             : #include "testHelpers.h"
      21             : #include "TypeCasts.hpp"
      22             : 
      23             : using namespace elsa;
      24             : using namespace doctest;
      25             : 
      26             : TEST_SUITE_BEGIN("problems");
      27             : 
      28             : TEST_CASE("Problem: Testing without regularization")
      29           4 : {
      30             :     // eliminate the timing info from console for the tests
      31           4 :     Logger::setLevel(Logger::LogLevel::WARN);
      32             : 
      33           4 :     GIVEN("some data term")
      34           4 :     {
      35           4 :         IndexVector_t numCoeff(2);
      36           4 :         numCoeff << 17, 23;
      37           4 :         VolumeDescriptor dd(numCoeff);
      38             : 
      39           4 :         RealVector_t scaling(dd.getNumberOfCoefficients());
      40           4 :         scaling.setRandom();
      41           4 :         DataContainer dcScaling(dd, scaling);
      42           4 :         Scaling scaleOp(dd, dcScaling);
      43             : 
      44           4 :         RealVector_t dataVec(dd.getNumberOfCoefficients());
      45           4 :         dataVec.setRandom();
      46           4 :         DataContainer dcData(dd, dataVec);
      47             : 
      48           4 :         LinearResidual linRes(scaleOp, dcData);
      49           4 :         L2NormPow2 func(linRes);
      50             : 
      51           4 :         WHEN("setting up the problem without x0")
      52           4 :         {
      53           2 :             Problem prob(func);
      54             : 
      55           2 :             THEN("the clone works correctly")
      56           2 :             {
      57           1 :                 auto probClone = prob.clone();
      58             : 
      59           1 :                 REQUIRE_NE(probClone.get(), &prob);
      60           1 :                 REQUIRE_EQ(*probClone, prob);
      61           1 :             }
      62             : 
      63           2 :             THEN("the problem behaves as expected")
      64           2 :             {
      65           1 :                 DataContainer dcZero(dd);
      66           1 :                 dcZero = 0;
      67           1 :                 REQUIRE_EQ(prob.getCurrentSolution(), dcZero);
      68             : 
      69           1 :                 REQUIRE_UNARY(checkApproxEq(prob.evaluate(), 0.5f * dataVec.squaredNorm()));
      70           1 :                 REQUIRE_UNARY(checkApproxEq(prob.getGradient(), -1.0f * dcScaling * dcData));
      71             : 
      72           1 :                 auto hessian = prob.getHessian();
      73           1 :                 auto result = hessian.apply(dcData);
      74         392 :                 for (index_t i = 0; i < result.getSize(); ++i)
      75           1 :                     REQUIRE_UNARY(checkApproxEq(result[i], scaling[i] * scaling[i] * dataVec[i]));
      76             : 
      77           1 :                 REQUIRE_UNARY(checkApproxEq(prob.getLipschitzConstant(100), 1.0f));
      78           1 :             }
      79           2 :         }
      80             : 
      81           4 :         WHEN("setting up the problem with x0")
      82           4 :         {
      83           2 :             RealVector_t x0Vec(dd.getNumberOfCoefficients());
      84           2 :             x0Vec.setRandom();
      85           2 :             DataContainer dcX0(dd, x0Vec);
      86             : 
      87           2 :             Problem prob(func, dcX0);
      88             : 
      89           2 :             THEN("the clone works correctly")
      90           2 :             {
      91           1 :                 auto probClone = prob.clone();
      92             : 
      93           1 :                 REQUIRE_NE(probClone.get(), &prob);
      94           1 :                 REQUIRE_EQ(*probClone, prob);
      95           1 :             }
      96             : 
      97           2 :             THEN("the problem behaves as expected")
      98           2 :             {
      99           1 :                 REQUIRE_EQ(prob.getCurrentSolution(), dcX0);
     100             : 
     101           1 :                 REQUIRE_UNARY(checkApproxEq(
     102           1 :                     prob.evaluate(), 0.5f
     103           1 :                                          * (scaling.array() * x0Vec.array() - dataVec.array())
     104           1 :                                                .matrix()
     105           1 :                                                .squaredNorm()));
     106             : 
     107           1 :                 DataContainer gradientDirect = dcScaling * (dcScaling * dcX0 - dcData);
     108           1 :                 auto gradient = prob.getGradient();
     109         392 :                 for (index_t i = 0; i < gradientDirect.getSize(); ++i)
     110           1 :                     REQUIRE_UNARY(checkApproxEq(gradient[i], gradientDirect[i]));
     111             : 
     112           1 :                 auto hessian = prob.getHessian();
     113           1 :                 auto result = hessian.apply(dcData);
     114         392 :                 for (index_t i = 0; i < result.getSize(); ++i)
     115           1 :                     REQUIRE_UNARY(checkApproxEq(result[i], scaling[i] * scaling[i] * dataVec[i]));
     116             : 
     117           1 :                 REQUIRE_UNARY(checkApproxEq(prob.getLipschitzConstant(100), 1.0f));
     118           1 :             }
     119           2 :         }
     120           4 :     }
     121           4 : }
     122             : 
     123             : TEST_CASE("Problem: Testing with one regularization term")
     124           5 : {
     125             :     // eliminate the timing info from console for the tests
     126           5 :     Logger::setLevel(Logger::LogLevel::WARN);
     127             : 
     128           5 :     GIVEN("some data term and some regularization term")
     129           5 :     {
     130             :         // least squares data term
     131           5 :         IndexVector_t numCoeff(2);
     132           5 :         numCoeff << 23, 47;
     133           5 :         VolumeDescriptor dd(numCoeff);
     134             : 
     135           5 :         RealVector_t scaling(dd.getNumberOfCoefficients());
     136           5 :         scaling.setRandom();
     137           5 :         DataContainer dcScaling(dd, scaling);
     138           5 :         Scaling scaleOp(dd, dcScaling);
     139             : 
     140           5 :         RealVector_t dataVec(dd.getNumberOfCoefficients());
     141           5 :         dataVec.setRandom();
     142           5 :         DataContainer dcData(dd, dataVec);
     143             : 
     144           5 :         LinearResidual linRes(scaleOp, dcData);
     145           5 :         L2NormPow2 func(linRes);
     146             : 
     147             :         // l2 norm regularization term
     148           5 :         L2NormPow2 regFunc(dd);
     149           5 :         real_t weight = 2.0;
     150           5 :         RegularizationTerm regTerm(weight, regFunc);
     151             : 
     152           5 :         WHEN("setting up the problem without x0")
     153           5 :         {
     154           2 :             Problem prob(func, regTerm);
     155             : 
     156           2 :             THEN("the clone works correctly")
     157           2 :             {
     158           1 :                 auto probClone = prob.clone();
     159             : 
     160           1 :                 REQUIRE_NE(probClone.get(), &prob);
     161           1 :                 REQUIRE_EQ(*probClone, prob);
     162           1 :             }
     163             : 
     164           2 :             THEN("the problem behaves as expected")
     165           2 :             {
     166           1 :                 DataContainer dcZero(dd);
     167           1 :                 dcZero = 0;
     168           1 :                 REQUIRE_UNARY(checkApproxEq(prob.getCurrentSolution(), dcZero));
     169             : 
     170           1 :                 REQUIRE_UNARY(checkApproxEq(prob.evaluate(), 0.5f * dataVec.squaredNorm()));
     171           1 :                 REQUIRE_UNARY(checkApproxEq(prob.getGradient(), -1.0f * dcScaling * dcData));
     172             : 
     173           1 :                 auto hessian = prob.getHessian();
     174           1 :                 auto result = hessian.apply(dcData);
     175        1082 :                 for (index_t i = 0; i < result.getSize(); ++i)
     176           1 :                     REQUIRE_UNARY(checkApproxEq(result[i], scaling[i] * scaling[i] * dataVec[i]
     177           1 :                                                                + weight * dataVec[i]));
     178             : 
     179           1 :                 REQUIRE_UNARY(checkApproxEq(prob.getLipschitzConstant(100), 1.0f + weight));
     180           1 :             }
     181           2 :         }
     182             : 
     183           5 :         WHEN("setting up the problem with x0")
     184           5 :         {
     185           2 :             RealVector_t x0Vec(dd.getNumberOfCoefficients());
     186           2 :             x0Vec.setRandom();
     187           2 :             DataContainer dcX0(dd, x0Vec);
     188             : 
     189           2 :             Problem prob(func, regTerm, dcX0);
     190             : 
     191           2 :             THEN("the clone works correctly")
     192           2 :             {
     193           1 :                 auto probClone = prob.clone();
     194             : 
     195           1 :                 REQUIRE_NE(probClone.get(), &prob);
     196           1 :                 REQUIRE_EQ(*probClone, prob);
     197           1 :             }
     198             : 
     199           2 :             THEN("the problem behaves as expected")
     200           2 :             {
     201           1 :                 REQUIRE_EQ(prob.getCurrentSolution(), dcX0);
     202             : 
     203           1 :                 auto valueData =
     204           1 :                     0.5f
     205           1 :                     * (scaling.array() * x0Vec.array() - dataVec.array()).matrix().squaredNorm();
     206           1 :                 REQUIRE_UNARY(checkApproxEq(prob.evaluate(),
     207           1 :                                             valueData + weight * 0.5f * x0Vec.squaredNorm()));
     208             : 
     209           1 :                 DataContainer gradientDirect =
     210           1 :                     dcScaling * (dcScaling * dcX0 - dcData) + weight * dcX0;
     211           1 :                 auto gradient = prob.getGradient();
     212        1082 :                 for (index_t i = 0; i < gradient.getSize(); ++i)
     213           1 :                     REQUIRE_UNARY(checkApproxEq(gradient[i], gradientDirect[i]));
     214             : 
     215           1 :                 auto hessian = prob.getHessian();
     216           1 :                 auto result = hessian.apply(dcData);
     217        1082 :                 for (index_t i = 0; i < result.getSize(); ++i)
     218           1 :                     REQUIRE_UNARY(checkApproxEq(result[i], scaling[i] * scaling[i] * dataVec[i]
     219           1 :                                                                + weight * dataVec[i]));
     220             : 
     221           1 :                 REQUIRE_UNARY(checkApproxEq(prob.getLipschitzConstant(100), 1.0f + weight));
     222           1 :             }
     223           2 :         }
     224             : 
     225           5 :         WHEN("given a different data descriptor and another regularization term with a different "
     226           5 :              "domain descriptor")
     227           5 :         {
     228             :             // three-dimensional data descriptor
     229           1 :             IndexVector_t otherNumCoeff(3);
     230           1 :             otherNumCoeff << 15, 38, 22;
     231           1 :             VolumeDescriptor otherDD(otherNumCoeff);
     232             : 
     233             :             // l2 norm regularization term
     234           1 :             L2NormPow2 otherRegFunc(otherDD);
     235           1 :             RegularizationTerm otherRegTerm(weight, otherRegFunc);
     236             : 
     237           1 :             THEN("no exception is thrown when setting up a problem with different domain "
     238           1 :                  "descriptors")
     239           1 :             {
     240           1 :                 REQUIRE_NOTHROW(Problem{func, otherRegTerm});
     241           1 :             }
     242           1 :         }
     243           5 :     }
     244           5 : }
     245             : 
     246             : TEST_CASE("Problem: Testing with several regularization terms")
     247           5 : {
     248             :     // eliminate the timing info from console for the tests
     249           5 :     Logger::setLevel(Logger::LogLevel::WARN);
     250             : 
     251           5 :     GIVEN("some data term and several regularization terms")
     252           5 :     {
     253             :         // least squares data term
     254           5 :         IndexVector_t numCoeff(3);
     255           5 :         numCoeff << 17, 33, 52;
     256           5 :         VolumeDescriptor dd(numCoeff);
     257             : 
     258           5 :         RealVector_t scaling(dd.getNumberOfCoefficients());
     259           5 :         scaling.setRandom();
     260           5 :         DataContainer dcScaling(dd, scaling);
     261           5 :         Scaling scaleOp(dd, dcScaling);
     262             : 
     263           5 :         RealVector_t dataVec(dd.getNumberOfCoefficients());
     264           5 :         dataVec.setRandom();
     265           5 :         DataContainer dcData(dd, dataVec);
     266             : 
     267           5 :         LinearResidual linRes(scaleOp, dcData);
     268           5 :         L2NormPow2 func(linRes);
     269             : 
     270             :         // l2 norm regularization term
     271           5 :         L2NormPow2 regFunc(dd);
     272           5 :         real_t weight1 = 2.0;
     273           5 :         RegularizationTerm regTerm1(weight1, regFunc);
     274             : 
     275           5 :         real_t weight2 = 3.0;
     276           5 :         RegularizationTerm regTerm2(weight2, regFunc);
     277             : 
     278           5 :         std::vector<RegularizationTerm<real_t>> vecReg{regTerm1, regTerm2};
     279             : 
     280           5 :         WHEN("setting up the problem without x0")
     281           5 :         {
     282           2 :             Problem prob(func, vecReg);
     283             : 
     284           2 :             THEN("the clone works correctly")
     285           2 :             {
     286           1 :                 auto probClone = prob.clone();
     287             : 
     288           1 :                 REQUIRE_NE(probClone.get(), &prob);
     289           1 :                 REQUIRE_EQ(*probClone, prob);
     290           1 :             }
     291             : 
     292           2 :             THEN("the problem behaves as expected")
     293           2 :             {
     294           1 :                 DataContainer dcZero(dd);
     295           1 :                 dcZero = 0;
     296           1 :                 REQUIRE_UNARY(checkApproxEq(prob.getCurrentSolution(), dcZero));
     297             : 
     298           1 :                 REQUIRE_UNARY(checkApproxEq(prob.evaluate(), 0.5f * dataVec.squaredNorm()));
     299           1 :                 REQUIRE_UNARY(checkApproxEq(prob.getGradient(), -1.0f * dcScaling * dcData));
     300             : 
     301           1 :                 auto hessian = prob.getHessian();
     302           1 :                 auto result = hessian.apply(dcData);
     303       29173 :                 for (index_t i = 0; i < result.getSize(); ++i)
     304           1 :                     REQUIRE_UNARY(checkApproxEq(result[i], scaling[i] * scaling[i] * dataVec[i]
     305           1 :                                                                + weight1 * dataVec[i]
     306           1 :                                                                + weight2 * dataVec[i]));
     307             : 
     308           1 :                 REQUIRE_UNARY(
     309           1 :                     checkApproxEq(prob.getLipschitzConstant(100), 1.0f + weight1 + weight2));
     310           1 :             }
     311           2 :         }
     312             : 
     313           5 :         WHEN("setting up the problem with x0")
     314           5 :         {
     315           2 :             RealVector_t x0Vec(dd.getNumberOfCoefficients());
     316           2 :             x0Vec.setRandom();
     317           2 :             DataContainer dcX0(dd, x0Vec);
     318             : 
     319           2 :             Problem prob(func, vecReg, dcX0);
     320             : 
     321           2 :             THEN("the clone works correctly")
     322           2 :             {
     323           1 :                 auto probClone = prob.clone();
     324             : 
     325           1 :                 REQUIRE_NE(probClone.get(), &prob);
     326           1 :                 REQUIRE_EQ(*probClone, prob);
     327           1 :             }
     328             : 
     329           2 :             THEN("the problem behaves as expected")
     330           2 :             {
     331           1 :                 REQUIRE_UNARY(isApprox(prob.getCurrentSolution(), dcX0));
     332             : 
     333           1 :                 auto valueData =
     334           1 :                     0.5f
     335           1 :                     * (scaling.array() * x0Vec.array() - dataVec.array()).matrix().squaredNorm();
     336           1 :                 REQUIRE_UNARY(
     337           1 :                     checkApproxEq(prob.evaluate(), valueData + weight1 * 0.5f * x0Vec.squaredNorm()
     338           1 :                                                        + weight2 * 0.5f * x0Vec.squaredNorm()));
     339             : 
     340           1 :                 auto gradient = prob.getGradient();
     341           1 :                 DataContainer gradientDirect =
     342           1 :                     dcScaling * (dcScaling * dcX0 - dcData) + weight1 * dcX0 + weight2 * dcX0;
     343       29173 :                 for (index_t i = 0; i < gradient.getSize(); ++i)
     344           1 :                     REQUIRE_UNARY(checkApproxEq(gradient[i], gradientDirect[i]));
     345             : 
     346           1 :                 auto hessian = prob.getHessian();
     347           1 :                 auto result = hessian.apply(dcData);
     348       29173 :                 for (index_t i = 0; i < result.getSize(); ++i)
     349           1 :                     REQUIRE_UNARY(checkApproxEq(result[i], scaling[i] * scaling[i] * dataVec[i]
     350           1 :                                                                + weight1 * dataVec[i]
     351           1 :                                                                + weight2 * dataVec[i]));
     352             : 
     353           1 :                 REQUIRE_UNARY(
     354           1 :                     checkApproxEq(prob.getLipschitzConstant(100), 1.0f + weight1 + weight2));
     355           1 :             }
     356           2 :         }
     357             : 
     358           5 :         WHEN("given two different data descriptors and two regularization terms with a different "
     359           5 :              "domain descriptor")
     360           5 :         {
     361             :             // three-dimensional data descriptor
     362           1 :             IndexVector_t otherNumCoeff(3);
     363           1 :             otherNumCoeff << 15, 38, 22;
     364           1 :             VolumeDescriptor otherDD(otherNumCoeff);
     365             : 
     366             :             // four-dimensional data descriptor
     367           1 :             IndexVector_t anotherNumCoeff(4);
     368           1 :             anotherNumCoeff << 7, 9, 21, 17;
     369           1 :             VolumeDescriptor anotherDD(anotherNumCoeff);
     370             : 
     371             :             // l2 norm regularization term
     372           1 :             L2NormPow2 otherRegFunc(otherDD);
     373           1 :             RegularizationTerm otherRegTerm(weight1, otherRegFunc);
     374             : 
     375             :             // l2 norm regularization term
     376           1 :             L2NormPow2 anotherRegFunc(anotherDD);
     377           1 :             RegularizationTerm anotherRegTerm(weight2, anotherRegFunc);
     378             : 
     379           1 :             THEN("no exception is thrown when setting up a problem with different domain "
     380           1 :                  "descriptors")
     381           1 :             {
     382           1 :                 REQUIRE_NOTHROW(Problem{func, std::vector{otherRegTerm, anotherRegTerm}});
     383           1 :             }
     384           1 :         }
     385           5 :     }
     386           5 : }
     387             : 
     388             : TEST_SUITE_END();

Generated by: LCOV version 1.14