Line data Source code
1 : #include "CG.h" 2 : #include "Logger.h" 3 : #include "TypeCasts.hpp" 4 : #include "spdlog/stopwatch.h" 5 : 6 : namespace elsa 7 : { 8 : 9 : template <typename data_t> 10 0 : CG<data_t>::CG(const Problem<data_t>& problem, data_t epsilon) 11 0 : : Solver<data_t>{QuadricProblem<data_t>{problem}}, _epsilon{epsilon} 12 : { 13 0 : } 14 : 15 : template <typename data_t> 16 0 : CG<data_t>::CG(const Problem<data_t>& problem, 17 : const LinearOperator<data_t>& preconditionerInverse, data_t epsilon) 18 : : Solver<data_t>{QuadricProblem<data_t>{problem}}, 19 0 : _preconditionerInverse{preconditionerInverse.clone()}, 20 0 : _epsilon{epsilon} 21 : { 22 : // check that preconditioner is compatible with problem 23 0 : if (_preconditionerInverse->getDomainDescriptor().getNumberOfCoefficients() 24 0 : != _problem->getCurrentSolution().getSize() 25 0 : || _preconditionerInverse->getRangeDescriptor().getNumberOfCoefficients() 26 0 : != _problem->getCurrentSolution().getSize()) { 27 0 : throw InvalidArgumentError("CG: incorrect size of preconditioner"); 28 : } 29 0 : } 30 : 31 : template <typename data_t> 32 0 : DataContainer<data_t>& CG<data_t>::solveImpl(index_t iterations) 33 : { 34 0 : if (iterations == 0) 35 0 : iterations = _defaultIterations; 36 : 37 0 : spdlog::stopwatch aggregate_time; 38 0 : Logger::get("CG")->info("Start preparations..."); 39 : 40 : // get references to some variables in the Quadric 41 0 : auto& x = _problem->getCurrentSolution(); 42 : const auto& gradientExpr = 43 0 : static_cast<const Quadric<data_t>&>(_problem->getDataTerm()).getGradientExpression(); 44 0 : const LinearOperator<data_t>* A = nullptr; 45 0 : const DataContainer<data_t>* b = nullptr; 46 : 47 0 : if (gradientExpr.hasOperator()) 48 0 : A = &gradientExpr.getOperator(); 49 : 50 0 : if (gradientExpr.hasDataVector()) 51 0 : b = &gradientExpr.getDataVector(); 52 : 53 : // Start CG initialization 54 0 : auto r = _problem->getGradient(); 55 0 : r *= static_cast<data_t>(-1.0); 56 : 57 0 : auto d = _preconditionerInverse ? _preconditionerInverse->apply(r) : r; 58 : 59 : // only allocate space for s if preconditioned 60 0 : std::unique_ptr<DataContainer<data_t>> s{}; 61 0 : if (_preconditionerInverse) 62 0 : s = std::make_unique<DataContainer<data_t>>( 63 : _preconditionerInverse->getRangeDescriptor()); 64 : 65 0 : auto deltaNew = r.dot(d); 66 0 : auto deltaZero = deltaNew; 67 : 68 0 : Logger::get("CG")->info("Preparations done, tooke {}s", aggregate_time); 69 : 70 0 : Logger::get("CG")->info("epsilon: {}", _epsilon); 71 0 : Logger::get("CG")->info("delta zero: {}", std::sqrt(deltaZero)); 72 : 73 : // log history legend 74 0 : Logger::get("CG")->info("{:^6}|{:*^16}|{:*^16}|{:*^8}|{:*^8}|", "iter", "deltaNew", 75 : "deltaZero", "time", "elapsed"); 76 : 77 0 : for (index_t it = 0; it != iterations; ++it) { 78 0 : spdlog::stopwatch iter_time; 79 0 : auto Ad = A ? A->apply(d) : d; 80 : 81 0 : data_t alpha = deltaNew / d.dot(Ad); 82 : 83 0 : x += alpha * d; 84 0 : r -= alpha * Ad; 85 : 86 0 : if (_preconditionerInverse) 87 0 : _preconditionerInverse->apply(r, *s); 88 : 89 0 : const auto deltaOld = deltaNew; 90 : 91 0 : deltaNew = _preconditionerInverse ? r.dot(*s) : r.squaredL2Norm(); 92 : 93 : // evaluate objective function as -0.5 * x^t[b + (b - Ax)] 94 : data_t objVal; 95 0 : if (b == nullptr) { 96 0 : objVal = static_cast<data_t>(-0.5) * x.dot(r); 97 : } else { 98 0 : objVal = static_cast<data_t>(-0.5) * x.dot(*b + r); 99 : } 100 : 101 0 : Logger::get("CG")->info("{:>5} |{:>15} |{:>15} | {:>6.3} |{:>6.3}s |", it, 102 0 : std::sqrt(deltaNew), objVal, iter_time, aggregate_time); 103 : 104 0 : if (deltaNew <= _epsilon * _epsilon * deltaZero) { 105 : // check that we are not stopping prematurely due to accumulated roundoff error 106 0 : r = _problem->getGradient(); 107 0 : deltaNew = r.squaredL2Norm(); 108 0 : if (deltaNew <= _epsilon * _epsilon * deltaZero) { 109 0 : Logger::get("CG")->info("SUCCESS: Reached convergence at {}/{} iteration", 110 0 : it + 1, iterations); 111 0 : return x; 112 : } else { 113 : // we are very close to the desired solution, so do a hard reset 114 0 : r *= static_cast<data_t>(-1.0); 115 0 : d = 0; 116 0 : if (_preconditionerInverse) 117 0 : _preconditionerInverse->apply(r, *s); 118 : } 119 : } 120 : 121 0 : const auto beta = deltaNew / deltaOld; 122 0 : d = beta * d + (_preconditionerInverse ? *s : r); 123 : } 124 : 125 0 : Logger::get("CG")->warn("Failed to reach convergence at {} iterations", iterations); 126 : 127 0 : return x; 128 0 : } 129 : 130 : template <typename data_t> 131 0 : CG<data_t>* CG<data_t>::cloneImpl() const 132 : { 133 0 : if (_preconditionerInverse) 134 0 : return new CG(*_problem, *_preconditionerInverse, _epsilon); 135 : else 136 0 : return new CG(*_problem, _epsilon); 137 : } 138 : 139 : template <typename data_t> 140 0 : bool CG<data_t>::isEqual(const Solver<data_t>& other) const 141 : { 142 0 : if (!Solver<data_t>::isEqual(other)) 143 0 : return false; 144 : 145 0 : auto otherCG = downcast_safe<CG>(&other); 146 0 : if (!otherCG) 147 0 : return false; 148 : 149 0 : if (_epsilon != otherCG->_epsilon) 150 0 : return false; 151 : 152 0 : if ((_preconditionerInverse && !otherCG->_preconditionerInverse) 153 0 : || (!_preconditionerInverse && otherCG->_preconditionerInverse)) 154 0 : return false; 155 : 156 0 : if (_preconditionerInverse && otherCG->_preconditionerInverse) 157 0 : if (*_preconditionerInverse != *otherCG->_preconditionerInverse) 158 0 : return false; 159 : 160 0 : return true; 161 : } 162 : 163 : // ------------------------------------------ 164 : // explicit template instantiation 165 : template class CG<float>; 166 : template class CG<double>; 167 : 168 : } // namespace elsa