LCOV - code coverage report
Current view: top level - elsa/solvers/tests - test_OGM.cpp (source / functions) Hit Total Coverage
Test: coverage-all.lcov Lines: 148 148 100.0 %
Date: 2022-08-25 03:05:39 Functions: 5 5 100.0 %

          Line data    Source code
       1             : /**
       2             :  * @file test_OGM.cpp
       3             :  *
       4             :  * @brief Tests for the Optimized Gradient Method class
       5             :  *
       6             :  * @author Michael Loipführer - initial code
       7             :  */
       8             : 
       9             : #include "doctest/doctest.h"
      10             : 
      11             : #include <iostream>
      12             : #include "OGM.h"
      13             : #include "WLSProblem.h"
      14             : #include "Problem.h"
      15             : #include "Identity.h"
      16             : #include "LinearResidual.h"
      17             : #include "L2NormPow2.h"
      18             : #include "Logger.h"
      19             : #include "VolumeDescriptor.h"
      20             : #include "SiddonsMethod.h"
      21             : #include "CircleTrajectoryGenerator.h"
      22             : #include "PhantomGenerator.h"
      23             : #include "TypeCasts.hpp"
      24             : #include "testHelpers.h"
      25             : 
      26             : using namespace elsa;
      27             : using namespace doctest;
      28             : 
      29             : TEST_SUITE_BEGIN("solvers");
      30             : 
      31             : TYPE_TO_STRING(OGM<float>);
      32             : TYPE_TO_STRING(OGM<double>);
      33             : 
      34             : template <template <typename> typename T, typename data_t>
      35             : constexpr data_t return_data_t(const T<data_t>&);
      36             : 
      37             : TEST_CASE_TEMPLATE("OGM: Solving a simple linear problem", TestType, OGM<float>, OGM<double>)
      38           4 : {
      39             :     // Set seed for Eigen Matrices!
      40           4 :     srand((unsigned int) 666);
      41             : 
      42           4 :     using data_t = decltype(return_data_t(std::declval<TestType>()));
      43             :     // eliminate the timing info from console for the tests
      44           4 :     Logger::setLevel(Logger::LogLevel::OFF);
      45             : 
      46           4 :     GIVEN("a linear problem")
      47           4 :     {
      48           4 :         IndexVector_t numCoeff(2);
      49           4 :         numCoeff << 13, 24;
      50           4 :         VolumeDescriptor dd{numCoeff};
      51             : 
      52           4 :         Eigen::Matrix<data_t, -1, 1> bVec(dd.getNumberOfCoefficients());
      53           4 :         bVec.setRandom();
      54           4 :         DataContainer<data_t> dcB{dd, bVec};
      55             : 
      56           4 :         bVec.setRandom();
      57           4 :         bVec = bVec.cwiseAbs();
      58           4 :         Scaling<data_t> scalingOp{dd, DataContainer<data_t>{dd, bVec}};
      59             : 
      60             :         // using WLS problem here for ease of use
      61             :         // since OGM is very picky with the precision of the lipschitz constant of a problem we need
      62             :         // to pass it explicitly
      63           4 :         WLSProblem<data_t> prob{scalingOp, dcB, static_cast<data_t>(1.0)};
      64             : 
      65           4 :         data_t epsilon = std::numeric_limits<data_t>::epsilon();
      66             : 
      67           4 :         WHEN("setting up an OGM solver")
      68           4 :         {
      69           2 :             TestType solver{prob, epsilon};
      70             : 
      71           2 :             THEN("the clone works correctly")
      72           2 :             {
      73           2 :                 auto ogmClone = solver.clone();
      74             : 
      75           2 :                 REQUIRE_NE(ogmClone.get(), &solver);
      76           2 :                 REQUIRE_EQ(*ogmClone, solver);
      77             : 
      78           2 :                 AND_THEN("it works as expected")
      79           2 :                 {
      80           2 :                     auto solution = solver.solve(1000);
      81             : 
      82           2 :                     DataContainer<data_t> resultsDifference = scalingOp.apply(solution) - dcB;
      83             : 
      84             :                     // should have converged for the given number of iterations
      85           2 :                     REQUIRE_UNARY(checkApproxEq(resultsDifference.squaredL2Norm(),
      86           2 :                                                 epsilon * epsilon * dcB.squaredL2Norm(), 0.5f));
      87           2 :                 }
      88           2 :             }
      89           2 :         }
      90             : 
      91           4 :         WHEN("setting up a preconditioned OGM solver")
      92           4 :         {
      93           2 :             bVec = 1 / bVec.array();
      94           2 :             TestType solver{prob, Scaling<data_t>{dd, DataContainer<data_t>{dd, bVec}}, epsilon};
      95             : 
      96           2 :             THEN("the clone works correctly")
      97           2 :             {
      98           2 :                 auto ogmClone = solver.clone();
      99             : 
     100           2 :                 REQUIRE_NE(ogmClone.get(), &solver);
     101           2 :                 REQUIRE_EQ(*ogmClone, solver);
     102             : 
     103           2 :                 AND_THEN("it works as expected")
     104           2 :                 {
     105             :                     // with a good preconditioner we should need fewer iterations than without
     106           2 :                     auto solution = solver.solve(500);
     107             : 
     108           2 :                     DataContainer<data_t> resultsDifference = scalingOp.apply(solution) - dcB;
     109             : 
     110             :                     // should have converged for the given number of iterations
     111           2 :                     REQUIRE_UNARY(checkApproxEq(resultsDifference.squaredL2Norm(),
     112           2 :                                                 epsilon * epsilon * dcB.squaredL2Norm()));
     113           2 :                 }
     114           2 :             }
     115           2 :         }
     116           4 :     }
     117           4 : }
     118             : 
     119             : TEST_CASE_TEMPLATE("OGM: Solving a Tikhonov problem", TestType, OGM<float>, OGM<double>)
     120           4 : {
     121             :     // Set seed for Eigen Matrices!
     122           4 :     srand((unsigned int) 666);
     123             : 
     124           4 :     using data_t = decltype(return_data_t(std::declval<TestType>()));
     125             :     // eliminate the timing info from console for the tests
     126           4 :     Logger::setLevel(Logger::LogLevel::OFF);
     127             : 
     128           4 :     GIVEN("a Tikhonov problem")
     129           4 :     {
     130           4 :         IndexVector_t numCoeff(2);
     131           4 :         numCoeff << 13, 24;
     132           4 :         VolumeDescriptor dd(numCoeff);
     133             : 
     134           4 :         Eigen::Matrix<data_t, -1, 1> bVec(dd.getNumberOfCoefficients());
     135           4 :         bVec.setRandom();
     136           4 :         DataContainer dcB(dd, bVec);
     137             : 
     138           4 :         bVec.setRandom();
     139           4 :         bVec = bVec.cwiseProduct(bVec);
     140           4 :         Scaling<data_t> scalingOp{dd, DataContainer<data_t>{dd, bVec}};
     141             : 
     142           4 :         auto lambda = static_cast<data_t>(0.1);
     143           4 :         Scaling<data_t> lambdaOp{dd, lambda};
     144             : 
     145             :         // using WLS problem here for ease of use
     146             :         // since OGM is very picky with the precision of the lipschitz constant of a problem we need
     147             :         // to pass it explicitly
     148           4 :         WLSProblem<data_t> prob{scalingOp + lambdaOp, dcB, static_cast<data_t>(1.2)};
     149             : 
     150           4 :         data_t epsilon = std::numeric_limits<data_t>::epsilon();
     151             : 
     152           4 :         WHEN("setting up an OGM solver")
     153           4 :         {
     154           2 :             TestType solver{prob, epsilon};
     155             : 
     156           2 :             THEN("the clone works correctly")
     157           2 :             {
     158           2 :                 auto ogmClone = solver.clone();
     159             : 
     160           2 :                 REQUIRE_NE(ogmClone.get(), &solver);
     161           2 :                 REQUIRE_EQ(*ogmClone, solver);
     162             : 
     163           2 :                 AND_THEN("it works as expected")
     164           2 :                 {
     165           2 :                     auto solution = solver.solve(dd.getNumberOfCoefficients());
     166             : 
     167           2 :                     DataContainer<data_t> resultsDifference =
     168           2 :                         (scalingOp + lambdaOp).apply(solution) - dcB;
     169             : 
     170             :                     // should have converged for the given number of iterations
     171             :                     // does not converge to the optimal solution because of the regularization term
     172           2 :                     REQUIRE_UNARY(checkApproxEq(resultsDifference.squaredL2Norm(),
     173           2 :                                                 epsilon * epsilon * dcB.squaredL2Norm()));
     174           2 :                 }
     175           2 :             }
     176           2 :         }
     177             : 
     178           4 :         WHEN("setting up a preconditioned OGM solver")
     179           4 :         {
     180           2 :             bVec = 1 / (bVec.array() + lambda);
     181           2 :             TestType solver{prob, Scaling<data_t>{dd, DataContainer<data_t>{dd, bVec}}, epsilon};
     182             : 
     183           2 :             THEN("the clone works correctly")
     184           2 :             {
     185           2 :                 auto ogmClone = solver.clone();
     186             : 
     187           2 :                 REQUIRE_NE(ogmClone.get(), &solver);
     188           2 :                 REQUIRE_EQ(*ogmClone, solver);
     189             : 
     190           2 :                 AND_THEN("it works as expected")
     191           2 :                 {
     192             :                     // a perfect preconditioner should allow for convergence in a single step
     193           2 :                     auto solution = solver.solve(dd.getNumberOfCoefficients());
     194             : 
     195           2 :                     DataContainer<data_t> resultsDifference =
     196           2 :                         (scalingOp + lambdaOp).apply(solution) - dcB;
     197             : 
     198             :                     // should have converged for the given number of iterations
     199           2 :                     REQUIRE_UNARY(checkApproxEq(resultsDifference.squaredL2Norm(),
     200           2 :                                                 epsilon * epsilon * dcB.squaredL2Norm()));
     201           2 :                 }
     202           2 :             }
     203           2 :         }
     204           4 :     }
     205           4 : }
     206             : 
     207             : TEST_CASE("OGM: Solving a simple phantom reconstruction")
     208           1 : {
     209             :     // Set seed for Eigen Matrices!
     210           1 :     srand((unsigned int) 666);
     211             : 
     212             :     // eliminate the timing info from console for the tests
     213           1 :     Logger::setLevel(Logger::LogLevel::OFF);
     214             : 
     215           1 :     GIVEN("a Phantom reconstruction problem")
     216           1 :     {
     217           1 :         IndexVector_t size(2);
     218           1 :         size << 16, 16; // TODO: determine optimal phantom size for efficient testing
     219           1 :         auto phantom = PhantomGenerator<real_t>::createModifiedSheppLogan(size);
     220           1 :         auto& volumeDescriptor = phantom.getDataDescriptor();
     221             : 
     222           1 :         index_t numAngles{20}, arc{360};
     223           1 :         auto sinoDescriptor = CircleTrajectoryGenerator::createTrajectory(
     224           1 :             numAngles, phantom.getDataDescriptor(), arc, static_cast<real_t>(size(0)) * 100.0f,
     225           1 :             static_cast<real_t>(size(0)));
     226             : 
     227           1 :         SiddonsMethod projector(downcast<VolumeDescriptor>(volumeDescriptor), *sinoDescriptor);
     228             : 
     229           1 :         auto sinogram = projector.apply(phantom);
     230             : 
     231           1 :         WLSProblem problem(projector, sinogram);
     232           1 :         real_t epsilon = std::numeric_limits<real_t>::epsilon();
     233             : 
     234           1 :         WHEN("setting up a SQS solver")
     235           1 :         {
     236           1 :             OGM solver{problem, epsilon};
     237             : 
     238           1 :             THEN("the clone works correctly")
     239           1 :             {
     240           1 :                 auto ogmClone = solver.clone();
     241             : 
     242           1 :                 REQUIRE_NE(ogmClone.get(), &solver);
     243           1 :                 REQUIRE_EQ(*ogmClone, solver);
     244             : 
     245           1 :                 AND_THEN("it works as expected")
     246           1 :                 {
     247           1 :                     auto reconstruction = solver.solve(15);
     248             : 
     249           1 :                     DataContainer resultsDifference = reconstruction - phantom;
     250             : 
     251             :                     // should have converged for the given number of iterations
     252             :                     // does not converge to the optimal solution because of the regularization term
     253           1 :                     REQUIRE_UNARY(checkApproxEq(resultsDifference.squaredL2Norm(),
     254           1 :                                                 epsilon * epsilon * phantom.squaredL2Norm(), 0.15));
     255           1 :                 }
     256           1 :             }
     257           1 :         }
     258           1 :     }
     259           1 : }
     260             : 
     261             : TEST_SUITE_END();

Generated by: LCOV version 1.14