LCOV - code coverage report
Current view: top level - solvers - SQS.cpp (source / functions) Hit Total Coverage
Test: test_coverage.info.cleaned Lines: 0 93 0.0 %
Date: 2022-08-04 03:43:28 Functions: 0 10 0.0 %

          Line data    Source code
       1             : #include <Identity.h>
       2             : #include <Scaling.h>
       3             : #include "SQS.h"
       4             : #include "Logger.h"
       5             : #include "SubsetProblem.h"
       6             : #include "TypeCasts.hpp"
       7             : 
       8             : namespace elsa
       9             : {
      10             :     template <typename data_t>
      11           0 :     SQS<data_t>::SQS(const Problem<data_t>& problem, bool momentumAcceleration, data_t epsilon)
      12           0 :         : Solver<data_t>(problem), _epsilon{epsilon}, _momentumAcceleration{momentumAcceleration}
      13             :     {
      14           0 :         if (is<SubsetProblem<data_t>>(problem)) {
      15           0 :             Logger::get("SQS")->info(
      16             :                 "SQS did received a SubsetProblem, running in ordered subset mode");
      17           0 :             _subsetMode = true;
      18             :         } else {
      19           0 :             Logger::get("SQS")->info("SQS did not receive a SubsetProblem, running in normal mode");
      20             :         }
      21           0 :     }
      22             : 
      23             :     template <typename data_t>
      24           0 :     SQS<data_t>::SQS(const Problem<data_t>& problem, const LinearOperator<data_t>& preconditioner,
      25             :                      bool momentumAcceleration, data_t epsilon)
      26             :         : Solver<data_t>(problem),
      27           0 :           _epsilon{epsilon},
      28           0 :           _preconditioner{preconditioner.clone()},
      29           0 :           _momentumAcceleration{momentumAcceleration}
      30             :     {
      31           0 :         if (is<SubsetProblem<data_t>>(problem)) {
      32           0 :             Logger::get("SQS")->info(
      33             :                 "SQS did received a SubsetProblem, running in ordered subset mode");
      34           0 :             _subsetMode = true;
      35             :         } else {
      36           0 :             Logger::get("SQS")->info("SQS did not receive a SubsetProblem, running in normal mode");
      37             :         }
      38             : 
      39             :         // check that preconditioner is compatible with problem
      40           0 :         if (_preconditioner->getDomainDescriptor().getNumberOfCoefficients()
      41           0 :                 != _problem->getCurrentSolution().getSize()
      42           0 :             || _preconditioner->getRangeDescriptor().getNumberOfCoefficients()
      43           0 :                    != _problem->getCurrentSolution().getSize()) {
      44           0 :             throw InvalidArgumentError("SQS: incorrect size of preconditioner");
      45             :         }
      46           0 :     }
      47             : 
      48             :     template <typename data_t>
      49           0 :     DataContainer<data_t>& SQS<data_t>::solveImpl(index_t iterations)
      50             :     {
      51           0 :         if (iterations == 0)
      52           0 :             iterations = _defaultIterations;
      53             : 
      54           0 :         auto convergenceThreshold = _problem->getGradient().squaredL2Norm() * _epsilon * _epsilon;
      55             : 
      56           0 :         auto hessian = _problem->getHessian();
      57             : 
      58           0 :         auto ones = DataContainer<data_t>(getCurrentSolution().getDataDescriptor());
      59           0 :         ones = 1;
      60           0 :         auto diagVector = hessian.apply(ones);
      61           0 :         diagVector = static_cast<data_t>(1.0) / diagVector;
      62           0 :         auto diag = Scaling<data_t>(hessian.getDomainDescriptor(), diagVector);
      63             : 
      64           0 :         data_t prevT = 1;
      65           0 :         data_t t = 1;
      66           0 :         data_t nextT = 0;
      67           0 :         auto& z = getCurrentSolution();
      68           0 :         DataContainer<data_t> x = DataContainer<data_t>(getCurrentSolution());
      69           0 :         DataContainer<data_t> prevX = x;
      70           0 :         DataContainer<data_t> gradient(getCurrentSolution().getDataDescriptor());
      71             : 
      72           0 :         index_t nSubsets = 1;
      73           0 :         if (_subsetMode) {
      74           0 :             const auto& subsetProblem = static_cast<const SubsetProblem<data_t>*>(_problem.get());
      75           0 :             nSubsets = subsetProblem->getNumberOfSubsets();
      76             :         }
      77             : 
      78           0 :         for (index_t i = 0; i < iterations; i++) {
      79           0 :             Logger::get("SQS")->info("iteration {} of {}", i + 1, iterations);
      80             : 
      81           0 :             for (index_t m = 0; m < nSubsets; m++) {
      82           0 :                 if (_subsetMode) {
      83           0 :                     gradient =
      84           0 :                         static_cast<SubsetProblem<data_t>*>(_problem.get())->getSubsetGradient(m);
      85             :                 } else {
      86           0 :                     gradient = _problem->getGradient();
      87             :                 }
      88             : 
      89           0 :                 if (_preconditioner)
      90           0 :                     gradient = _preconditioner->apply(gradient);
      91             : 
      92             :                 // TODO: element wise relu
      93           0 :                 if (_momentumAcceleration) {
      94           0 :                     nextT = as<data_t>(1)
      95           0 :                             + std::sqrt(as<data_t>(1) + as<data_t>(4) * t * t) / as<data_t>(2);
      96             : 
      97           0 :                     x = z - nSubsets * diag.apply(gradient);
      98           0 :                     z = x + prevT / nextT * (x - prevX);
      99             :                 } else {
     100           0 :                     z = z - nSubsets * diag.apply(gradient);
     101             :                 }
     102             : 
     103             :                 // if the gradient is too small we stop
     104           0 :                 if (gradient.squaredL2Norm() <= convergenceThreshold) {
     105           0 :                     if (!_subsetMode
     106           0 :                         || _problem->getGradient().squaredL2Norm() <= convergenceThreshold) {
     107           0 :                         Logger::get("SQS")->info("SUCCESS: Reached convergence at {}/{} iteration",
     108           0 :                                                  i + 1, iterations);
     109             : 
     110             :                         // TODO: make return more sane
     111           0 :                         if (_momentumAcceleration) {
     112           0 :                             z = x;
     113             :                         }
     114           0 :                         return getCurrentSolution();
     115             :                     }
     116             :                 }
     117             : 
     118           0 :                 if (_momentumAcceleration) {
     119           0 :                     prevT = t;
     120           0 :                     t = nextT;
     121           0 :                     prevX = x;
     122             :                 }
     123             :             }
     124             :         }
     125             : 
     126           0 :         Logger::get("SQS")->warn("Failed to reach convergence at {} iterations", iterations);
     127             : 
     128             :         // TODO: make return more sane
     129           0 :         if (_momentumAcceleration) {
     130           0 :             z = x;
     131             :         }
     132           0 :         return getCurrentSolution();
     133           0 :     }
     134             : 
     135             :     template <typename data_t>
     136           0 :     SQS<data_t>* SQS<data_t>::cloneImpl() const
     137             :     {
     138           0 :         if (_preconditioner)
     139           0 :             return new SQS(*_problem, *_preconditioner, _momentumAcceleration, _epsilon);
     140             : 
     141           0 :         return new SQS(*_problem, _momentumAcceleration, _epsilon);
     142             :     }
     143             : 
     144             :     template <typename data_t>
     145           0 :     bool SQS<data_t>::isEqual(const Solver<data_t>& other) const
     146             :     {
     147           0 :         if (!Solver<data_t>::isEqual(other))
     148           0 :             return false;
     149             : 
     150           0 :         auto otherSQS = downcast_safe<SQS>(&other);
     151           0 :         if (!otherSQS)
     152           0 :             return false;
     153             : 
     154           0 :         if (_epsilon != otherSQS->_epsilon)
     155           0 :             return false;
     156             : 
     157           0 :         if ((_preconditioner && !otherSQS->_preconditioner)
     158           0 :             || (!_preconditioner && otherSQS->_preconditioner))
     159           0 :             return false;
     160             : 
     161           0 :         if (_preconditioner && otherSQS->_preconditioner)
     162           0 :             if (*_preconditioner != *otherSQS->_preconditioner)
     163           0 :                 return false;
     164             : 
     165           0 :         return true;
     166             :     }
     167             : 
     168             :     // ------------------------------------------
     169             :     // explicit template instantiation
     170             :     template class SQS<float>;
     171             :     template class SQS<double>;
     172             : 
     173             : } // namespace elsa

Generated by: LCOV version 1.14