LCOV - code coverage report
Current view: top level - elsa/solvers - FGM.cpp (source / functions) Hit Total Coverage
Test: coverage-all.lcov Lines: 65 104 62.5 %
Date: 2024-05-16 04:22:26 Functions: 18 22 81.8 %

          Line data    Source code
       1             : #include "FGM.h"
       2             : #include "DataContainer.h"
       3             : #include "Error.h"
       4             : #include "Functional.h"
       5             : #include "Logger.h"
       6             : #include "TypeCasts.hpp"
       7             : #include "PowerIterations.h"
       8             : #include "FixedStepSize.h"
       9             : 
      10             : namespace elsa
      11             : {
      12             :     template <typename data_t>
      13             :     FGM<data_t>::FGM(const Functional<data_t>& problem, data_t epsilon)
      14             :         : Solver<data_t>(),
      15             :           _problem{problem.clone()},
      16             :           _epsilon{epsilon},
      17             :           yOld{empty<data_t>(problem.getDomainDescriptor())},
      18             :           gradient{emptylike(yOld)}
      19             : 
      20           2 :     {
      21           2 :         if (!problem.isDifferentiable()) {
      22           0 :             throw InvalidArgumentError("FGM: Given problem is not differentiable!");
      23           0 :         }
      24           2 :         this->name_ = "FGM";
      25           2 :     }
      26             : 
      27             :     template <typename data_t>
      28             :     FGM<data_t>::FGM(const Functional<data_t>& problem,
      29             :                      const LinearOperator<data_t>& preconditionerInverse, data_t epsilon)
      30             :         : Solver<data_t>(),
      31             :           _problem{problem.clone()},
      32             :           _epsilon{epsilon},
      33             :           _preconditionerInverse{preconditionerInverse.clone()},
      34             :           yOld{empty<data_t>(problem.getDomainDescriptor())},
      35             :           gradient{emptylike(yOld)}
      36           0 :     {
      37           0 :         if (!problem.isDifferentiable()) {
      38           0 :             throw InvalidArgumentError("FGM: Given problem is not differentiable!");
      39           0 :         }
      40             : 
      41             :         // check that preconditioner is compatible with problem
      42           0 :         if (_preconditionerInverse->getDomainDescriptor().getNumberOfCoefficients()
      43           0 :                 != _problem->getDomainDescriptor().getNumberOfCoefficients()
      44           0 :             || _preconditionerInverse->getRangeDescriptor().getNumberOfCoefficients()
      45           0 :                    != _problem->getDomainDescriptor().getNumberOfCoefficients()) {
      46           0 :             throw InvalidArgumentError("FGM: incorrect size of preconditioner");
      47           0 :         }
      48           0 :         this->name_ = "FGM";
      49           0 :     }
      50             :     template <typename data_t>
      51             :     FGM<data_t>::FGM(const Functional<data_t>& problem,
      52             :                      const LineSearchMethod<data_t>& lineSearchMethod, data_t epsilon)
      53             :         : Solver<data_t>(),
      54             :           _problem{problem.clone()},
      55             :           _epsilon{epsilon},
      56             :           _lineSearchMethod{lineSearchMethod.clone()},
      57             :           yOld{empty<data_t>(problem.getDomainDescriptor())},
      58             :           gradient{emptylike(yOld)}
      59           2 :     {
      60           2 :         if (!problem.isDifferentiable()) {
      61           0 :             throw InvalidArgumentError("FGM: Given problem is not differentiable!");
      62           0 :         }
      63           2 :         this->name_ = "FGM";
      64           2 :     }
      65             : 
      66             :     template <typename data_t>
      67             :     FGM<data_t>::FGM(const Functional<data_t>& problem,
      68             :                      const LinearOperator<data_t>& preconditionerInverse,
      69             :                      const LineSearchMethod<data_t>& lineSearchMethod, data_t epsilon)
      70             :         : Solver<data_t>(),
      71             :           _problem{problem.clone()},
      72             :           _epsilon{epsilon},
      73             :           _preconditionerInverse{preconditionerInverse.clone()},
      74             :           _lineSearchMethod{lineSearchMethod.clone()},
      75             :           yOld{empty<data_t>(problem.getDomainDescriptor())},
      76             :           gradient{emptylike(yOld)}
      77           0 :     {
      78             :         // check that preconditioner is compatible with problem
      79           0 :         if (_preconditionerInverse->getDomainDescriptor().getNumberOfCoefficients()
      80           0 :                 != _problem->getDomainDescriptor().getNumberOfCoefficients()
      81           0 :             || _preconditionerInverse->getRangeDescriptor().getNumberOfCoefficients()
      82           0 :                    != _problem->getDomainDescriptor().getNumberOfCoefficients()) {
      83           0 :             throw InvalidArgumentError("FGM: incorrect size of preconditioner");
      84           0 :         }
      85           0 :         this->name_ = "FGM";
      86           0 :     }
      87             :     template <typename data_t>
      88             :     bool FGM<data_t>::shouldStop() const
      89           6 :     {
      90           6 :         return gradient.squaredL2Norm() <= _epsilon * _epsilon * deltaZero;
      91           6 :     }
      92             : 
      93             :     template <typename data_t>
      94             :     DataContainer<data_t> FGM<data_t>::setup(std::optional<DataContainer<data_t>> x0)
      95           2 :     {
      96           2 :         auto x = extract_or(x0, _problem->getDomainDescriptor());
      97             : 
      98           2 :         this->thetaOld = static_cast<data_t>(1.0);
      99           2 :         this->yOld = x;
     100           2 :         this->gradient = _problem->getGradient(x);
     101           2 :         this->deltaZero = this->gradient.squaredL2Norm();
     102             : 
     103           2 :         if (!_lineSearchMethod) {
     104           2 :             auto L = powerIterations(_problem->getHessian(x), 5);
     105           2 :             this->_lineSearchMethod = std::make_unique<FixedStepSize<data_t>>(*_problem, 1 / L);
     106           2 :             Logger::get("FGM")->info("Starting optimization with lipschitz constant {}", L);
     107           2 :         } else {
     108           0 :             Logger::get("FGM")->info("Starting optimization with a lineSearchMethod");
     109           0 :         }
     110           2 :         this->configured_ = true;
     111           2 :         return x;
     112           2 :     }
     113             : 
     114             :     template <typename data_t>
     115             :     DataContainer<data_t> FGM<data_t>::step(DataContainer<data_t> x)
     116           4 :     {
     117             : 
     118           4 :         this->gradient = _problem->getGradient(x);
     119             : 
     120           4 :         if (_preconditionerInverse)
     121           0 :             this->gradient = _preconditionerInverse->apply(this->gradient);
     122             : 
     123           4 :         auto alpha = _lineSearchMethod->solve(x, -this->gradient);
     124             : 
     125           4 :         auto y = emptylike(x);
     126           4 :         lincomb(1, x, -alpha, gradient, y);
     127             : 
     128           4 :         const auto theta =
     129           4 :             (data_t{1} + std::sqrt(data_t{1} + data_t{4} * thetaOld * thetaOld)) / data_t{2};
     130             : 
     131           4 :         lincomb(1, y, (thetaOld - data_t{1}) / theta, (y - yOld), x);
     132             : 
     133           4 :         this->thetaOld = theta;
     134           4 :         this->yOld = y;
     135             : 
     136           4 :         return x;
     137           4 :     }
     138             : 
     139             :     template <typename data_t>
     140             :     std::string FGM<data_t>::formatHeader() const
     141           2 :     {
     142           2 :         return fmt::format("| {:^13} | {:^13} |", "objective", "gradient");
     143           2 :     }
     144             : 
     145             :     template <typename data_t>
     146             :     std::string FGM<data_t>::formatStep(const DataContainer<data_t>& x) const
     147           4 :     {
     148           4 :         auto eval = _problem->evaluate(x);
     149           4 :         auto gradient = _problem->getGradient(x);
     150             : 
     151           4 :         return fmt::format("| {:>13} | {:>13} |", eval, gradient.squaredL2Norm());
     152           4 :     }
     153             : 
     154             :     template <typename data_t>
     155             :     FGM<data_t>* FGM<data_t>::cloneImpl() const
     156           2 :     {
     157           2 :         if (_lineSearchMethod and _preconditionerInverse) {
     158           0 :             return new FGM(*_problem, *_preconditionerInverse, *_lineSearchMethod, _epsilon);
     159           2 :         } else if (_preconditionerInverse) {
     160           0 :             return new FGM(*_problem, *_preconditionerInverse, _epsilon);
     161           2 :         } else if (_lineSearchMethod) {
     162           2 :             return new FGM(*_problem, *_lineSearchMethod, _epsilon);
     163           2 :         }
     164             : 
     165           0 :         return new FGM(*_problem, _epsilon);
     166           0 :     }
     167             : 
     168             :     template <typename data_t>
     169             :     bool FGM<data_t>::isEqual(const Solver<data_t>& other) const
     170           2 :     {
     171           2 :         auto otherFGM = downcast_safe<FGM>(&other);
     172           2 :         if (!otherFGM)
     173           0 :             return false;
     174             : 
     175           2 :         if (_epsilon != otherFGM->_epsilon)
     176           0 :             return false;
     177             : 
     178           2 :         if ((_preconditionerInverse && !otherFGM->_preconditionerInverse)
     179           2 :             || (!_preconditionerInverse && otherFGM->_preconditionerInverse))
     180           0 :             return false;
     181             : 
     182           2 :         if (_preconditionerInverse && otherFGM->_preconditionerInverse)
     183           0 :             if (*_preconditionerInverse != *otherFGM->_preconditionerInverse)
     184           0 :                 return false;
     185             : 
     186           2 :         if ((_lineSearchMethod and not otherFGM->_lineSearchMethod)
     187           2 :             or (not _lineSearchMethod and otherFGM->_lineSearchMethod))
     188           0 :             return false;
     189             : 
     190           2 :         if (_lineSearchMethod and otherFGM->_lineSearchMethod)
     191           2 :             if (not _lineSearchMethod->isEqual(*(otherFGM->_lineSearchMethod)))
     192           0 :                 return false;
     193             : 
     194           2 :         return true;
     195           2 :     }
     196             : 
     197             :     // ------------------------------------------
     198             :     // explicit template instantiation
     199             :     template class FGM<float>;
     200             :     template class FGM<double>;
     201             : 
     202             : } // namespace elsa

Generated by: LCOV version 1.14