LCOV - code coverage report
Current view: top level - solvers/tests - test_OGM.cpp (source / functions) Hit Total Coverage
Test: test_coverage.info.cleaned Lines: 104 104 100.0 %
Date: 2022-02-28 03:37:41 Functions: 13 13 100.0 %

          Line data    Source code
       1             : /**
       2             :  * @file test_OGM.cpp
       3             :  *
       4             :  * @brief Tests for the Optimized Gradient Method class
       5             :  *
       6             :  * @author Michael Loipführer - initial code
       7             :  */
       8             : 
       9             : #include "doctest/doctest.h"
      10             : 
      11             : #include <iostream>
      12             : #include "OGM.h"
      13             : #include "WLSProblem.h"
      14             : #include "Problem.h"
      15             : #include "Identity.h"
      16             : #include "LinearResidual.h"
      17             : #include "L2NormPow2.h"
      18             : #include "Logger.h"
      19             : #include "VolumeDescriptor.h"
      20             : #include "SiddonsMethod.h"
      21             : #include "CircleTrajectoryGenerator.h"
      22             : #include "PhantomGenerator.h"
      23             : #include "TypeCasts.hpp"
      24             : #include "testHelpers.h"
      25             : 
      26             : using namespace elsa;
      27             : using namespace doctest;
      28             : 
      29             : TEST_SUITE_BEGIN("solvers");
      30             : 
      31          10 : TYPE_TO_STRING(OGM<float>);
      32          10 : TYPE_TO_STRING(OGM<double>);
      33             : 
      34             : template <template <typename> typename T, typename data_t>
      35             : constexpr data_t return_data_t(const T<data_t>&);
      36             : 
      37          19 : TEST_CASE_TEMPLATE("OGM: Solving a simple linear problem", TestType, OGM<float>, OGM<double>)
      38             : {
      39             :     // Set seed for Eigen Matrices!
      40           4 :     srand((unsigned int) 666);
      41             : 
      42             :     using data_t = decltype(return_data_t(std::declval<TestType>()));
      43             :     // eliminate the timing info from console for the tests
      44           4 :     Logger::setLevel(Logger::LogLevel::OFF);
      45             : 
      46           8 :     GIVEN("a linear problem")
      47             :     {
      48           8 :         IndexVector_t numCoeff(2);
      49           4 :         numCoeff << 13, 24;
      50           8 :         VolumeDescriptor dd{numCoeff};
      51             : 
      52           8 :         Eigen::Matrix<data_t, -1, 1> bVec(dd.getNumberOfCoefficients());
      53           4 :         bVec.setRandom();
      54           8 :         DataContainer<data_t> dcB{dd, bVec};
      55             : 
      56           4 :         bVec.setRandom();
      57           4 :         bVec = bVec.cwiseAbs();
      58           8 :         Scaling<data_t> scalingOp{dd, DataContainer<data_t>{dd, bVec}};
      59             : 
      60             :         // using WLS problem here for ease of use
      61             :         // since OGM is very picky with the precision of the lipschitz constant of a problem we need
      62             :         // to pass it explicitly
      63           8 :         WLSProblem<data_t> prob{scalingOp, dcB, static_cast<data_t>(1.0)};
      64             : 
      65           4 :         data_t epsilon = std::numeric_limits<data_t>::epsilon();
      66             : 
      67           6 :         WHEN("setting up an OGM solver")
      68             :         {
      69           4 :             TestType solver{prob, epsilon};
      70             : 
      71           4 :             THEN("the clone works correctly")
      72             :             {
      73           4 :                 auto ogmClone = solver.clone();
      74             : 
      75           2 :                 REQUIRE_NE(ogmClone.get(), &solver);
      76           2 :                 REQUIRE_EQ(*ogmClone, solver);
      77             : 
      78           4 :                 AND_THEN("it works as expected")
      79             :                 {
      80           4 :                     auto solution = solver.solve(1000);
      81             : 
      82           2 :                     DataContainer<data_t> resultsDifference = scalingOp.apply(solution) - dcB;
      83             : 
      84             :                     // should have converged for the given number of iterations
      85           2 :                     REQUIRE_UNARY(checkApproxEq(resultsDifference.squaredL2Norm(),
      86             :                                                 epsilon * epsilon * dcB.squaredL2Norm(), 0.5f));
      87             :                 }
      88             :             }
      89             :         }
      90             : 
      91           6 :         WHEN("setting up a preconditioned OGM solver")
      92             :         {
      93           2 :             bVec = 1 / bVec.array();
      94           6 :             TestType solver{prob, Scaling<data_t>{dd, DataContainer<data_t>{dd, bVec}}, epsilon};
      95             : 
      96           4 :             THEN("the clone works correctly")
      97             :             {
      98           4 :                 auto ogmClone = solver.clone();
      99             : 
     100           2 :                 REQUIRE_NE(ogmClone.get(), &solver);
     101           2 :                 REQUIRE_EQ(*ogmClone, solver);
     102             : 
     103           4 :                 AND_THEN("it works as expected")
     104             :                 {
     105             :                     // with a good preconditioner we should need fewer iterations than without
     106           4 :                     auto solution = solver.solve(500);
     107             : 
     108           2 :                     DataContainer<data_t> resultsDifference = scalingOp.apply(solution) - dcB;
     109             : 
     110             :                     // should have converged for the given number of iterations
     111           2 :                     REQUIRE_UNARY(checkApproxEq(resultsDifference.squaredL2Norm(),
     112             :                                                 epsilon * epsilon * dcB.squaredL2Norm()));
     113             :                 }
     114             :             }
     115             :         }
     116             :     }
     117           4 : }
     118             : 
     119          19 : TEST_CASE_TEMPLATE("OGM: Solving a Tikhonov problem", TestType, OGM<float>, OGM<double>)
     120             : {
     121             :     // Set seed for Eigen Matrices!
     122           4 :     srand((unsigned int) 666);
     123             : 
     124             :     using data_t = decltype(return_data_t(std::declval<TestType>()));
     125             :     // eliminate the timing info from console for the tests
     126           4 :     Logger::setLevel(Logger::LogLevel::OFF);
     127             : 
     128           8 :     GIVEN("a Tikhonov problem")
     129             :     {
     130           8 :         IndexVector_t numCoeff(2);
     131           4 :         numCoeff << 13, 24;
     132           8 :         VolumeDescriptor dd(numCoeff);
     133             : 
     134           8 :         Eigen::Matrix<data_t, -1, 1> bVec(dd.getNumberOfCoefficients());
     135           4 :         bVec.setRandom();
     136           8 :         DataContainer dcB(dd, bVec);
     137             : 
     138           4 :         bVec.setRandom();
     139           4 :         bVec = bVec.cwiseProduct(bVec);
     140           8 :         Scaling<data_t> scalingOp{dd, DataContainer<data_t>{dd, bVec}};
     141             : 
     142           4 :         auto lambda = static_cast<data_t>(0.1);
     143           8 :         Scaling<data_t> lambdaOp{dd, lambda};
     144             : 
     145             :         // using WLS problem here for ease of use
     146             :         // since OGM is very picky with the precision of the lipschitz constant of a problem we need
     147             :         // to pass it explicitly
     148           8 :         WLSProblem<data_t> prob{scalingOp + lambdaOp, dcB, static_cast<data_t>(1.2)};
     149             : 
     150           4 :         data_t epsilon = std::numeric_limits<data_t>::epsilon();
     151             : 
     152           6 :         WHEN("setting up an OGM solver")
     153             :         {
     154           4 :             TestType solver{prob, epsilon};
     155             : 
     156           4 :             THEN("the clone works correctly")
     157             :             {
     158           4 :                 auto ogmClone = solver.clone();
     159             : 
     160           2 :                 REQUIRE_NE(ogmClone.get(), &solver);
     161           2 :                 REQUIRE_EQ(*ogmClone, solver);
     162             : 
     163           4 :                 AND_THEN("it works as expected")
     164             :                 {
     165           4 :                     auto solution = solver.solve(dd.getNumberOfCoefficients());
     166             : 
     167           2 :                     DataContainer<data_t> resultsDifference =
     168             :                         (scalingOp + lambdaOp).apply(solution) - dcB;
     169             : 
     170             :                     // should have converged for the given number of iterations
     171             :                     // does not converge to the optimal solution because of the regularization term
     172           2 :                     REQUIRE_UNARY(checkApproxEq(resultsDifference.squaredL2Norm(),
     173             :                                                 epsilon * epsilon * dcB.squaredL2Norm()));
     174             :                 }
     175             :             }
     176             :         }
     177             : 
     178           6 :         WHEN("setting up a preconditioned OGM solver")
     179             :         {
     180           2 :             bVec = 1 / (bVec.array() + lambda);
     181           6 :             TestType solver{prob, Scaling<data_t>{dd, DataContainer<data_t>{dd, bVec}}, epsilon};
     182             : 
     183           4 :             THEN("the clone works correctly")
     184             :             {
     185           4 :                 auto ogmClone = solver.clone();
     186             : 
     187           2 :                 REQUIRE_NE(ogmClone.get(), &solver);
     188           2 :                 REQUIRE_EQ(*ogmClone, solver);
     189             : 
     190           4 :                 AND_THEN("it works as expected")
     191             :                 {
     192             :                     // a perfect preconditioner should allow for convergence in a single step
     193           4 :                     auto solution = solver.solve(dd.getNumberOfCoefficients());
     194             : 
     195           2 :                     DataContainer<data_t> resultsDifference =
     196             :                         (scalingOp + lambdaOp).apply(solution) - dcB;
     197             : 
     198             :                     // should have converged for the given number of iterations
     199           2 :                     REQUIRE_UNARY(checkApproxEq(resultsDifference.squaredL2Norm(),
     200             :                                                 epsilon * epsilon * dcB.squaredL2Norm()));
     201             :                 }
     202             :             }
     203             :         }
     204             :     }
     205           4 : }
     206             : 
     207           1 : TEST_CASE("OGM: Solving a simple phantom reconstruction")
     208             : {
     209             :     // Set seed for Eigen Matrices!
     210           1 :     srand((unsigned int) 666);
     211             : 
     212             :     // eliminate the timing info from console for the tests
     213           1 :     Logger::setLevel(Logger::LogLevel::OFF);
     214             : 
     215           2 :     GIVEN("a Phantom reconstruction problem")
     216             :     {
     217           2 :         IndexVector_t size(2);
     218           1 :         size << 16, 16; // TODO: determine optimal phantom size for efficient testing
     219           2 :         auto phantom = PhantomGenerator<real_t>::createModifiedSheppLogan(size);
     220           1 :         auto& volumeDescriptor = phantom.getDataDescriptor();
     221             : 
     222           1 :         index_t numAngles{20}, arc{360};
     223             :         auto sinoDescriptor = CircleTrajectoryGenerator::createTrajectory(
     224           2 :             numAngles, phantom.getDataDescriptor(), arc, static_cast<real_t>(size(0)) * 100.0f,
     225           2 :             static_cast<real_t>(size(0)));
     226             : 
     227           2 :         SiddonsMethod projector(downcast<VolumeDescriptor>(volumeDescriptor), *sinoDescriptor);
     228             : 
     229           2 :         auto sinogram = projector.apply(phantom);
     230             : 
     231           2 :         WLSProblem problem(projector, sinogram);
     232           1 :         real_t epsilon = std::numeric_limits<real_t>::epsilon();
     233             : 
     234           2 :         WHEN("setting up a SQS solver")
     235             :         {
     236           2 :             OGM solver{problem, epsilon};
     237             : 
     238           2 :             THEN("the clone works correctly")
     239             :             {
     240           2 :                 auto ogmClone = solver.clone();
     241             : 
     242           1 :                 REQUIRE_NE(ogmClone.get(), &solver);
     243           1 :                 REQUIRE_EQ(*ogmClone, solver);
     244             : 
     245           2 :                 AND_THEN("it works as expected")
     246             :                 {
     247           2 :                     auto reconstruction = solver.solve(15);
     248             : 
     249           1 :                     DataContainer resultsDifference = reconstruction - phantom;
     250             : 
     251             :                     // should have converged for the given number of iterations
     252             :                     // does not converge to the optimal solution because of the regularization term
     253           1 :                     REQUIRE_UNARY(checkApproxEq(resultsDifference.squaredL2Norm(),
     254             :                                                 epsilon * epsilon * phantom.squaredL2Norm(), 0.15));
     255             :                 }
     256             :             }
     257             :         }
     258             :     }
     259           1 : }
     260             : 
     261             : TEST_SUITE_END();

Generated by: LCOV version 1.15