LCOV - code coverage report
Current view: top level - elsa/solvers - GradientDescent.cpp (source / functions) Hit Total Coverage
Test: coverage-all.lcov Lines: 39 41 95.1 %
Date: 2024-12-21 07:37:52 Functions: 12 12 100.0 %

          Line data    Source code
       1             : #include "GradientDescent.h"
       2             : #include "Functional.h"
       3             : #include "Logger.h"
       4             : #include "TypeCasts.hpp"
       5             : #include "PowerIterations.h"
       6             : #include <iostream>
       7             : 
       8             : namespace elsa
       9             : {
      10             : 
      11             :     template <typename data_t>
      12             :     GradientDescent<data_t>::GradientDescent(const Functional<data_t>& problem, data_t stepSize)
      13             :         : Solver<data_t>(),
      14             :           _problem(problem.clone()),
      15             :           _lineSearchMethod(FixedStepSize<data_t>(*_problem, stepSize).clone())
      16           6 :     {
      17           6 :     }
      18             : 
      19             :     template <typename data_t>
      20             :     GradientDescent<data_t>::GradientDescent(const Functional<data_t>& problem)
      21             :         : Solver<data_t>(), _problem(problem.clone())
      22           4 :     {
      23           4 :     }
      24             : 
      25             :     template <typename data_t>
      26             :     GradientDescent<data_t>::GradientDescent(const Functional<data_t>& problem,
      27             :                                              const LineSearchMethod<data_t>& lineSearchMethod)
      28             :         : Solver<data_t>(), _problem(problem.clone()), _lineSearchMethod(lineSearchMethod.clone())
      29           4 :     {
      30           4 :     }
      31             : 
      32             :     template <typename data_t>
      33             :     DataContainer<data_t> GradientDescent<data_t>::solve(index_t iterations,
      34             :                                                          std::optional<DataContainer<data_t>> x0)
      35           6 :     {
      36           6 :         auto x = extract_or(x0, _problem->getDomainDescriptor());
      37             : 
      38             :         // If stepSize is not initialized yet, we do it know with x0
      39           6 :         if (!_lineSearchMethod) {
      40           2 :             _lineSearchMethod =
      41           2 :                 FixedStepSize<data_t>(*_problem, powerIterations(_problem->getHessian(x))).clone();
      42           2 :             Logger::get("GradientDescent")
      43           2 :                 ->info("Step length is chosen to be: {:8.5})", _lineSearchMethod->solve(x, x));
      44           2 :         }
      45             : 
      46         406 :         for (index_t i = 0; i < iterations; ++i) {
      47         400 :             Logger::get("GradientDescent")->info("iteration {} of {}", i + 1, iterations);
      48         400 :             auto gradient = _problem->getGradient(x);
      49         400 :             x -= _lineSearchMethod->solve(x, -gradient) * gradient;
      50         400 :         }
      51             : 
      52           6 :         return x;
      53           6 :     }
      54             : 
      55             :     template <typename data_t>
      56             :     GradientDescent<data_t>* GradientDescent<data_t>::cloneImpl() const
      57           6 :     {
      58           6 :         if (_lineSearchMethod) {
      59           4 :             return new GradientDescent(*_problem, *_lineSearchMethod);
      60           4 :         } else {
      61           2 :             return new GradientDescent(*_problem);
      62           2 :         }
      63           6 :     }
      64             : 
      65             :     template <typename data_t>
      66             :     bool GradientDescent<data_t>::isEqual(const Solver<data_t>& other) const
      67           6 :     {
      68           6 :         auto otherGD = downcast_safe<GradientDescent<data_t>>(&other);
      69           6 :         if (!otherGD)
      70           0 :             return false;
      71             : 
      72           6 :         if ((_lineSearchMethod and not otherGD->_lineSearchMethod)
      73           6 :             or (not _lineSearchMethod and otherGD->_lineSearchMethod)) {
      74           0 :             return false;
      75           6 :         } else if (not _lineSearchMethod and not otherGD->_lineSearchMethod) {
      76           2 :             return true;
      77           4 :         } else {
      78           4 :             return _lineSearchMethod->isEqual(*(otherGD->_lineSearchMethod));
      79           4 :         }
      80           6 :     }
      81             : 
      82             :     // ------------------------------------------
      83             :     // explicit template instantiation
      84             :     template class GradientDescent<float>;
      85             :     template class GradientDescent<double>;
      86             : } // namespace elsa

Generated by: LCOV version 1.14