LCOV - code coverage report
Current view: top level - elsa/solvers - GradientDescent.h (source / functions) Hit Total Coverage
Test: coverage-all.lcov Lines: 1 1 100.0 %
Date: 2025-01-22 07:37:33 Functions: 2 2 100.0 %

          Line data    Source code
       1             : #pragma once
       2             : 
       3             : #include <optional>
       4             : 
       5             : #include "Solver.h"
       6             : #include "Functional.h"
       7             : #include "MaybeUninitialized.hpp"
       8             : #include "LineSearchMethod.h"
       9             : #include "FixedStepSize.h"
      10             : 
      11             : namespace elsa
      12             : {
      13             : 
      14             :     /**
      15             :      * @brief Class representing a simple gradient descent solver with a fixed, given step size.
      16             :      *
      17             :      * This class implements a simple gradient descent iterative solver with a fixed, given step
      18             :      * size. No particular stopping rule is currently implemented (only a fixed number of
      19             :      * iterations, default to 100).
      20             :      *
      21             :      * @tparam data_t data type for the domain and range of the problem, defaulting to real_t
      22             :      *
      23             :      * @author
      24             :      * - Tobias Lasser - initial code
      25             :      */
      26             :     template <typename data_t = real_t>
      27             :     class GradientDescent : public Solver<data_t>
      28             :     {
      29             :     public:
      30             :         /// Scalar alias
      31             :         using Scalar = typename Solver<data_t>::Scalar;
      32             : 
      33             :         /**
      34             :          * @brief Constructor for gradient descent, accepting a problem and a fixed step size
      35             :          *
      36             :          * @param[in] problem the problem that is supposed to be solved
      37             :          * @param[in] stepSize the fixed step size to be used while solving
      38             :          */
      39             :         GradientDescent(const Functional<data_t>& problem, data_t stepSize);
      40             : 
      41             :         /**
      42             :          * @brief Constructor for gradient descent, accepting a problem. The step size will be
      43             :          * computed as \f$ 1 \over L \f$ with \f$ L \f$ being the Lipschitz constant of the
      44             :          * function.
      45             :          *
      46             :          * @param[in] problem the problem that is supposed to be solved
      47             :          */
      48             :         explicit GradientDescent(const Functional<data_t>& problem);
      49             : 
      50             :         /**
      51             :          * @brief Constructor for gradient descent, accepting a problem and a line
      52             :          * search method
      53             :          *
      54             :          * @param[in] problem the problem that is supposed to be solved
      55             :          * @param[in] lineSearchMethod the line search method to find the step size at
      56             :          * each iteration
      57             :          */
      58             :         GradientDescent(const Functional<data_t>& problem,
      59             :                         const LineSearchMethod<data_t>& lineSearchMethod);
      60             : 
      61             :         /// make copy constructor deletion explicit
      62             :         GradientDescent(const GradientDescent<data_t>&) = delete;
      63             : 
      64             :         /// default destructor
      65          14 :         ~GradientDescent() override = default;
      66             : 
      67             :         /**
      68             :          * @brief Solve the optimization problem, i.e. apply iterations number of iterations of
      69             :          * gradient descent
      70             :          *
      71             :          * @param[in] iterations number of iterations to execute
      72             :          *
      73             :          * @returns the approximated solution
      74             :          */
      75             :         DataContainer<data_t>
      76             :             solve(index_t iterations,
      77             :                   std::optional<DataContainer<data_t>> x0 = std::nullopt) override;
      78             : 
      79             :     private:
      80             :         /// the differentiable optimizaion problem
      81             :         std::unique_ptr<Functional<data_t>> _problem;
      82             : 
      83             :         /// the line search method
      84             :         std::unique_ptr<LineSearchMethod<data_t>> _lineSearchMethod;
      85             : 
      86             :         /// implement the polymorphic clone operation
      87             :         GradientDescent<data_t>* cloneImpl() const override;
      88             : 
      89             :         /// implement the polymorphic comparison operation
      90             :         bool isEqual(const Solver<data_t>& other) const override;
      91             :     };
      92             : } // namespace elsa

Generated by: LCOV version 1.14