LCOV - code coverage report
Current view: top level - solvers - OGM.cpp (source / functions) Hit Total Coverage
Test: test_coverage.info.cleaned Lines: 0 67 0.0 %
Date: 2022-08-04 03:43:28 Functions: 0 10 0.0 %

          Line data    Source code
       1             : #include "OGM.h"
       2             : #include "TypeCasts.hpp"
       3             : #include "Logger.h"
       4             : 
       5             : namespace elsa
       6             : {
       7             :     template <typename data_t>
       8           0 :     OGM<data_t>::OGM(const Problem<data_t>& problem, data_t epsilon)
       9           0 :         : Solver<data_t>(problem), _epsilon{epsilon}
      10             :     {
      11           0 :     }
      12             : 
      13             :     template <typename data_t>
      14           0 :     OGM<data_t>::OGM(const Problem<data_t>& problem,
      15             :                      const LinearOperator<data_t>& preconditionerInverse, data_t epsilon)
      16             :         : Solver<data_t>(problem),
      17           0 :           _epsilon{epsilon},
      18           0 :           _preconditionerInverse{preconditionerInverse.clone()}
      19             :     {
      20             :         // check that preconditioner is compatible with problem
      21           0 :         if (_preconditionerInverse->getDomainDescriptor().getNumberOfCoefficients()
      22           0 :                 != _problem->getCurrentSolution().getSize()
      23           0 :             || _preconditionerInverse->getRangeDescriptor().getNumberOfCoefficients()
      24           0 :                    != _problem->getCurrentSolution().getSize()) {
      25           0 :             throw InvalidArgumentError("OGM: incorrect size of preconditioner");
      26             :         }
      27           0 :     }
      28             : 
      29             :     template <typename data_t>
      30           0 :     DataContainer<data_t>& OGM<data_t>::solveImpl(index_t iterations)
      31             :     {
      32           0 :         if (iterations == 0)
      33           0 :             iterations = _defaultIterations;
      34             : 
      35           0 :         auto prevTheta = static_cast<data_t>(1.0);
      36           0 :         auto x0 = DataContainer<data_t>(getCurrentSolution());
      37           0 :         auto& prevY = x0;
      38             : 
      39             :         // OGM is very picky when it comes to the accuracy of the used lipschitz constant therefore
      40             :         // we use 20 power iterations instead of 5 here to be more precise.
      41             :         // In some cases OGM might still not converge then an even more precise constant is needed
      42           0 :         auto lipschitz = _problem->getLipschitzConstant(20);
      43           0 :         auto deltaZero = _problem->getGradient().squaredL2Norm();
      44           0 :         Logger::get("OGM")->info("Starting optimization with lipschitz constant {}", lipschitz);
      45             : 
      46             :         // log history legend
      47           0 :         Logger::get("OGM")->info("{:*^20}|{:*^20}|{:*^20}|{:*^20}|{:*^20}", "iteration",
      48             :                                  "thetaRatio0", "thetaRatio1", "y", "gradient");
      49             : 
      50           0 :         for (index_t i = 0; i < iterations; ++i) {
      51           0 :             auto& x = getCurrentSolution();
      52             : 
      53           0 :             auto gradient = _problem->getGradient();
      54             : 
      55           0 :             if (_preconditionerInverse)
      56           0 :                 gradient = _preconditionerInverse->apply(gradient);
      57             : 
      58           0 :             DataContainer<data_t> y = x - gradient / lipschitz;
      59             :             data_t theta;
      60           0 :             if (i == iterations - 1) { // last iteration
      61           0 :                 theta = (static_cast<data_t>(1.0)
      62           0 :                          + std::sqrt(static_cast<data_t>(1.0)
      63           0 :                                      + static_cast<data_t>(8.0) * prevTheta * prevTheta))
      64             :                         / static_cast<data_t>(2.0);
      65             :             } else {
      66           0 :                 theta = (static_cast<data_t>(1.0)
      67           0 :                          + std::sqrt(static_cast<data_t>(1.0)
      68           0 :                                      + static_cast<data_t>(4.0) * prevTheta * prevTheta))
      69             :                         / static_cast<data_t>(2.0);
      70             :             }
      71             : 
      72           0 :             Logger::get("OGM")->info(" {:<19}| {:<19}| {:<19}| {:<19}| {:<19}", i,
      73           0 :                                      (prevTheta - 1) / theta, prevTheta / theta, y.squaredL2Norm(),
      74           0 :                                      gradient.squaredL2Norm());
      75             : 
      76             :             // x_{i+1} = y_{i+1} + \frac{\theta_i-1}{\theta_{i+1}}(y_{i+1} - y_i) +
      77             :             // \frac{\theta_i}{\theta_{i+1}}/(y_{i+1} - x_i)
      78           0 :             x = y + ((prevTheta - static_cast<data_t>(1.0)) / theta) * (y - prevY)
      79           0 :                 - (prevTheta / theta) * (gradient / lipschitz);
      80           0 :             prevTheta = theta;
      81           0 :             prevY = y;
      82             : 
      83             :             // if the gradient is too small we stop
      84           0 :             if (gradient.squaredL2Norm() <= _epsilon * _epsilon * deltaZero) {
      85           0 :                 Logger::get("OGM")->info("SUCCESS: Reached convergence at {}/{} iteration", i + 1,
      86             :                                          iterations);
      87           0 :                 return x;
      88             :             }
      89             :         }
      90             : 
      91           0 :         Logger::get("OGM")->warn("Failed to reach convergence at {} iterations", iterations);
      92             : 
      93           0 :         return getCurrentSolution();
      94           0 :     }
      95             : 
      96             :     template <typename data_t>
      97           0 :     OGM<data_t>* OGM<data_t>::cloneImpl() const
      98             :     {
      99           0 :         if (_preconditionerInverse)
     100           0 :             return new OGM(*_problem, *_preconditionerInverse, _epsilon);
     101             : 
     102           0 :         return new OGM(*_problem, _epsilon);
     103             :     }
     104             : 
     105             :     template <typename data_t>
     106           0 :     bool OGM<data_t>::isEqual(const Solver<data_t>& other) const
     107             :     {
     108           0 :         if (!Solver<data_t>::isEqual(other))
     109           0 :             return false;
     110             : 
     111           0 :         auto otherOGM = downcast_safe<OGM>(&other);
     112           0 :         if (!otherOGM)
     113           0 :             return false;
     114             : 
     115           0 :         if (_epsilon != otherOGM->_epsilon)
     116           0 :             return false;
     117             : 
     118           0 :         if ((_preconditionerInverse && !otherOGM->_preconditionerInverse)
     119           0 :             || (!_preconditionerInverse && otherOGM->_preconditionerInverse))
     120           0 :             return false;
     121             : 
     122           0 :         if (_preconditionerInverse && otherOGM->_preconditionerInverse)
     123           0 :             if (*_preconditionerInverse != *otherOGM->_preconditionerInverse)
     124           0 :                 return false;
     125             : 
     126           0 :         return true;
     127             :     }
     128             : 
     129             :     // ------------------------------------------
     130             :     // explicit template instantiation
     131             :     template class OGM<float>;
     132             :     template class OGM<double>;
     133             : 
     134             : } // namespace elsa

Generated by: LCOV version 1.14