Line data Source code
1 : #include "GradientDescent.h" 2 : #include "Functional.h" 3 : #include "Logger.h" 4 : #include "TypeCasts.hpp" 5 : #include "PowerIterations.h" 6 : #include <iostream> 7 : 8 : namespace elsa 9 : { 10 : 11 : template <typename data_t> 12 : GradientDescent<data_t>::GradientDescent(const Functional<data_t>& problem, data_t stepSize) 13 : : Solver<data_t>(), 14 : _problem(problem.clone()), 15 : _lineSearchMethod(FixedStepSize<data_t>(*_problem, stepSize).clone()) 16 6 : { 17 6 : } 18 : 19 : template <typename data_t> 20 : GradientDescent<data_t>::GradientDescent(const Functional<data_t>& problem) 21 : : Solver<data_t>(), _problem(problem.clone()) 22 4 : { 23 4 : } 24 : 25 : template <typename data_t> 26 : GradientDescent<data_t>::GradientDescent(const Functional<data_t>& problem, 27 : const LineSearchMethod<data_t>& lineSearchMethod) 28 : : Solver<data_t>(), _problem(problem.clone()), _lineSearchMethod(lineSearchMethod.clone()) 29 4 : { 30 4 : } 31 : 32 : template <typename data_t> 33 : DataContainer<data_t> GradientDescent<data_t>::solve(index_t iterations, 34 : std::optional<DataContainer<data_t>> x0) 35 6 : { 36 6 : auto x = extract_or(x0, _problem->getDomainDescriptor()); 37 : 38 : // If stepSize is not initialized yet, we do it know with x0 39 6 : if (!_lineSearchMethod) { 40 2 : _lineSearchMethod = 41 2 : FixedStepSize<data_t>(*_problem, powerIterations(_problem->getHessian(x))).clone(); 42 2 : Logger::get("GradientDescent") 43 2 : ->info("Step length is chosen to be: {:8.5})", _lineSearchMethod->solve(x, x)); 44 2 : } 45 : 46 406 : for (index_t i = 0; i < iterations; ++i) { 47 400 : Logger::get("GradientDescent")->info("iteration {} of {}", i + 1, iterations); 48 400 : auto gradient = _problem->getGradient(x); 49 400 : x -= _lineSearchMethod->solve(x, -gradient) * gradient; 50 400 : } 51 : 52 6 : return x; 53 6 : } 54 : 55 : template <typename data_t> 56 : GradientDescent<data_t>* GradientDescent<data_t>::cloneImpl() const 57 6 : { 58 6 : if (_lineSearchMethod) { 59 4 : return new GradientDescent(*_problem, *_lineSearchMethod); 60 4 : } else { 61 2 : return new GradientDescent(*_problem); 62 2 : } 63 6 : } 64 : 65 : template <typename data_t> 66 : bool GradientDescent<data_t>::isEqual(const Solver<data_t>& other) const 67 6 : { 68 6 : auto otherGD = downcast_safe<GradientDescent<data_t>>(&other); 69 6 : if (!otherGD) 70 0 : return false; 71 : 72 6 : if ((_lineSearchMethod and not otherGD->_lineSearchMethod) 73 6 : or (not _lineSearchMethod and otherGD->_lineSearchMethod)) { 74 0 : return false; 75 6 : } else if (not _lineSearchMethod and not otherGD->_lineSearchMethod) { 76 2 : return true; 77 4 : } else { 78 4 : return _lineSearchMethod->isEqual(*(otherGD->_lineSearchMethod)); 79 4 : } 80 6 : } 81 : 82 : // ------------------------------------------ 83 : // explicit template instantiation 84 : template class GradientDescent<float>; 85 : template class GradientDescent<double>; 86 : } // namespace elsa