Line data Source code
1 : #include "OGM.h" 2 : #include "TypeCasts.hpp" 3 : #include "Logger.h" 4 : 5 : namespace elsa 6 : { 7 : template <typename data_t> 8 : OGM<data_t>::OGM(const Problem<data_t>& problem, data_t epsilon) 9 : : Solver<data_t>(), _problem(problem.clone()), _epsilon{epsilon} 10 10 : { 11 10 : } 12 : 13 : template <typename data_t> 14 : OGM<data_t>::OGM(const Problem<data_t>& problem, 15 : const LinearOperator<data_t>& preconditionerInverse, data_t epsilon) 16 : : Solver<data_t>(), 17 : _problem(problem.clone()), 18 : _epsilon{epsilon}, 19 : _preconditionerInverse{preconditionerInverse.clone()} 20 8 : { 21 : // check that preconditioner is compatible with problem 22 8 : if (_preconditionerInverse->getDomainDescriptor().getNumberOfCoefficients() 23 8 : != _problem->getCurrentSolution().getSize() 24 8 : || _preconditionerInverse->getRangeDescriptor().getNumberOfCoefficients() 25 8 : != _problem->getCurrentSolution().getSize()) { 26 0 : throw InvalidArgumentError("OGM: incorrect size of preconditioner"); 27 0 : } 28 8 : } 29 : 30 : template <typename data_t> 31 : DataContainer<data_t>& OGM<data_t>::solveImpl(index_t iterations) 32 9 : { 33 9 : if (iterations == 0) 34 0 : iterations = _defaultIterations; 35 : 36 9 : auto prevTheta = static_cast<data_t>(1.0); 37 9 : auto x0 = _problem->getCurrentSolution(); 38 9 : auto& prevY = x0; 39 : 40 : // OGM is very picky when it comes to the accuracy of the used lipschitz constant therefore 41 : // we use 20 power iterations instead of 5 here to be more precise. 42 : // In some cases OGM might still not converge then an even more precise constant is needed 43 9 : auto lipschitz = _problem->getLipschitzConstant(20); 44 9 : auto deltaZero = _problem->getGradient().squaredL2Norm(); 45 9 : Logger::get("OGM")->info("Starting optimization with lipschitz constant {}", lipschitz); 46 : 47 : // log history legend 48 9 : Logger::get("OGM")->info("{:*^20}|{:*^20}|{:*^20}|{:*^20}|{:*^20}", "iteration", 49 9 : "thetaRatio0", "thetaRatio1", "y", "gradient"); 50 : 51 4272 : for (index_t i = 0; i < iterations; ++i) { 52 4263 : auto& x = _problem->getCurrentSolution(); 53 : 54 4263 : auto gradient = _problem->getGradient(); 55 : 56 4263 : if (_preconditionerInverse) 57 1624 : gradient = _preconditionerInverse->apply(gradient); 58 : 59 4263 : DataContainer<data_t> y = x - gradient / lipschitz; 60 4263 : data_t theta; 61 4263 : if (i == iterations - 1) { // last iteration 62 9 : theta = (static_cast<data_t>(1.0) 63 9 : + std::sqrt(static_cast<data_t>(1.0) 64 9 : + static_cast<data_t>(8.0) * prevTheta * prevTheta)) 65 9 : / static_cast<data_t>(2.0); 66 4254 : } else { 67 4254 : theta = (static_cast<data_t>(1.0) 68 4254 : + std::sqrt(static_cast<data_t>(1.0) 69 4254 : + static_cast<data_t>(4.0) * prevTheta * prevTheta)) 70 4254 : / static_cast<data_t>(2.0); 71 4254 : } 72 : 73 4263 : Logger::get("OGM")->info(" {:<19}| {:<19}| {:<19}| {:<19}| {:<19}", i, 74 4263 : (prevTheta - 1) / theta, prevTheta / theta, y.squaredL2Norm(), 75 4263 : gradient.squaredL2Norm()); 76 : 77 : // x_{i+1} = y_{i+1} + \frac{\theta_i-1}{\theta_{i+1}}(y_{i+1} - y_i) + 78 : // \frac{\theta_i}{\theta_{i+1}}/(y_{i+1} - x_i) 79 4263 : x = y + ((prevTheta - static_cast<data_t>(1.0)) / theta) * (y - prevY) 80 4263 : - (prevTheta / theta) * (gradient / lipschitz); 81 4263 : prevTheta = theta; 82 4263 : prevY = y; 83 : 84 : // if the gradient is too small we stop 85 4263 : if (gradient.squaredL2Norm() <= _epsilon * _epsilon * deltaZero) { 86 0 : Logger::get("OGM")->info("SUCCESS: Reached convergence at {}/{} iteration", i + 1, 87 0 : iterations); 88 0 : return x; 89 0 : } 90 4263 : } 91 : 92 9 : Logger::get("OGM")->warn("Failed to reach convergence at {} iterations", iterations); 93 : 94 9 : return _problem->getCurrentSolution(); 95 9 : } 96 : 97 : template <typename data_t> 98 : OGM<data_t>* OGM<data_t>::cloneImpl() const 99 9 : { 100 9 : if (_preconditionerInverse) 101 4 : return new OGM(*_problem, *_preconditionerInverse, _epsilon); 102 : 103 5 : return new OGM(*_problem, _epsilon); 104 5 : } 105 : 106 : template <typename data_t> 107 : bool OGM<data_t>::isEqual(const Solver<data_t>& other) const 108 9 : { 109 9 : auto otherOGM = downcast_safe<OGM>(&other); 110 9 : if (!otherOGM) 111 0 : return false; 112 : 113 9 : if (_epsilon != otherOGM->_epsilon) 114 0 : return false; 115 : 116 9 : if ((_preconditionerInverse && !otherOGM->_preconditionerInverse) 117 9 : || (!_preconditionerInverse && otherOGM->_preconditionerInverse)) 118 0 : return false; 119 : 120 9 : if (_preconditionerInverse && otherOGM->_preconditionerInverse) 121 4 : if (*_preconditionerInverse != *otherOGM->_preconditionerInverse) 122 0 : return false; 123 : 124 9 : return true; 125 9 : } 126 : 127 : // ------------------------------------------ 128 : // explicit template instantiation 129 : template class OGM<float>; 130 : template class OGM<double>; 131 : 132 : } // namespace elsa