LCOV - code coverage report
Current view: top level - solvers - FGM.cpp (source / functions) Hit Total Coverage
Test: test_coverage.info.cleaned Lines: 0 59 0.0 %
Date: 2022-08-04 03:43:28 Functions: 0 10 0.0 %

          Line data    Source code
       1             : #include "FGM.h"
       2             : #include "Logger.h"
       3             : #include "TypeCasts.hpp"
       4             : 
       5             : namespace elsa
       6             : {
       7             :     template <typename data_t>
       8           0 :     FGM<data_t>::FGM(const Problem<data_t>& problem, data_t epsilon)
       9           0 :         : Solver<data_t>(problem), _epsilon{epsilon}
      10             :     {
      11           0 :     }
      12             : 
      13             :     template <typename data_t>
      14           0 :     FGM<data_t>::FGM(const Problem<data_t>& problem,
      15             :                      const LinearOperator<data_t>& preconditionerInverse, data_t epsilon)
      16             :         : Solver<data_t>(problem),
      17           0 :           _epsilon{epsilon},
      18           0 :           _preconditionerInverse{preconditionerInverse.clone()}
      19             :     {
      20             :         // check that preconditioner is compatible with problem
      21           0 :         if (_preconditionerInverse->getDomainDescriptor().getNumberOfCoefficients()
      22           0 :                 != _problem->getCurrentSolution().getSize()
      23           0 :             || _preconditionerInverse->getRangeDescriptor().getNumberOfCoefficients()
      24           0 :                    != _problem->getCurrentSolution().getSize()) {
      25           0 :             throw InvalidArgumentError("FGM: incorrect size of preconditioner");
      26             :         }
      27           0 :     }
      28             : 
      29             :     template <typename data_t>
      30           0 :     DataContainer<data_t>& FGM<data_t>::solveImpl(index_t iterations)
      31             :     {
      32           0 :         if (iterations == 0)
      33           0 :             iterations = _defaultIterations;
      34             : 
      35           0 :         auto prevTheta = static_cast<data_t>(1.0);
      36           0 :         auto x0 = DataContainer<data_t>(getCurrentSolution());
      37           0 :         auto& prevY = x0;
      38             : 
      39           0 :         auto deltaZero = _problem->getGradient().squaredL2Norm();
      40           0 :         auto lipschitz = _problem->getLipschitzConstant();
      41           0 :         Logger::get("FGM")->info("Starting optimization with lipschitz constant {}", lipschitz);
      42             : 
      43           0 :         for (index_t i = 0; i < iterations; ++i) {
      44           0 :             Logger::get("FGM")->info("iteration {} of {}", i + 1, iterations);
      45           0 :             auto& x = getCurrentSolution();
      46             : 
      47           0 :             auto gradient = _problem->getGradient();
      48             : 
      49           0 :             if (_preconditionerInverse)
      50           0 :                 gradient = _preconditionerInverse->apply(gradient);
      51             : 
      52           0 :             DataContainer<data_t> y = x - gradient / lipschitz;
      53           0 :             const auto theta = (static_cast<data_t>(1.0)
      54           0 :                                 + std::sqrt(static_cast<data_t>(1.0)
      55           0 :                                             + static_cast<data_t>(4.0) * prevTheta * prevTheta))
      56             :                                / static_cast<data_t>(2.0);
      57           0 :             x = y + (prevTheta - static_cast<data_t>(1.0)) / theta * (y - prevY);
      58           0 :             prevTheta = theta;
      59           0 :             prevY = y;
      60             : 
      61             :             // if the gradient is too small we stop
      62           0 :             if (gradient.squaredL2Norm() <= _epsilon * _epsilon * deltaZero) {
      63           0 :                 Logger::get("FGM")->info("SUCCESS: Reached convergence at {}/{} iteration", i + 1,
      64             :                                          iterations);
      65           0 :                 return x;
      66             :             }
      67             :         }
      68             : 
      69           0 :         Logger::get("FGM")->warn("Failed to reach convergence at {} iterations", iterations);
      70             : 
      71           0 :         return getCurrentSolution();
      72           0 :     }
      73             : 
      74             :     template <typename data_t>
      75           0 :     FGM<data_t>* FGM<data_t>::cloneImpl() const
      76             :     {
      77           0 :         if (_preconditionerInverse)
      78           0 :             return new FGM(*_problem, *_preconditionerInverse, _epsilon);
      79             : 
      80           0 :         return new FGM(*_problem, _epsilon);
      81             :     }
      82             : 
      83             :     template <typename data_t>
      84           0 :     bool FGM<data_t>::isEqual(const Solver<data_t>& other) const
      85             :     {
      86           0 :         if (!Solver<data_t>::isEqual(other))
      87           0 :             return false;
      88             : 
      89           0 :         auto otherFGM = downcast_safe<FGM>(&other);
      90           0 :         if (!otherFGM)
      91           0 :             return false;
      92             : 
      93           0 :         if (_epsilon != otherFGM->_epsilon)
      94           0 :             return false;
      95             : 
      96           0 :         if ((_preconditionerInverse && !otherFGM->_preconditionerInverse)
      97           0 :             || (!_preconditionerInverse && otherFGM->_preconditionerInverse))
      98           0 :             return false;
      99             : 
     100           0 :         if (_preconditionerInverse && otherFGM->_preconditionerInverse)
     101           0 :             if (*_preconditionerInverse != *otherFGM->_preconditionerInverse)
     102           0 :                 return false;
     103             : 
     104           0 :         return true;
     105             :     }
     106             : 
     107             :     // ------------------------------------------
     108             :     // explicit template instantiation
     109             :     template class FGM<float>;
     110             :     template class FGM<double>;
     111             : 
     112             : } // namespace elsa

Generated by: LCOV version 1.14