Line data Source code
1 : #include "FGM.h" 2 : #include "Logger.h" 3 : #include "TypeCasts.hpp" 4 : 5 : namespace elsa 6 : { 7 : template <typename data_t> 8 0 : FGM<data_t>::FGM(const Problem<data_t>& problem, data_t epsilon) 9 0 : : Solver<data_t>(problem), _epsilon{epsilon} 10 : { 11 0 : } 12 : 13 : template <typename data_t> 14 0 : FGM<data_t>::FGM(const Problem<data_t>& problem, 15 : const LinearOperator<data_t>& preconditionerInverse, data_t epsilon) 16 : : Solver<data_t>(problem), 17 0 : _epsilon{epsilon}, 18 0 : _preconditionerInverse{preconditionerInverse.clone()} 19 : { 20 : // check that preconditioner is compatible with problem 21 0 : if (_preconditionerInverse->getDomainDescriptor().getNumberOfCoefficients() 22 0 : != _problem->getCurrentSolution().getSize() 23 0 : || _preconditionerInverse->getRangeDescriptor().getNumberOfCoefficients() 24 0 : != _problem->getCurrentSolution().getSize()) { 25 0 : throw InvalidArgumentError("FGM: incorrect size of preconditioner"); 26 : } 27 0 : } 28 : 29 : template <typename data_t> 30 0 : DataContainer<data_t>& FGM<data_t>::solveImpl(index_t iterations) 31 : { 32 0 : if (iterations == 0) 33 0 : iterations = _defaultIterations; 34 : 35 0 : auto prevTheta = static_cast<data_t>(1.0); 36 0 : auto x0 = DataContainer<data_t>(getCurrentSolution()); 37 0 : auto& prevY = x0; 38 : 39 0 : auto deltaZero = _problem->getGradient().squaredL2Norm(); 40 0 : auto lipschitz = _problem->getLipschitzConstant(); 41 0 : Logger::get("FGM")->info("Starting optimization with lipschitz constant {}", lipschitz); 42 : 43 0 : for (index_t i = 0; i < iterations; ++i) { 44 0 : Logger::get("FGM")->info("iteration {} of {}", i + 1, iterations); 45 0 : auto& x = getCurrentSolution(); 46 : 47 0 : auto gradient = _problem->getGradient(); 48 : 49 0 : if (_preconditionerInverse) 50 0 : gradient = _preconditionerInverse->apply(gradient); 51 : 52 0 : DataContainer<data_t> y = x - gradient / lipschitz; 53 0 : const auto theta = (static_cast<data_t>(1.0) 54 0 : + std::sqrt(static_cast<data_t>(1.0) 55 0 : + static_cast<data_t>(4.0) * prevTheta * prevTheta)) 56 : / static_cast<data_t>(2.0); 57 0 : x = y + (prevTheta - static_cast<data_t>(1.0)) / theta * (y - prevY); 58 0 : prevTheta = theta; 59 0 : prevY = y; 60 : 61 : // if the gradient is too small we stop 62 0 : if (gradient.squaredL2Norm() <= _epsilon * _epsilon * deltaZero) { 63 0 : Logger::get("FGM")->info("SUCCESS: Reached convergence at {}/{} iteration", i + 1, 64 : iterations); 65 0 : return x; 66 : } 67 : } 68 : 69 0 : Logger::get("FGM")->warn("Failed to reach convergence at {} iterations", iterations); 70 : 71 0 : return getCurrentSolution(); 72 0 : } 73 : 74 : template <typename data_t> 75 0 : FGM<data_t>* FGM<data_t>::cloneImpl() const 76 : { 77 0 : if (_preconditionerInverse) 78 0 : return new FGM(*_problem, *_preconditionerInverse, _epsilon); 79 : 80 0 : return new FGM(*_problem, _epsilon); 81 : } 82 : 83 : template <typename data_t> 84 0 : bool FGM<data_t>::isEqual(const Solver<data_t>& other) const 85 : { 86 0 : if (!Solver<data_t>::isEqual(other)) 87 0 : return false; 88 : 89 0 : auto otherFGM = downcast_safe<FGM>(&other); 90 0 : if (!otherFGM) 91 0 : return false; 92 : 93 0 : if (_epsilon != otherFGM->_epsilon) 94 0 : return false; 95 : 96 0 : if ((_preconditionerInverse && !otherFGM->_preconditionerInverse) 97 0 : || (!_preconditionerInverse && otherFGM->_preconditionerInverse)) 98 0 : return false; 99 : 100 0 : if (_preconditionerInverse && otherFGM->_preconditionerInverse) 101 0 : if (*_preconditionerInverse != *otherFGM->_preconditionerInverse) 102 0 : return false; 103 : 104 0 : return true; 105 : } 106 : 107 : // ------------------------------------------ 108 : // explicit template instantiation 109 : template class FGM<float>; 110 : template class FGM<double>; 111 : 112 : } // namespace elsa