LCOV - code coverage report
Current view: top level - elsa/solvers - SQS.cpp (source / functions) Hit Total Coverage
Test: coverage-all.lcov Lines: 97 109 89.0 %
Date: 2022-08-25 03:05:39 Functions: 10 10 100.0 %

          Line data    Source code
       1             : #include <Identity.h>
       2             : #include <Scaling.h>
       3             : #include "SQS.h"
       4             : #include "Logger.h"
       5             : #include "SubsetProblem.h"
       6             : #include "TypeCasts.hpp"
       7             : 
       8             : namespace elsa
       9             : {
      10             :     template <typename data_t>
      11             :     SQS<data_t>::SQS(const Problem<data_t>& problem, bool momentumAcceleration, data_t epsilon)
      12             :         : Solver<data_t>(),
      13             :           _problem(problem.clone()),
      14             :           _epsilon{epsilon},
      15             :           _momentumAcceleration{momentumAcceleration}
      16          14 :     {
      17          14 :         if (is<SubsetProblem<data_t>>(problem)) {
      18           4 :             Logger::get("SQS")->info(
      19           4 :                 "SQS did received a SubsetProblem, running in ordered subset mode");
      20           4 :             _subsetMode = true;
      21          10 :         } else {
      22          10 :             Logger::get("SQS")->info("SQS did not receive a SubsetProblem, running in normal mode");
      23          10 :         }
      24          14 :     }
      25             : 
      26             :     template <typename data_t>
      27             :     SQS<data_t>::SQS(const Problem<data_t>& problem, const LinearOperator<data_t>& preconditioner,
      28             :                      bool momentumAcceleration, data_t epsilon)
      29             :         : Solver<data_t>(),
      30             :           _problem(problem.clone()),
      31             :           _epsilon{epsilon},
      32             :           _preconditioner{preconditioner.clone()},
      33             :           _momentumAcceleration{momentumAcceleration}
      34           8 :     {
      35           8 :         if (is<SubsetProblem<data_t>>(problem)) {
      36           0 :             Logger::get("SQS")->info(
      37           0 :                 "SQS did received a SubsetProblem, running in ordered subset mode");
      38           0 :             _subsetMode = true;
      39           8 :         } else {
      40           8 :             Logger::get("SQS")->info("SQS did not receive a SubsetProblem, running in normal mode");
      41           8 :         }
      42             : 
      43             :         // check that preconditioner is compatible with problem
      44           8 :         if (_preconditioner->getDomainDescriptor().getNumberOfCoefficients()
      45           8 :                 != _problem->getCurrentSolution().getSize()
      46           8 :             || _preconditioner->getRangeDescriptor().getNumberOfCoefficients()
      47           8 :                    != _problem->getCurrentSolution().getSize()) {
      48           0 :             throw InvalidArgumentError("SQS: incorrect size of preconditioner");
      49           0 :         }
      50           8 :     }
      51             : 
      52             :     template <typename data_t>
      53             :     DataContainer<data_t>& SQS<data_t>::solveImpl(index_t iterations)
      54          11 :     {
      55          11 :         if (iterations == 0)
      56           0 :             iterations = _defaultIterations;
      57             : 
      58          11 :         auto convergenceThreshold = _problem->getGradient().squaredL2Norm() * _epsilon * _epsilon;
      59             : 
      60          11 :         auto hessian = _problem->getHessian();
      61             : 
      62          11 :         auto& solutionDesc = _problem->getCurrentSolution().getDataDescriptor();
      63             : 
      64          11 :         auto ones = DataContainer<data_t>(solutionDesc);
      65          11 :         ones = 1;
      66          11 :         auto diagVector = hessian.apply(ones);
      67          11 :         diagVector = static_cast<data_t>(1.0) / diagVector;
      68          11 :         auto diag = Scaling<data_t>(hessian.getDomainDescriptor(), diagVector);
      69             : 
      70          11 :         data_t prevT = 1;
      71          11 :         data_t t = 1;
      72          11 :         data_t nextT = 0;
      73          11 :         auto& z = _problem->getCurrentSolution();
      74          11 :         DataContainer<data_t> x = _problem->getCurrentSolution();
      75          11 :         DataContainer<data_t> prevX = x;
      76          11 :         DataContainer<data_t> gradient(solutionDesc);
      77             : 
      78          11 :         index_t nSubsets = 1;
      79          11 :         if (_subsetMode) {
      80           2 :             const auto& subsetProblem = static_cast<const SubsetProblem<data_t>*>(_problem.get());
      81           2 :             nSubsets = subsetProblem->getNumberOfSubsets();
      82           2 :         }
      83             : 
      84         915 :         for (index_t i = 0; i < iterations; i++) {
      85         909 :             Logger::get("SQS")->info("iteration {} of {}", i + 1, iterations);
      86             : 
      87        1873 :             for (index_t m = 0; m < nSubsets; m++) {
      88         969 :                 if (_subsetMode) {
      89          80 :                     gradient =
      90          80 :                         static_cast<SubsetProblem<data_t>*>(_problem.get())->getSubsetGradient(m);
      91         889 :                 } else {
      92         889 :                     gradient = _problem->getGradient();
      93         889 :                 }
      94             : 
      95         969 :                 if (_preconditioner)
      96         837 :                     gradient = _preconditioner->apply(gradient);
      97             : 
      98             :                 // TODO: element wise relu
      99         969 :                 if (_momentumAcceleration) {
     100         969 :                     nextT = as<data_t>(1)
     101         969 :                             + std::sqrt(as<data_t>(1) + as<data_t>(4) * t * t) / as<data_t>(2);
     102             : 
     103         969 :                     x = z - nSubsets * diag.apply(gradient);
     104         969 :                     z = x + prevT / nextT * (x - prevX);
     105         969 :                 } else {
     106           0 :                     z = z - nSubsets * diag.apply(gradient);
     107           0 :                 }
     108             : 
     109             :                 // if the gradient is too small we stop
     110         969 :                 if (gradient.squaredL2Norm() <= convergenceThreshold) {
     111           5 :                     if (!_subsetMode
     112           5 :                         || _problem->getGradient().squaredL2Norm() <= convergenceThreshold) {
     113           5 :                         Logger::get("SQS")->info("SUCCESS: Reached convergence at {}/{} iteration",
     114           5 :                                                  i + 1, iterations);
     115             : 
     116             :                         // TODO: make return more sane
     117           5 :                         if (_momentumAcceleration) {
     118           5 :                             z = x;
     119           5 :                         }
     120           5 :                         return _problem->getCurrentSolution();
     121           5 :                     }
     122         964 :                 }
     123             : 
     124         964 :                 if (_momentumAcceleration) {
     125         964 :                     prevT = t;
     126         964 :                     t = nextT;
     127         964 :                     prevX = x;
     128         964 :                 }
     129         964 :             }
     130         909 :         }
     131             : 
     132          11 :         Logger::get("SQS")->warn("Failed to reach convergence at {} iterations", iterations);
     133             : 
     134             :         // TODO: make return more sane
     135           6 :         if (_momentumAcceleration) {
     136           6 :             z = x;
     137           6 :         }
     138           6 :         return _problem->getCurrentSolution();
     139          11 :     }
     140             : 
     141             :     template <typename data_t>
     142             :     SQS<data_t>* SQS<data_t>::cloneImpl() const
     143          11 :     {
     144          11 :         if (_preconditioner)
     145           4 :             return new SQS(*_problem, *_preconditioner, _momentumAcceleration, _epsilon);
     146             : 
     147           7 :         return new SQS(*_problem, _momentumAcceleration, _epsilon);
     148           7 :     }
     149             : 
     150             :     template <typename data_t>
     151             :     bool SQS<data_t>::isEqual(const Solver<data_t>& other) const
     152          11 :     {
     153          11 :         auto otherSQS = downcast_safe<SQS>(&other);
     154          11 :         if (!otherSQS)
     155           0 :             return false;
     156             : 
     157          11 :         if (_epsilon != otherSQS->_epsilon)
     158           0 :             return false;
     159             : 
     160          11 :         if ((_preconditioner && !otherSQS->_preconditioner)
     161          11 :             || (!_preconditioner && otherSQS->_preconditioner))
     162           0 :             return false;
     163             : 
     164          11 :         if (_preconditioner && otherSQS->_preconditioner)
     165           4 :             if (*_preconditioner != *otherSQS->_preconditioner)
     166           0 :                 return false;
     167             : 
     168          11 :         return true;
     169          11 :     }
     170             : 
     171             :     // ------------------------------------------
     172             :     // explicit template instantiation
     173             :     template class SQS<float>;
     174             :     template class SQS<double>;
     175             : 
     176             : } // namespace elsa

Generated by: LCOV version 1.14