Line data Source code
1 : #include "FGM.h" 2 : #include "Logger.h" 3 : #include "TypeCasts.hpp" 4 : 5 : namespace elsa 6 : { 7 : template <typename data_t> 8 : FGM<data_t>::FGM(const Problem<data_t>& problem, data_t epsilon) 9 : : Solver<data_t>(), _problem(problem.clone()), _epsilon{epsilon} 10 10 : { 11 10 : } 12 : 13 : template <typename data_t> 14 : FGM<data_t>::FGM(const Problem<data_t>& problem, 15 : const LinearOperator<data_t>& preconditionerInverse, data_t epsilon) 16 : : Solver<data_t>(), 17 : _problem(problem.clone()), 18 : _epsilon{epsilon}, 19 : _preconditionerInverse{preconditionerInverse.clone()} 20 8 : { 21 : // check that preconditioner is compatible with problem 22 8 : if (_preconditionerInverse->getDomainDescriptor().getNumberOfCoefficients() 23 8 : != _problem->getCurrentSolution().getSize() 24 8 : || _preconditionerInverse->getRangeDescriptor().getNumberOfCoefficients() 25 8 : != _problem->getCurrentSolution().getSize()) { 26 0 : throw InvalidArgumentError("FGM: incorrect size of preconditioner"); 27 0 : } 28 8 : } 29 : 30 : template <typename data_t> 31 : DataContainer<data_t>& FGM<data_t>::solveImpl(index_t iterations) 32 9 : { 33 9 : if (iterations == 0) 34 0 : iterations = _defaultIterations; 35 : 36 9 : auto prevTheta = static_cast<data_t>(1.0); 37 9 : auto x0 = _problem->getCurrentSolution(); 38 9 : auto& prevY = x0; 39 : 40 9 : auto deltaZero = _problem->getGradient().squaredL2Norm(); 41 9 : auto lipschitz = _problem->getLipschitzConstant(); 42 9 : Logger::get("FGM")->info("Starting optimization with lipschitz constant {}", lipschitz); 43 : 44 5272 : for (index_t i = 0; i < iterations; ++i) { 45 5263 : Logger::get("FGM")->info("iteration {} of {}", i + 1, iterations); 46 5263 : auto& x = _problem->getCurrentSolution(); 47 : 48 5263 : auto gradient = _problem->getGradient(); 49 : 50 5263 : if (_preconditionerInverse) 51 2624 : gradient = _preconditionerInverse->apply(gradient); 52 : 53 5263 : DataContainer<data_t> y = x - gradient / lipschitz; 54 5263 : const auto theta = (static_cast<data_t>(1.0) 55 5263 : + std::sqrt(static_cast<data_t>(1.0) 56 5263 : + static_cast<data_t>(4.0) * prevTheta * prevTheta)) 57 5263 : / static_cast<data_t>(2.0); 58 5263 : x = y + (prevTheta - static_cast<data_t>(1.0)) / theta * (y - prevY); 59 5263 : prevTheta = theta; 60 5263 : prevY = y; 61 : 62 : // if the gradient is too small we stop 63 5263 : if (gradient.squaredL2Norm() <= _epsilon * _epsilon * deltaZero) { 64 0 : Logger::get("FGM")->info("SUCCESS: Reached convergence at {}/{} iteration", i + 1, 65 0 : iterations); 66 0 : return x; 67 0 : } 68 5263 : } 69 : 70 9 : Logger::get("FGM")->warn("Failed to reach convergence at {} iterations", iterations); 71 : 72 9 : return _problem->getCurrentSolution(); 73 9 : } 74 : 75 : template <typename data_t> 76 : FGM<data_t>* FGM<data_t>::cloneImpl() const 77 9 : { 78 9 : if (_preconditionerInverse) 79 4 : return new FGM(*_problem, *_preconditionerInverse, _epsilon); 80 : 81 5 : return new FGM(*_problem, _epsilon); 82 5 : } 83 : 84 : template <typename data_t> 85 : bool FGM<data_t>::isEqual(const Solver<data_t>& other) const 86 9 : { 87 9 : auto otherFGM = downcast_safe<FGM>(&other); 88 9 : if (!otherFGM) 89 0 : return false; 90 : 91 9 : if (_epsilon != otherFGM->_epsilon) 92 0 : return false; 93 : 94 9 : if ((_preconditionerInverse && !otherFGM->_preconditionerInverse) 95 9 : || (!_preconditionerInverse && otherFGM->_preconditionerInverse)) 96 0 : return false; 97 : 98 9 : if (_preconditionerInverse && otherFGM->_preconditionerInverse) 99 4 : if (*_preconditionerInverse != *otherFGM->_preconditionerInverse) 100 0 : return false; 101 : 102 9 : return true; 103 9 : } 104 : 105 : // ------------------------------------------ 106 : // explicit template instantiation 107 : template class FGM<float>; 108 : template class FGM<double>; 109 : 110 : } // namespace elsa