LCOV - code coverage report
Current view: top level - solvers - CG.cpp (source / functions) Hit Total Coverage
Test: test_coverage.info.cleaned Lines: 0 87 0.0 %
Date: 2022-08-04 03:43:28 Functions: 0 10 0.0 %

          Line data    Source code
       1             : #include "CG.h"
       2             : #include "Logger.h"
       3             : #include "TypeCasts.hpp"
       4             : #include "spdlog/stopwatch.h"
       5             : 
       6             : namespace elsa
       7             : {
       8             : 
       9             :     template <typename data_t>
      10           0 :     CG<data_t>::CG(const Problem<data_t>& problem, data_t epsilon)
      11           0 :         : Solver<data_t>{QuadricProblem<data_t>{problem}}, _epsilon{epsilon}
      12             :     {
      13           0 :     }
      14             : 
      15             :     template <typename data_t>
      16           0 :     CG<data_t>::CG(const Problem<data_t>& problem,
      17             :                    const LinearOperator<data_t>& preconditionerInverse, data_t epsilon)
      18             :         : Solver<data_t>{QuadricProblem<data_t>{problem}},
      19           0 :           _preconditionerInverse{preconditionerInverse.clone()},
      20           0 :           _epsilon{epsilon}
      21             :     {
      22             :         // check that preconditioner is compatible with problem
      23           0 :         if (_preconditionerInverse->getDomainDescriptor().getNumberOfCoefficients()
      24           0 :                 != _problem->getCurrentSolution().getSize()
      25           0 :             || _preconditionerInverse->getRangeDescriptor().getNumberOfCoefficients()
      26           0 :                    != _problem->getCurrentSolution().getSize()) {
      27           0 :             throw InvalidArgumentError("CG: incorrect size of preconditioner");
      28             :         }
      29           0 :     }
      30             : 
      31             :     template <typename data_t>
      32           0 :     DataContainer<data_t>& CG<data_t>::solveImpl(index_t iterations)
      33             :     {
      34           0 :         if (iterations == 0)
      35           0 :             iterations = _defaultIterations;
      36             : 
      37           0 :         spdlog::stopwatch aggregate_time;
      38           0 :         Logger::get("CG")->info("Start preparations...");
      39             : 
      40             :         // get references to some variables in the Quadric
      41           0 :         auto& x = _problem->getCurrentSolution();
      42             :         const auto& gradientExpr =
      43           0 :             static_cast<const Quadric<data_t>&>(_problem->getDataTerm()).getGradientExpression();
      44           0 :         const LinearOperator<data_t>* A = nullptr;
      45           0 :         const DataContainer<data_t>* b = nullptr;
      46             : 
      47           0 :         if (gradientExpr.hasOperator())
      48           0 :             A = &gradientExpr.getOperator();
      49             : 
      50           0 :         if (gradientExpr.hasDataVector())
      51           0 :             b = &gradientExpr.getDataVector();
      52             : 
      53             :         // Start CG initialization
      54           0 :         auto r = _problem->getGradient();
      55           0 :         r *= static_cast<data_t>(-1.0);
      56             : 
      57           0 :         auto d = _preconditionerInverse ? _preconditionerInverse->apply(r) : r;
      58             : 
      59             :         // only allocate space for s if preconditioned
      60           0 :         std::unique_ptr<DataContainer<data_t>> s{};
      61           0 :         if (_preconditionerInverse)
      62           0 :             s = std::make_unique<DataContainer<data_t>>(
      63             :                 _preconditionerInverse->getRangeDescriptor());
      64             : 
      65           0 :         auto deltaNew = r.dot(d);
      66           0 :         auto deltaZero = deltaNew;
      67             : 
      68           0 :         Logger::get("CG")->info("Preparations done, tooke {}s", aggregate_time);
      69             : 
      70           0 :         Logger::get("CG")->info("epsilon: {}", _epsilon);
      71           0 :         Logger::get("CG")->info("delta zero: {}", std::sqrt(deltaZero));
      72             : 
      73             :         // log history legend
      74           0 :         Logger::get("CG")->info("{:^6}|{:*^16}|{:*^16}|{:*^8}|{:*^8}|", "iter", "deltaNew",
      75             :                                 "deltaZero", "time", "elapsed");
      76             : 
      77           0 :         for (index_t it = 0; it != iterations; ++it) {
      78           0 :             spdlog::stopwatch iter_time;
      79           0 :             auto Ad = A ? A->apply(d) : d;
      80             : 
      81           0 :             data_t alpha = deltaNew / d.dot(Ad);
      82             : 
      83           0 :             x += alpha * d;
      84           0 :             r -= alpha * Ad;
      85             : 
      86           0 :             if (_preconditionerInverse)
      87           0 :                 _preconditionerInverse->apply(r, *s);
      88             : 
      89           0 :             const auto deltaOld = deltaNew;
      90             : 
      91           0 :             deltaNew = _preconditionerInverse ? r.dot(*s) : r.squaredL2Norm();
      92             : 
      93             :             // evaluate objective function as -0.5 * x^t[b + (b - Ax)]
      94             :             data_t objVal;
      95           0 :             if (b == nullptr) {
      96           0 :                 objVal = static_cast<data_t>(-0.5) * x.dot(r);
      97             :             } else {
      98           0 :                 objVal = static_cast<data_t>(-0.5) * x.dot(*b + r);
      99             :             }
     100             : 
     101           0 :             Logger::get("CG")->info("{:>5} |{:>15} |{:>15} | {:>6.3} |{:>6.3}s |", it,
     102           0 :                                     std::sqrt(deltaNew), objVal, iter_time, aggregate_time);
     103             : 
     104           0 :             if (deltaNew <= _epsilon * _epsilon * deltaZero) {
     105             :                 // check that we are not stopping prematurely due to accumulated roundoff error
     106           0 :                 r = _problem->getGradient();
     107           0 :                 deltaNew = r.squaredL2Norm();
     108           0 :                 if (deltaNew <= _epsilon * _epsilon * deltaZero) {
     109           0 :                     Logger::get("CG")->info("SUCCESS: Reached convergence at {}/{} iteration",
     110           0 :                                             it + 1, iterations);
     111           0 :                     return x;
     112             :                 } else {
     113             :                     // we are very close to the desired solution, so do a hard reset
     114           0 :                     r *= static_cast<data_t>(-1.0);
     115           0 :                     d = 0;
     116           0 :                     if (_preconditionerInverse)
     117           0 :                         _preconditionerInverse->apply(r, *s);
     118             :                 }
     119             :             }
     120             : 
     121           0 :             const auto beta = deltaNew / deltaOld;
     122           0 :             d = beta * d + (_preconditionerInverse ? *s : r);
     123             :         }
     124             : 
     125           0 :         Logger::get("CG")->warn("Failed to reach convergence at {} iterations", iterations);
     126             : 
     127           0 :         return x;
     128           0 :     }
     129             : 
     130             :     template <typename data_t>
     131           0 :     CG<data_t>* CG<data_t>::cloneImpl() const
     132             :     {
     133           0 :         if (_preconditionerInverse)
     134           0 :             return new CG(*_problem, *_preconditionerInverse, _epsilon);
     135             :         else
     136           0 :             return new CG(*_problem, _epsilon);
     137             :     }
     138             : 
     139             :     template <typename data_t>
     140           0 :     bool CG<data_t>::isEqual(const Solver<data_t>& other) const
     141             :     {
     142           0 :         if (!Solver<data_t>::isEqual(other))
     143           0 :             return false;
     144             : 
     145           0 :         auto otherCG = downcast_safe<CG>(&other);
     146           0 :         if (!otherCG)
     147           0 :             return false;
     148             : 
     149           0 :         if (_epsilon != otherCG->_epsilon)
     150           0 :             return false;
     151             : 
     152           0 :         if ((_preconditionerInverse && !otherCG->_preconditionerInverse)
     153           0 :             || (!_preconditionerInverse && otherCG->_preconditionerInverse))
     154           0 :             return false;
     155             : 
     156           0 :         if (_preconditionerInverse && otherCG->_preconditionerInverse)
     157           0 :             if (*_preconditionerInverse != *otherCG->_preconditionerInverse)
     158           0 :                 return false;
     159             : 
     160           0 :         return true;
     161             :     }
     162             : 
     163             :     // ------------------------------------------
     164             :     // explicit template instantiation
     165             :     template class CG<float>;
     166             :     template class CG<double>;
     167             : 
     168             : } // namespace elsa

Generated by: LCOV version 1.14