LCOV - code coverage report
Current view: top level - elsa/solvers - CG.cpp (source / functions) Hit Total Coverage
Test: coverage-all.lcov Lines: 87 96 90.6 %
Date: 2022-08-25 03:05:39 Functions: 10 10 100.0 %

          Line data    Source code
       1             : #include "CG.h"
       2             : #include "Logger.h"
       3             : #include "TypeCasts.hpp"
       4             : #include "spdlog/stopwatch.h"
       5             : 
       6             : namespace elsa
       7             : {
       8             : 
       9             :     template <typename data_t>
      10             :     CG<data_t>::CG(const Problem<data_t>& problem, data_t epsilon)
      11             :         : Solver<data_t>(), _problem{QuadricProblem<data_t>{problem}}, _epsilon{epsilon}
      12          46 :     {
      13          46 :     }
      14             : 
      15             :     template <typename data_t>
      16             :     CG<data_t>::CG(const Problem<data_t>& problem,
      17             :                    const LinearOperator<data_t>& preconditionerInverse, data_t epsilon)
      18             :         : Solver<data_t>(),
      19             :           _problem{QuadricProblem<data_t>{problem}},
      20             :           _preconditionerInverse{preconditionerInverse.clone()},
      21             :           _epsilon{epsilon}
      22           8 :     {
      23             :         // check that preconditioner is compatible with problem
      24           8 :         if (_preconditionerInverse->getDomainDescriptor().getNumberOfCoefficients()
      25           8 :                 != _problem.getCurrentSolution().getSize()
      26           8 :             || _preconditionerInverse->getRangeDescriptor().getNumberOfCoefficients()
      27           8 :                    != _problem.getCurrentSolution().getSize()) {
      28           0 :             throw InvalidArgumentError("CG: incorrect size of preconditioner");
      29           0 :         }
      30           8 :     }
      31             : 
      32             :     template <typename data_t>
      33             :     DataContainer<data_t>& CG<data_t>::solveImpl(index_t iterations)
      34          45 :     {
      35          45 :         if (iterations == 0)
      36           0 :             iterations = _defaultIterations;
      37             : 
      38          45 :         spdlog::stopwatch aggregate_time;
      39          45 :         Logger::get("CG")->info("Start preparations...");
      40             : 
      41             :         // get references to some variables in the Quadric
      42          45 :         auto& x = _problem.getCurrentSolution();
      43          45 :         const auto& gradientExpr =
      44          45 :             static_cast<const Quadric<data_t>&>(_problem.getDataTerm()).getGradientExpression();
      45          45 :         const LinearOperator<data_t>* A = nullptr;
      46          45 :         const DataContainer<data_t>* b = nullptr;
      47             : 
      48          45 :         if (gradientExpr.hasOperator())
      49          45 :             A = &gradientExpr.getOperator();
      50             : 
      51          45 :         if (gradientExpr.hasDataVector())
      52          45 :             b = &gradientExpr.getDataVector();
      53             : 
      54             :         // Start CG initialization
      55          45 :         auto r = _problem.getGradient();
      56          45 :         r *= static_cast<data_t>(-1.0);
      57             : 
      58          45 :         auto d = _preconditionerInverse ? _preconditionerInverse->apply(r) : r;
      59             : 
      60             :         // only allocate space for s if preconditioned
      61          45 :         std::unique_ptr<DataContainer<data_t>> s{};
      62          45 :         if (_preconditionerInverse)
      63           4 :             s = std::make_unique<DataContainer<data_t>>(
      64           4 :                 _preconditionerInverse->getRangeDescriptor());
      65             : 
      66          45 :         auto deltaNew = r.dot(d);
      67          45 :         auto deltaZero = deltaNew;
      68             : 
      69          45 :         Logger::get("CG")->info("Preparations done, tooke {}s", aggregate_time);
      70             : 
      71          45 :         Logger::get("CG")->info("epsilon: {}", _epsilon);
      72          45 :         Logger::get("CG")->info("delta zero: {}", std::sqrt(deltaZero));
      73             : 
      74             :         // log history legend
      75          45 :         Logger::get("CG")->info("{:^6}|{:*^16}|{:*^16}|{:*^8}|{:*^8}|", "iter", "deltaNew",
      76          45 :                                 "deltaZero", "time", "elapsed");
      77             : 
      78         473 :         for (index_t it = 0; it != iterations; ++it) {
      79         450 :             spdlog::stopwatch iter_time;
      80         450 :             auto Ad = A ? A->apply(d) : d;
      81             : 
      82         450 :             data_t alpha = deltaNew / d.dot(Ad);
      83             : 
      84         450 :             x += alpha * d;
      85         450 :             r -= alpha * Ad;
      86             : 
      87         450 :             if (_preconditionerInverse)
      88           4 :                 _preconditionerInverse->apply(r, *s);
      89             : 
      90         450 :             const auto deltaOld = deltaNew;
      91             : 
      92         450 :             deltaNew = _preconditionerInverse ? r.dot(*s) : r.squaredL2Norm();
      93             : 
      94             :             // evaluate objective function as -0.5 * x^t[b + (b - Ax)]
      95         450 :             data_t objVal;
      96         450 :             if (b == nullptr) {
      97           0 :                 objVal = static_cast<data_t>(-0.5) * x.dot(r);
      98         450 :             } else {
      99         450 :                 objVal = static_cast<data_t>(-0.5) * x.dot(*b + r);
     100         450 :             }
     101             : 
     102         450 :             Logger::get("CG")->info("{:>5} |{:>15} |{:>15} | {:>6.3} |{:>6.3}s |", it,
     103         450 :                                     std::sqrt(deltaNew), objVal, iter_time, aggregate_time);
     104             : 
     105         450 :             if (deltaNew <= _epsilon * _epsilon * deltaZero) {
     106             :                 // check that we are not stopping prematurely due to accumulated roundoff error
     107         134 :                 r = _problem.getGradient();
     108         134 :                 deltaNew = r.squaredL2Norm();
     109         134 :                 if (deltaNew <= _epsilon * _epsilon * deltaZero) {
     110          22 :                     Logger::get("CG")->info("SUCCESS: Reached convergence at {}/{} iteration",
     111          22 :                                             it + 1, iterations);
     112          22 :                     return x;
     113         112 :                 } else {
     114             :                     // we are very close to the desired solution, so do a hard reset
     115         112 :                     r *= static_cast<data_t>(-1.0);
     116         112 :                     d = 0;
     117         112 :                     if (_preconditionerInverse)
     118           0 :                         _preconditionerInverse->apply(r, *s);
     119         112 :                 }
     120         134 :             }
     121             : 
     122         450 :             const auto beta = deltaNew / deltaOld;
     123         428 :             d = beta * d + (_preconditionerInverse ? *s : r);
     124         428 :         }
     125             : 
     126          45 :         Logger::get("CG")->warn("Failed to reach convergence at {} iterations", iterations);
     127             : 
     128          23 :         return x;
     129          45 :     }
     130             : 
     131             :     template <typename data_t>
     132             :     CG<data_t>* CG<data_t>::cloneImpl() const
     133           9 :     {
     134           9 :         if (_preconditionerInverse)
     135           4 :             return new CG(_problem, *_preconditionerInverse, _epsilon);
     136           5 :         else
     137           5 :             return new CG(_problem, _epsilon);
     138           9 :     }
     139             : 
     140             :     template <typename data_t>
     141             :     bool CG<data_t>::isEqual(const Solver<data_t>& other) const
     142           9 :     {
     143           9 :         auto otherCG = downcast_safe<CG>(&other);
     144           9 :         if (!otherCG)
     145           0 :             return false;
     146             : 
     147           9 :         if (_epsilon != otherCG->_epsilon)
     148           0 :             return false;
     149             : 
     150           9 :         if ((_preconditionerInverse && !otherCG->_preconditionerInverse)
     151           9 :             || (!_preconditionerInverse && otherCG->_preconditionerInverse))
     152           0 :             return false;
     153             : 
     154           9 :         if (_preconditionerInverse && otherCG->_preconditionerInverse)
     155           4 :             if (*_preconditionerInverse != *otherCG->_preconditionerInverse)
     156           0 :                 return false;
     157             : 
     158           9 :         return true;
     159           9 :     }
     160             : 
     161             :     // ------------------------------------------
     162             :     // explicit template instantiation
     163             :     template class CG<float>;
     164             :     template class CG<double>;
     165             : 
     166             : } // namespace elsa

Generated by: LCOV version 1.14