LCOV - code coverage report
Current view: top level - solvers - GradientDescent.cpp (source / functions) Hit Total Coverage
Test: test_coverage.info.cleaned Lines: 0 31 0.0 %
Date: 2022-08-04 03:43:28 Functions: 0 10 0.0 %

          Line data    Source code
       1             : #include "GradientDescent.h"
       2             : #include "Logger.h"
       3             : #include "TypeCasts.hpp"
       4             : 
       5             : namespace elsa
       6             : {
       7             :     template <typename data_t>
       8           0 :     GradientDescent<data_t>::GradientDescent(const Problem<data_t>& problem, data_t stepSize)
       9           0 :         : Solver<data_t>(problem), _stepSize{stepSize}
      10             :     {
      11             :         // sanity check
      12           0 :         if (_stepSize <= 0)
      13           0 :             throw InvalidArgumentError("GradientDescent: step size has to be positive");
      14           0 :     }
      15             : 
      16             :     template <typename data_t>
      17           0 :     GradientDescent<data_t>::GradientDescent(const Problem<data_t>& problem)
      18           0 :         : Solver<data_t>(problem)
      19             :     {
      20           0 :         this->_stepSize =
      21           0 :             static_cast<data_t>(1.0) / static_cast<data_t>(problem.getLipschitzConstant());
      22           0 :     }
      23             : 
      24             :     template <typename data_t>
      25           0 :     DataContainer<data_t>& GradientDescent<data_t>::solveImpl(index_t iterations)
      26             :     {
      27           0 :         if (iterations == 0)
      28           0 :             iterations = _defaultIterations;
      29             : 
      30           0 :         for (index_t i = 0; i < iterations; ++i) {
      31           0 :             Logger::get("GradientDescent")->info("iteration {} of {}", i + 1, iterations);
      32           0 :             auto& x = getCurrentSolution();
      33             : 
      34           0 :             auto gradient = _problem->getGradient();
      35           0 :             gradient *= _stepSize;
      36           0 :             x -= gradient;
      37             :         }
      38             : 
      39           0 :         return getCurrentSolution();
      40             :     }
      41             : 
      42             :     template <typename data_t>
      43           0 :     GradientDescent<data_t>* GradientDescent<data_t>::cloneImpl() const
      44             :     {
      45           0 :         return new GradientDescent(*_problem, _stepSize);
      46             :     }
      47             : 
      48             :     template <typename data_t>
      49           0 :     bool GradientDescent<data_t>::isEqual(const Solver<data_t>& other) const
      50             :     {
      51           0 :         if (!Solver<data_t>::isEqual(other))
      52           0 :             return false;
      53             : 
      54           0 :         auto otherGD = downcast_safe<GradientDescent<data_t>>(&other);
      55           0 :         if (!otherGD)
      56           0 :             return false;
      57             : 
      58           0 :         if (_stepSize != otherGD->_stepSize)
      59           0 :             return false;
      60             : 
      61           0 :         return true;
      62             :     }
      63             : 
      64             :     // ------------------------------------------
      65             :     // explicit template instantiation
      66             :     template class GradientDescent<float>;
      67             :     template class GradientDescent<double>;
      68             : } // namespace elsa

Generated by: LCOV version 1.14