LCOV - code coverage report
Current view: top level - elsa/solvers - OGM.cpp (source / functions) Hit Total Coverage
Test: coverage-all.lcov Lines: 63 110 57.3 %
Date: 2024-05-16 04:22:26 Functions: 16 22 72.7 %

          Line data    Source code
       1             : #include "OGM.h"
       2             : #include "DataContainer.h"
       3             : #include "Functional.h"
       4             : #include "TypeCasts.hpp"
       5             : #include "Logger.h"
       6             : #include "PowerIterations.h"
       7             : #include "FixedStepSize.h"
       8             : 
       9             : namespace elsa
      10             : {
      11             :     template <typename data_t>
      12             :     OGM<data_t>::OGM(const Functional<data_t>& problem, data_t epsilon)
      13             :         : Solver<data_t>(),
      14             :           _problem(problem.clone()),
      15             :           _epsilon{epsilon},
      16             :           yOld{empty<data_t>(problem.getDomainDescriptor())},
      17             :           gradient{emptylike(yOld)}
      18           4 :     {
      19           4 :         if (!problem.isDifferentiable()) {
      20           0 :             throw InvalidArgumentError("OGM: Given problem is not differentiable!");
      21           0 :         }
      22           4 :         this->name_ = "OGM";
      23           4 :     }
      24             : 
      25             :     template <typename data_t>
      26             :     OGM<data_t>::OGM(const Functional<data_t>& problem,
      27             :                      const LinearOperator<data_t>& preconditionerInverse, data_t epsilon)
      28             :         : Solver<data_t>(),
      29             :           _problem(problem.clone()),
      30             :           _epsilon{epsilon},
      31             :           _preconditionerInverse{preconditionerInverse.clone()},
      32             :           yOld{empty<data_t>(problem.getDomainDescriptor())},
      33             :           gradient{emptylike(yOld)}
      34           0 :     {
      35           0 :         if (!problem.isDifferentiable()) {
      36           0 :             throw InvalidArgumentError("OGM: Given problem is not differentiable!");
      37           0 :         }
      38             : 
      39             :         // check that preconditioner is compatible with problem
      40           0 :         if (_preconditionerInverse->getDomainDescriptor().getNumberOfCoefficients()
      41           0 :                 != _problem->getDomainDescriptor().getNumberOfCoefficients()
      42           0 :             || _preconditionerInverse->getRangeDescriptor().getNumberOfCoefficients()
      43           0 :                    != _problem->getDomainDescriptor().getNumberOfCoefficients()) {
      44           0 :             throw InvalidArgumentError("OGM: incorrect size of preconditioner");
      45           0 :         }
      46           0 :         this->name_ = "OGM";
      47           0 :     }
      48             :     template <typename data_t>
      49             :     OGM<data_t>::OGM(const Functional<data_t>& problem,
      50             :                      const LineSearchMethod<data_t>& lineSearchMethod, data_t epsilon)
      51             :         : Solver<data_t>(),
      52             :           _problem(problem.clone()),
      53             :           _epsilon{epsilon},
      54             :           _lineSearchMethod(lineSearchMethod.clone()),
      55             :           yOld{empty<data_t>(problem.getDomainDescriptor())},
      56             :           gradient{emptylike(yOld)}
      57           0 :     {
      58           0 :         if (!problem.isDifferentiable()) {
      59           0 :             throw InvalidArgumentError("OGM: Given problem is not differentiable!");
      60           0 :         }
      61           0 :         this->name_ = "OGM";
      62           0 :     }
      63             : 
      64             :     template <typename data_t>
      65             :     OGM<data_t>::OGM(const Functional<data_t>& problem,
      66             :                      const LinearOperator<data_t>& preconditionerInverse,
      67             :                      const LineSearchMethod<data_t>& lineSearchMethod, data_t epsilon)
      68             :         : Solver<data_t>(),
      69             :           _problem(problem.clone()),
      70             :           _epsilon{epsilon},
      71             :           _preconditionerInverse{preconditionerInverse.clone()},
      72             :           _lineSearchMethod(lineSearchMethod.clone()),
      73             :           yOld{empty<data_t>(problem.getDomainDescriptor())},
      74             :           gradient{emptylike(yOld)}
      75           0 :     {
      76           0 :         if (!problem.isDifferentiable()) {
      77           0 :             throw InvalidArgumentError("OGM: Given problem is not differentiable!");
      78           0 :         }
      79             : 
      80             :         // check that preconditioner is compatible with problem
      81           0 :         if (_preconditionerInverse->getDomainDescriptor().getNumberOfCoefficients()
      82           0 :                 != _problem->getDomainDescriptor().getNumberOfCoefficients()
      83           0 :             || _preconditionerInverse->getRangeDescriptor().getNumberOfCoefficients()
      84           0 :                    != _problem->getDomainDescriptor().getNumberOfCoefficients()) {
      85           0 :             throw InvalidArgumentError("OGM: incorrect size of preconditioner");
      86           0 :         }
      87           0 :         this->name_ = "OGM";
      88           0 :     }
      89             : 
      90             :     template <typename data_t>
      91             :     bool OGM<data_t>::shouldStop() const
      92         400 :     {
      93         400 :         return this->gradient.squaredL2Norm() <= _epsilon * _epsilon * deltaZero;
      94         400 :     }
      95             : 
      96             :     template <typename data_t>
      97             :     DataContainer<data_t> OGM<data_t>::setup(std::optional<DataContainer<data_t>> x0)
      98           2 :     {
      99           2 :         this->thetaOld = data_t{1};
     100           2 :         auto x = extract_or(x0, _problem->getDomainDescriptor());
     101           2 :         this->yOld = x;
     102           2 :         this->gradient = _problem->getGradient(x);
     103           2 :         this->deltaZero = this->gradient.squaredL2Norm();
     104             : 
     105           2 :         if (!_lineSearchMethod) {
     106             :             // OGM is very picky when it comes to the accuracy of the used lipschitz constant
     107             :             // therefore we use 20 power iterations instead of 5 here to be more precise. In some
     108             :             // cases OGM might still not converge then an even more precise constant is needed
     109           2 :             auto L = powerIterations(_problem->getHessian(x), 20);
     110           2 :             this->_lineSearchMethod = std::make_unique<FixedStepSize<data_t>>(*_problem, 1 / L);
     111             : 
     112           2 :             Logger::get("OGM")->info("Starting optimization with lipschitz constant {}", L);
     113           2 :         } else {
     114           0 :             Logger::get("OGM")->info("Starting optimization with a lineSearchMethod");
     115           0 :         }
     116             : 
     117           2 :         this->configured_ = true;
     118             : 
     119           2 :         return x;
     120           2 :     }
     121             : 
     122             :     template <typename data_t>
     123             :     std::string OGM<data_t>::formatHeader() const
     124           2 :     {
     125           2 :         return fmt::format("| {:^13} | {:^13} |", "objective", "gradient");
     126           2 :     }
     127             : 
     128             :     template <typename data_t>
     129             :     std::string OGM<data_t>::formatStep(const DataContainer<data_t>& x) const
     130         400 :     {
     131         400 :         auto eval = _problem->evaluate(x);
     132         400 :         auto gradient = _problem->getGradient(x);
     133             : 
     134         400 :         return fmt::format("| {:>13} | {:>13} |", eval, gradient.squaredL2Norm());
     135         400 :     }
     136             : 
     137             :     template <typename data_t>
     138             :     DataContainer<data_t> OGM<data_t>::step(DataContainer<data_t> x)
     139         400 :     {
     140             : 
     141         400 :         auto y = emptylike(x);
     142             : 
     143         400 :         this->_problem->getGradient(x, this->gradient);
     144             : 
     145         400 :         if (this->_preconditionerInverse)
     146           0 :             this->_preconditionerInverse->apply(this->gradient, this->gradient);
     147             : 
     148         400 :         auto alpha = this->_lineSearchMethod->solve(x, -this->gradient);
     149             : 
     150         400 :         lincomb(1, x, -alpha, this->gradient, y);
     151             : 
     152             :         // TODO: Consult David on how to handle this in the steppable interface
     153         400 :         const auto theta =
     154         400 :             data_t{0.5} * (data_t{1} + std::sqrt(data_t{1} + data_t{4} * std::pow(thetaOld, 2)));
     155             : 
     156             :         // x_{i+1} = y_{i+1} + \frac{\theta_i-1}{\theta_{i+1}}(y_{i+1} - y_i) +
     157             :         // \frac{\theta_i}{\theta_{i+1}}/(y_{i+1} - x_i)
     158         400 :         lincomb(1, y, ((this->thetaOld - data_t{1}) / theta), (y - this->yOld), x);
     159         400 :         lincomb(1, x,
     160         400 :                 -(this->thetaOld / theta) * this->_lineSearchMethod->solve(x, -this->gradient),
     161         400 :                 this->gradient, x);
     162             : 
     163         400 :         this->thetaOld = theta;
     164         400 :         this->yOld = y;
     165             : 
     166         400 :         return x;
     167         400 :     }
     168             : 
     169             :     template <typename data_t>
     170             :     OGM<data_t>* OGM<data_t>::cloneImpl() const
     171           2 :     {
     172           2 :         if (_lineSearchMethod and _preconditionerInverse) {
     173           0 :             return new OGM(*_problem, *_preconditionerInverse, *_lineSearchMethod, _epsilon);
     174           2 :         } else if (_preconditionerInverse) {
     175           0 :             return new OGM(*_problem, *_preconditionerInverse, _epsilon);
     176           2 :         } else if (_lineSearchMethod) {
     177           0 :             return new OGM(*_problem, *_lineSearchMethod, _epsilon);
     178           0 :         }
     179           2 :         return new OGM(*_problem, _epsilon);
     180           2 :     }
     181             : 
     182             :     template <typename data_t>
     183             :     bool OGM<data_t>::isEqual(const Solver<data_t>& other) const
     184           2 :     {
     185           2 :         auto otherOGM = downcast_safe<OGM>(&other);
     186           2 :         if (!otherOGM)
     187           0 :             return false;
     188             : 
     189           2 :         if (_epsilon != otherOGM->_epsilon)
     190           0 :             return false;
     191             : 
     192           2 :         if ((_preconditionerInverse && !otherOGM->_preconditionerInverse)
     193           2 :             || (!_preconditionerInverse && otherOGM->_preconditionerInverse))
     194           0 :             return false;
     195             : 
     196           2 :         if (_preconditionerInverse && otherOGM->_preconditionerInverse)
     197           0 :             if (*_preconditionerInverse != *otherOGM->_preconditionerInverse)
     198           0 :                 return false;
     199             : 
     200           2 :         if ((_lineSearchMethod and not otherOGM->_lineSearchMethod)
     201           2 :             or (not _lineSearchMethod and otherOGM->_lineSearchMethod))
     202           0 :             return false;
     203             : 
     204           2 :         if (_lineSearchMethod and otherOGM->_lineSearchMethod)
     205           0 :             if (not _lineSearchMethod->isEqual(*(otherOGM->_lineSearchMethod)))
     206           0 :                 return false;
     207             : 
     208           2 :         return true;
     209           2 :     }
     210             : 
     211             :     // ------------------------------------------
     212             :     // explicit template instantiation
     213             :     template class OGM<float>;
     214             :     template class OGM<double>;
     215             : 
     216             : } // namespace elsa

Generated by: LCOV version 1.14