Line data Source code
1 : #include "CG.h" 2 : #include "Logger.h" 3 : #include "TypeCasts.hpp" 4 : #include "spdlog/stopwatch.h" 5 : 6 : namespace elsa 7 : { 8 : 9 : template <typename data_t> 10 : CG<data_t>::CG(const Problem<data_t>& problem, data_t epsilon) 11 : : Solver<data_t>(), _problem{QuadricProblem<data_t>{problem}}, _epsilon{epsilon} 12 46 : { 13 46 : } 14 : 15 : template <typename data_t> 16 : CG<data_t>::CG(const Problem<data_t>& problem, 17 : const LinearOperator<data_t>& preconditionerInverse, data_t epsilon) 18 : : Solver<data_t>(), 19 : _problem{QuadricProblem<data_t>{problem}}, 20 : _preconditionerInverse{preconditionerInverse.clone()}, 21 : _epsilon{epsilon} 22 8 : { 23 : // check that preconditioner is compatible with problem 24 8 : if (_preconditionerInverse->getDomainDescriptor().getNumberOfCoefficients() 25 8 : != _problem.getCurrentSolution().getSize() 26 8 : || _preconditionerInverse->getRangeDescriptor().getNumberOfCoefficients() 27 8 : != _problem.getCurrentSolution().getSize()) { 28 0 : throw InvalidArgumentError("CG: incorrect size of preconditioner"); 29 0 : } 30 8 : } 31 : 32 : template <typename data_t> 33 : DataContainer<data_t>& CG<data_t>::solveImpl(index_t iterations) 34 45 : { 35 45 : if (iterations == 0) 36 0 : iterations = _defaultIterations; 37 : 38 45 : spdlog::stopwatch aggregate_time; 39 45 : Logger::get("CG")->info("Start preparations..."); 40 : 41 : // get references to some variables in the Quadric 42 45 : auto& x = _problem.getCurrentSolution(); 43 45 : const auto& gradientExpr = 44 45 : static_cast<const Quadric<data_t>&>(_problem.getDataTerm()).getGradientExpression(); 45 45 : const LinearOperator<data_t>* A = nullptr; 46 45 : const DataContainer<data_t>* b = nullptr; 47 : 48 45 : if (gradientExpr.hasOperator()) 49 45 : A = &gradientExpr.getOperator(); 50 : 51 45 : if (gradientExpr.hasDataVector()) 52 45 : b = &gradientExpr.getDataVector(); 53 : 54 : // Start CG initialization 55 45 : auto r = _problem.getGradient(); 56 45 : r *= static_cast<data_t>(-1.0); 57 : 58 45 : auto d = _preconditionerInverse ? _preconditionerInverse->apply(r) : r; 59 : 60 : // only allocate space for s if preconditioned 61 45 : std::unique_ptr<DataContainer<data_t>> s{}; 62 45 : if (_preconditionerInverse) 63 4 : s = std::make_unique<DataContainer<data_t>>( 64 4 : _preconditionerInverse->getRangeDescriptor()); 65 : 66 45 : auto deltaNew = r.dot(d); 67 45 : auto deltaZero = deltaNew; 68 : 69 45 : Logger::get("CG")->info("Preparations done, tooke {}s", aggregate_time); 70 : 71 45 : Logger::get("CG")->info("epsilon: {}", _epsilon); 72 45 : Logger::get("CG")->info("delta zero: {}", std::sqrt(deltaZero)); 73 : 74 : // log history legend 75 45 : Logger::get("CG")->info("{:^6}|{:*^16}|{:*^16}|{:*^8}|{:*^8}|", "iter", "deltaNew", 76 45 : "deltaZero", "time", "elapsed"); 77 : 78 473 : for (index_t it = 0; it != iterations; ++it) { 79 450 : spdlog::stopwatch iter_time; 80 450 : auto Ad = A ? A->apply(d) : d; 81 : 82 450 : data_t alpha = deltaNew / d.dot(Ad); 83 : 84 450 : x += alpha * d; 85 450 : r -= alpha * Ad; 86 : 87 450 : if (_preconditionerInverse) 88 4 : _preconditionerInverse->apply(r, *s); 89 : 90 450 : const auto deltaOld = deltaNew; 91 : 92 450 : deltaNew = _preconditionerInverse ? r.dot(*s) : r.squaredL2Norm(); 93 : 94 : // evaluate objective function as -0.5 * x^t[b + (b - Ax)] 95 450 : data_t objVal; 96 450 : if (b == nullptr) { 97 0 : objVal = static_cast<data_t>(-0.5) * x.dot(r); 98 450 : } else { 99 450 : objVal = static_cast<data_t>(-0.5) * x.dot(*b + r); 100 450 : } 101 : 102 450 : Logger::get("CG")->info("{:>5} |{:>15} |{:>15} | {:>6.3} |{:>6.3}s |", it, 103 450 : std::sqrt(deltaNew), objVal, iter_time, aggregate_time); 104 : 105 450 : if (deltaNew <= _epsilon * _epsilon * deltaZero) { 106 : // check that we are not stopping prematurely due to accumulated roundoff error 107 134 : r = _problem.getGradient(); 108 134 : deltaNew = r.squaredL2Norm(); 109 134 : if (deltaNew <= _epsilon * _epsilon * deltaZero) { 110 22 : Logger::get("CG")->info("SUCCESS: Reached convergence at {}/{} iteration", 111 22 : it + 1, iterations); 112 22 : return x; 113 112 : } else { 114 : // we are very close to the desired solution, so do a hard reset 115 112 : r *= static_cast<data_t>(-1.0); 116 112 : d = 0; 117 112 : if (_preconditionerInverse) 118 0 : _preconditionerInverse->apply(r, *s); 119 112 : } 120 134 : } 121 : 122 450 : const auto beta = deltaNew / deltaOld; 123 428 : d = beta * d + (_preconditionerInverse ? *s : r); 124 428 : } 125 : 126 45 : Logger::get("CG")->warn("Failed to reach convergence at {} iterations", iterations); 127 : 128 23 : return x; 129 45 : } 130 : 131 : template <typename data_t> 132 : CG<data_t>* CG<data_t>::cloneImpl() const 133 9 : { 134 9 : if (_preconditionerInverse) 135 4 : return new CG(_problem, *_preconditionerInverse, _epsilon); 136 5 : else 137 5 : return new CG(_problem, _epsilon); 138 9 : } 139 : 140 : template <typename data_t> 141 : bool CG<data_t>::isEqual(const Solver<data_t>& other) const 142 9 : { 143 9 : auto otherCG = downcast_safe<CG>(&other); 144 9 : if (!otherCG) 145 0 : return false; 146 : 147 9 : if (_epsilon != otherCG->_epsilon) 148 0 : return false; 149 : 150 9 : if ((_preconditionerInverse && !otherCG->_preconditionerInverse) 151 9 : || (!_preconditionerInverse && otherCG->_preconditionerInverse)) 152 0 : return false; 153 : 154 9 : if (_preconditionerInverse && otherCG->_preconditionerInverse) 155 4 : if (*_preconditionerInverse != *otherCG->_preconditionerInverse) 156 0 : return false; 157 : 158 9 : return true; 159 9 : } 160 : 161 : // ------------------------------------------ 162 : // explicit template instantiation 163 : template class CG<float>; 164 : template class CG<double>; 165 : 166 : } // namespace elsa