Line data Source code
1 : #include "OGM.h" 2 : #include "TypeCasts.hpp" 3 : #include "Logger.h" 4 : 5 : namespace elsa 6 : { 7 : template <typename data_t> 8 0 : OGM<data_t>::OGM(const Problem<data_t>& problem, data_t epsilon) 9 0 : : Solver<data_t>(problem), _epsilon{epsilon} 10 : { 11 0 : } 12 : 13 : template <typename data_t> 14 0 : OGM<data_t>::OGM(const Problem<data_t>& problem, 15 : const LinearOperator<data_t>& preconditionerInverse, data_t epsilon) 16 : : Solver<data_t>(problem), 17 0 : _epsilon{epsilon}, 18 0 : _preconditionerInverse{preconditionerInverse.clone()} 19 : { 20 : // check that preconditioner is compatible with problem 21 0 : if (_preconditionerInverse->getDomainDescriptor().getNumberOfCoefficients() 22 0 : != _problem->getCurrentSolution().getSize() 23 0 : || _preconditionerInverse->getRangeDescriptor().getNumberOfCoefficients() 24 0 : != _problem->getCurrentSolution().getSize()) { 25 0 : throw InvalidArgumentError("OGM: incorrect size of preconditioner"); 26 : } 27 0 : } 28 : 29 : template <typename data_t> 30 0 : DataContainer<data_t>& OGM<data_t>::solveImpl(index_t iterations) 31 : { 32 0 : if (iterations == 0) 33 0 : iterations = _defaultIterations; 34 : 35 0 : auto prevTheta = static_cast<data_t>(1.0); 36 0 : auto x0 = DataContainer<data_t>(getCurrentSolution()); 37 0 : auto& prevY = x0; 38 : 39 : // OGM is very picky when it comes to the accuracy of the used lipschitz constant therefore 40 : // we use 20 power iterations instead of 5 here to be more precise. 41 : // In some cases OGM might still not converge then an even more precise constant is needed 42 0 : auto lipschitz = _problem->getLipschitzConstant(20); 43 0 : auto deltaZero = _problem->getGradient().squaredL2Norm(); 44 0 : Logger::get("OGM")->info("Starting optimization with lipschitz constant {}", lipschitz); 45 : 46 : // log history legend 47 0 : Logger::get("OGM")->info("{:*^20}|{:*^20}|{:*^20}|{:*^20}|{:*^20}", "iteration", 48 : "thetaRatio0", "thetaRatio1", "y", "gradient"); 49 : 50 0 : for (index_t i = 0; i < iterations; ++i) { 51 0 : auto& x = getCurrentSolution(); 52 : 53 0 : auto gradient = _problem->getGradient(); 54 : 55 0 : if (_preconditionerInverse) 56 0 : gradient = _preconditionerInverse->apply(gradient); 57 : 58 0 : DataContainer<data_t> y = x - gradient / lipschitz; 59 : data_t theta; 60 0 : if (i == iterations - 1) { // last iteration 61 0 : theta = (static_cast<data_t>(1.0) 62 0 : + std::sqrt(static_cast<data_t>(1.0) 63 0 : + static_cast<data_t>(8.0) * prevTheta * prevTheta)) 64 : / static_cast<data_t>(2.0); 65 : } else { 66 0 : theta = (static_cast<data_t>(1.0) 67 0 : + std::sqrt(static_cast<data_t>(1.0) 68 0 : + static_cast<data_t>(4.0) * prevTheta * prevTheta)) 69 : / static_cast<data_t>(2.0); 70 : } 71 : 72 0 : Logger::get("OGM")->info(" {:<19}| {:<19}| {:<19}| {:<19}| {:<19}", i, 73 0 : (prevTheta - 1) / theta, prevTheta / theta, y.squaredL2Norm(), 74 0 : gradient.squaredL2Norm()); 75 : 76 : // x_{i+1} = y_{i+1} + \frac{\theta_i-1}{\theta_{i+1}}(y_{i+1} - y_i) + 77 : // \frac{\theta_i}{\theta_{i+1}}/(y_{i+1} - x_i) 78 0 : x = y + ((prevTheta - static_cast<data_t>(1.0)) / theta) * (y - prevY) 79 0 : - (prevTheta / theta) * (gradient / lipschitz); 80 0 : prevTheta = theta; 81 0 : prevY = y; 82 : 83 : // if the gradient is too small we stop 84 0 : if (gradient.squaredL2Norm() <= _epsilon * _epsilon * deltaZero) { 85 0 : Logger::get("OGM")->info("SUCCESS: Reached convergence at {}/{} iteration", i + 1, 86 : iterations); 87 0 : return x; 88 : } 89 : } 90 : 91 0 : Logger::get("OGM")->warn("Failed to reach convergence at {} iterations", iterations); 92 : 93 0 : return getCurrentSolution(); 94 0 : } 95 : 96 : template <typename data_t> 97 0 : OGM<data_t>* OGM<data_t>::cloneImpl() const 98 : { 99 0 : if (_preconditionerInverse) 100 0 : return new OGM(*_problem, *_preconditionerInverse, _epsilon); 101 : 102 0 : return new OGM(*_problem, _epsilon); 103 : } 104 : 105 : template <typename data_t> 106 0 : bool OGM<data_t>::isEqual(const Solver<data_t>& other) const 107 : { 108 0 : if (!Solver<data_t>::isEqual(other)) 109 0 : return false; 110 : 111 0 : auto otherOGM = downcast_safe<OGM>(&other); 112 0 : if (!otherOGM) 113 0 : return false; 114 : 115 0 : if (_epsilon != otherOGM->_epsilon) 116 0 : return false; 117 : 118 0 : if ((_preconditionerInverse && !otherOGM->_preconditionerInverse) 119 0 : || (!_preconditionerInverse && otherOGM->_preconditionerInverse)) 120 0 : return false; 121 : 122 0 : if (_preconditionerInverse && otherOGM->_preconditionerInverse) 123 0 : if (*_preconditionerInverse != *otherOGM->_preconditionerInverse) 124 0 : return false; 125 : 126 0 : return true; 127 : } 128 : 129 : // ------------------------------------------ 130 : // explicit template instantiation 131 : template class OGM<float>; 132 : template class OGM<double>; 133 : 134 : } // namespace elsa