LCOV - code coverage report
Current view: top level - elsa/solvers - OGM.cpp (source / functions) Hit Total Coverage
Test: coverage-all.lcov Lines: 63 74 85.1 %
Date: 2022-08-25 03:05:39 Functions: 10 10 100.0 %

          Line data    Source code
       1             : #include "OGM.h"
       2             : #include "TypeCasts.hpp"
       3             : #include "Logger.h"
       4             : 
       5             : namespace elsa
       6             : {
       7             :     template <typename data_t>
       8             :     OGM<data_t>::OGM(const Problem<data_t>& problem, data_t epsilon)
       9             :         : Solver<data_t>(), _problem(problem.clone()), _epsilon{epsilon}
      10          10 :     {
      11          10 :     }
      12             : 
      13             :     template <typename data_t>
      14             :     OGM<data_t>::OGM(const Problem<data_t>& problem,
      15             :                      const LinearOperator<data_t>& preconditionerInverse, data_t epsilon)
      16             :         : Solver<data_t>(),
      17             :           _problem(problem.clone()),
      18             :           _epsilon{epsilon},
      19             :           _preconditionerInverse{preconditionerInverse.clone()}
      20           8 :     {
      21             :         // check that preconditioner is compatible with problem
      22           8 :         if (_preconditionerInverse->getDomainDescriptor().getNumberOfCoefficients()
      23           8 :                 != _problem->getCurrentSolution().getSize()
      24           8 :             || _preconditionerInverse->getRangeDescriptor().getNumberOfCoefficients()
      25           8 :                    != _problem->getCurrentSolution().getSize()) {
      26           0 :             throw InvalidArgumentError("OGM: incorrect size of preconditioner");
      27           0 :         }
      28           8 :     }
      29             : 
      30             :     template <typename data_t>
      31             :     DataContainer<data_t>& OGM<data_t>::solveImpl(index_t iterations)
      32           9 :     {
      33           9 :         if (iterations == 0)
      34           0 :             iterations = _defaultIterations;
      35             : 
      36           9 :         auto prevTheta = static_cast<data_t>(1.0);
      37           9 :         auto x0 = _problem->getCurrentSolution();
      38           9 :         auto& prevY = x0;
      39             : 
      40             :         // OGM is very picky when it comes to the accuracy of the used lipschitz constant therefore
      41             :         // we use 20 power iterations instead of 5 here to be more precise.
      42             :         // In some cases OGM might still not converge then an even more precise constant is needed
      43           9 :         auto lipschitz = _problem->getLipschitzConstant(20);
      44           9 :         auto deltaZero = _problem->getGradient().squaredL2Norm();
      45           9 :         Logger::get("OGM")->info("Starting optimization with lipschitz constant {}", lipschitz);
      46             : 
      47             :         // log history legend
      48           9 :         Logger::get("OGM")->info("{:*^20}|{:*^20}|{:*^20}|{:*^20}|{:*^20}", "iteration",
      49           9 :                                  "thetaRatio0", "thetaRatio1", "y", "gradient");
      50             : 
      51        4272 :         for (index_t i = 0; i < iterations; ++i) {
      52        4263 :             auto& x = _problem->getCurrentSolution();
      53             : 
      54        4263 :             auto gradient = _problem->getGradient();
      55             : 
      56        4263 :             if (_preconditionerInverse)
      57        1624 :                 gradient = _preconditionerInverse->apply(gradient);
      58             : 
      59        4263 :             DataContainer<data_t> y = x - gradient / lipschitz;
      60        4263 :             data_t theta;
      61        4263 :             if (i == iterations - 1) { // last iteration
      62           9 :                 theta = (static_cast<data_t>(1.0)
      63           9 :                          + std::sqrt(static_cast<data_t>(1.0)
      64           9 :                                      + static_cast<data_t>(8.0) * prevTheta * prevTheta))
      65           9 :                         / static_cast<data_t>(2.0);
      66        4254 :             } else {
      67        4254 :                 theta = (static_cast<data_t>(1.0)
      68        4254 :                          + std::sqrt(static_cast<data_t>(1.0)
      69        4254 :                                      + static_cast<data_t>(4.0) * prevTheta * prevTheta))
      70        4254 :                         / static_cast<data_t>(2.0);
      71        4254 :             }
      72             : 
      73        4263 :             Logger::get("OGM")->info(" {:<19}| {:<19}| {:<19}| {:<19}| {:<19}", i,
      74        4263 :                                      (prevTheta - 1) / theta, prevTheta / theta, y.squaredL2Norm(),
      75        4263 :                                      gradient.squaredL2Norm());
      76             : 
      77             :             // x_{i+1} = y_{i+1} + \frac{\theta_i-1}{\theta_{i+1}}(y_{i+1} - y_i) +
      78             :             // \frac{\theta_i}{\theta_{i+1}}/(y_{i+1} - x_i)
      79        4263 :             x = y + ((prevTheta - static_cast<data_t>(1.0)) / theta) * (y - prevY)
      80        4263 :                 - (prevTheta / theta) * (gradient / lipschitz);
      81        4263 :             prevTheta = theta;
      82        4263 :             prevY = y;
      83             : 
      84             :             // if the gradient is too small we stop
      85        4263 :             if (gradient.squaredL2Norm() <= _epsilon * _epsilon * deltaZero) {
      86           0 :                 Logger::get("OGM")->info("SUCCESS: Reached convergence at {}/{} iteration", i + 1,
      87           0 :                                          iterations);
      88           0 :                 return x;
      89           0 :             }
      90        4263 :         }
      91             : 
      92           9 :         Logger::get("OGM")->warn("Failed to reach convergence at {} iterations", iterations);
      93             : 
      94           9 :         return _problem->getCurrentSolution();
      95           9 :     }
      96             : 
      97             :     template <typename data_t>
      98             :     OGM<data_t>* OGM<data_t>::cloneImpl() const
      99           9 :     {
     100           9 :         if (_preconditionerInverse)
     101           4 :             return new OGM(*_problem, *_preconditionerInverse, _epsilon);
     102             : 
     103           5 :         return new OGM(*_problem, _epsilon);
     104           5 :     }
     105             : 
     106             :     template <typename data_t>
     107             :     bool OGM<data_t>::isEqual(const Solver<data_t>& other) const
     108           9 :     {
     109           9 :         auto otherOGM = downcast_safe<OGM>(&other);
     110           9 :         if (!otherOGM)
     111           0 :             return false;
     112             : 
     113           9 :         if (_epsilon != otherOGM->_epsilon)
     114           0 :             return false;
     115             : 
     116           9 :         if ((_preconditionerInverse && !otherOGM->_preconditionerInverse)
     117           9 :             || (!_preconditionerInverse && otherOGM->_preconditionerInverse))
     118           0 :             return false;
     119             : 
     120           9 :         if (_preconditionerInverse && otherOGM->_preconditionerInverse)
     121           4 :             if (*_preconditionerInverse != *otherOGM->_preconditionerInverse)
     122           0 :                 return false;
     123             : 
     124           9 :         return true;
     125           9 :     }
     126             : 
     127             :     // ------------------------------------------
     128             :     // explicit template instantiation
     129             :     template class OGM<float>;
     130             :     template class OGM<double>;
     131             : 
     132             : } // namespace elsa

Generated by: LCOV version 1.14