LCOV - code coverage report
Current view: top level - elsa/solvers - FGM.cpp (source / functions) Hit Total Coverage
Test: coverage-all.lcov Lines: 50 61 82.0 %
Date: 2022-08-25 03:05:39 Functions: 10 10 100.0 %

          Line data    Source code
       1             : #include "FGM.h"
       2             : #include "Logger.h"
       3             : #include "TypeCasts.hpp"
       4             : 
       5             : namespace elsa
       6             : {
       7             :     template <typename data_t>
       8             :     FGM<data_t>::FGM(const Problem<data_t>& problem, data_t epsilon)
       9             :         : Solver<data_t>(), _problem(problem.clone()), _epsilon{epsilon}
      10          10 :     {
      11          10 :     }
      12             : 
      13             :     template <typename data_t>
      14             :     FGM<data_t>::FGM(const Problem<data_t>& problem,
      15             :                      const LinearOperator<data_t>& preconditionerInverse, data_t epsilon)
      16             :         : Solver<data_t>(),
      17             :           _problem(problem.clone()),
      18             :           _epsilon{epsilon},
      19             :           _preconditionerInverse{preconditionerInverse.clone()}
      20           8 :     {
      21             :         // check that preconditioner is compatible with problem
      22           8 :         if (_preconditionerInverse->getDomainDescriptor().getNumberOfCoefficients()
      23           8 :                 != _problem->getCurrentSolution().getSize()
      24           8 :             || _preconditionerInverse->getRangeDescriptor().getNumberOfCoefficients()
      25           8 :                    != _problem->getCurrentSolution().getSize()) {
      26           0 :             throw InvalidArgumentError("FGM: incorrect size of preconditioner");
      27           0 :         }
      28           8 :     }
      29             : 
      30             :     template <typename data_t>
      31             :     DataContainer<data_t>& FGM<data_t>::solveImpl(index_t iterations)
      32           9 :     {
      33           9 :         if (iterations == 0)
      34           0 :             iterations = _defaultIterations;
      35             : 
      36           9 :         auto prevTheta = static_cast<data_t>(1.0);
      37           9 :         auto x0 = _problem->getCurrentSolution();
      38           9 :         auto& prevY = x0;
      39             : 
      40           9 :         auto deltaZero = _problem->getGradient().squaredL2Norm();
      41           9 :         auto lipschitz = _problem->getLipschitzConstant();
      42           9 :         Logger::get("FGM")->info("Starting optimization with lipschitz constant {}", lipschitz);
      43             : 
      44        5272 :         for (index_t i = 0; i < iterations; ++i) {
      45        5263 :             Logger::get("FGM")->info("iteration {} of {}", i + 1, iterations);
      46        5263 :             auto& x = _problem->getCurrentSolution();
      47             : 
      48        5263 :             auto gradient = _problem->getGradient();
      49             : 
      50        5263 :             if (_preconditionerInverse)
      51        2624 :                 gradient = _preconditionerInverse->apply(gradient);
      52             : 
      53        5263 :             DataContainer<data_t> y = x - gradient / lipschitz;
      54        5263 :             const auto theta = (static_cast<data_t>(1.0)
      55        5263 :                                 + std::sqrt(static_cast<data_t>(1.0)
      56        5263 :                                             + static_cast<data_t>(4.0) * prevTheta * prevTheta))
      57        5263 :                                / static_cast<data_t>(2.0);
      58        5263 :             x = y + (prevTheta - static_cast<data_t>(1.0)) / theta * (y - prevY);
      59        5263 :             prevTheta = theta;
      60        5263 :             prevY = y;
      61             : 
      62             :             // if the gradient is too small we stop
      63        5263 :             if (gradient.squaredL2Norm() <= _epsilon * _epsilon * deltaZero) {
      64           0 :                 Logger::get("FGM")->info("SUCCESS: Reached convergence at {}/{} iteration", i + 1,
      65           0 :                                          iterations);
      66           0 :                 return x;
      67           0 :             }
      68        5263 :         }
      69             : 
      70           9 :         Logger::get("FGM")->warn("Failed to reach convergence at {} iterations", iterations);
      71             : 
      72           9 :         return _problem->getCurrentSolution();
      73           9 :     }
      74             : 
      75             :     template <typename data_t>
      76             :     FGM<data_t>* FGM<data_t>::cloneImpl() const
      77           9 :     {
      78           9 :         if (_preconditionerInverse)
      79           4 :             return new FGM(*_problem, *_preconditionerInverse, _epsilon);
      80             : 
      81           5 :         return new FGM(*_problem, _epsilon);
      82           5 :     }
      83             : 
      84             :     template <typename data_t>
      85             :     bool FGM<data_t>::isEqual(const Solver<data_t>& other) const
      86           9 :     {
      87           9 :         auto otherFGM = downcast_safe<FGM>(&other);
      88           9 :         if (!otherFGM)
      89           0 :             return false;
      90             : 
      91           9 :         if (_epsilon != otherFGM->_epsilon)
      92           0 :             return false;
      93             : 
      94           9 :         if ((_preconditionerInverse && !otherFGM->_preconditionerInverse)
      95           9 :             || (!_preconditionerInverse && otherFGM->_preconditionerInverse))
      96           0 :             return false;
      97             : 
      98           9 :         if (_preconditionerInverse && otherFGM->_preconditionerInverse)
      99           4 :             if (*_preconditionerInverse != *otherFGM->_preconditionerInverse)
     100           0 :                 return false;
     101             : 
     102           9 :         return true;
     103           9 :     }
     104             : 
     105             :     // ------------------------------------------
     106             :     // explicit template instantiation
     107             :     template class FGM<float>;
     108             :     template class FGM<double>;
     109             : 
     110             : } // namespace elsa

Generated by: LCOV version 1.14