Line data Source code
1 : #include "GradientDescent.h" 2 : #include "Logger.h" 3 : #include "TypeCasts.hpp" 4 : 5 : namespace elsa 6 : { 7 : template <typename data_t> 8 0 : GradientDescent<data_t>::GradientDescent(const Problem<data_t>& problem, data_t stepSize) 9 0 : : Solver<data_t>(problem), _stepSize{stepSize} 10 : { 11 : // sanity check 12 0 : if (_stepSize <= 0) 13 0 : throw InvalidArgumentError("GradientDescent: step size has to be positive"); 14 0 : } 15 : 16 : template <typename data_t> 17 0 : GradientDescent<data_t>::GradientDescent(const Problem<data_t>& problem) 18 0 : : Solver<data_t>(problem) 19 : { 20 0 : this->_stepSize = 21 0 : static_cast<data_t>(1.0) / static_cast<data_t>(problem.getLipschitzConstant()); 22 0 : } 23 : 24 : template <typename data_t> 25 0 : DataContainer<data_t>& GradientDescent<data_t>::solveImpl(index_t iterations) 26 : { 27 0 : if (iterations == 0) 28 0 : iterations = _defaultIterations; 29 : 30 0 : for (index_t i = 0; i < iterations; ++i) { 31 0 : Logger::get("GradientDescent")->info("iteration {} of {}", i + 1, iterations); 32 0 : auto& x = getCurrentSolution(); 33 : 34 0 : auto gradient = _problem->getGradient(); 35 0 : gradient *= _stepSize; 36 0 : x -= gradient; 37 : } 38 : 39 0 : return getCurrentSolution(); 40 : } 41 : 42 : template <typename data_t> 43 0 : GradientDescent<data_t>* GradientDescent<data_t>::cloneImpl() const 44 : { 45 0 : return new GradientDescent(*_problem, _stepSize); 46 : } 47 : 48 : template <typename data_t> 49 0 : bool GradientDescent<data_t>::isEqual(const Solver<data_t>& other) const 50 : { 51 0 : if (!Solver<data_t>::isEqual(other)) 52 0 : return false; 53 : 54 0 : auto otherGD = downcast_safe<GradientDescent<data_t>>(&other); 55 0 : if (!otherGD) 56 0 : return false; 57 : 58 0 : if (_stepSize != otherGD->_stepSize) 59 0 : return false; 60 : 61 0 : return true; 62 : } 63 : 64 : // ------------------------------------------ 65 : // explicit template instantiation 66 : template class GradientDescent<float>; 67 : template class GradientDescent<double>; 68 : } // namespace elsa