Line data Source code
1 : #include "GradientDescent.h" 2 : #include "Logger.h" 3 : #include "TypeCasts.hpp" 4 : 5 : namespace elsa 6 : { 7 : template <typename data_t> 8 : GradientDescent<data_t>::GradientDescent(const Problem<data_t>& problem, data_t stepSize) 9 : : Solver<data_t>(), _problem(problem.clone()), _stepSize{stepSize} 10 12 : { 11 : // sanity check 12 12 : if (_stepSize <= 0) 13 0 : throw InvalidArgumentError("GradientDescent: step size has to be positive"); 14 12 : } 15 : 16 : template <typename data_t> 17 : GradientDescent<data_t>::GradientDescent(const Problem<data_t>& problem) 18 : : Solver<data_t>(), _problem(problem.clone()) 19 2 : { 20 2 : this->_stepSize = 21 2 : static_cast<data_t>(1.0) / static_cast<data_t>(problem.getLipschitzConstant()); 22 2 : } 23 : 24 : template <typename data_t> 25 : DataContainer<data_t>& GradientDescent<data_t>::solveImpl(index_t iterations) 26 6 : { 27 6 : if (iterations == 0) 28 0 : iterations = _defaultIterations; 29 : 30 246 : for (index_t i = 0; i < iterations; ++i) { 31 240 : Logger::get("GradientDescent")->info("iteration {} of {}", i + 1, iterations); 32 240 : auto& x = _problem->getCurrentSolution(); 33 : 34 240 : auto gradient = _problem->getGradient(); 35 240 : gradient *= _stepSize; 36 240 : x -= gradient; 37 240 : } 38 : 39 6 : return _problem->getCurrentSolution(); 40 6 : } 41 : 42 : template <typename data_t> 43 : GradientDescent<data_t>* GradientDescent<data_t>::cloneImpl() const 44 6 : { 45 6 : return new GradientDescent(*_problem, _stepSize); 46 6 : } 47 : 48 : template <typename data_t> 49 : bool GradientDescent<data_t>::isEqual(const Solver<data_t>& other) const 50 6 : { 51 6 : auto otherGD = downcast_safe<GradientDescent<data_t>>(&other); 52 6 : if (!otherGD) 53 0 : return false; 54 : 55 6 : if (_stepSize != otherGD->_stepSize) 56 0 : return false; 57 : 58 6 : return true; 59 6 : } 60 : 61 : // ------------------------------------------ 62 : // explicit template instantiation 63 : template class GradientDescent<float>; 64 : template class GradientDescent<double>; 65 : } // namespace elsa