LCOV - code coverage report
Current view: top level - elsa/line_search - ArmijoCondition.cpp (source / functions) Hit Total Coverage
Test: coverage-all.lcov Lines: 26 30 86.7 %
Date: 2024-12-21 07:37:52 Functions: 8 8 100.0 %

          Line data    Source code
       1             : #include "ArmijoCondition.h"
       2             : namespace elsa
       3             : {
       4             : 
       5             :     template <typename data_t>
       6             :     ArmijoCondition<data_t>::ArmijoCondition(const Functional<data_t>& problem, data_t amax,
       7             :                                              data_t c, data_t rho, index_t max_iterations)
       8             :         : LineSearchMethod<data_t>(problem, max_iterations), _amax(amax), _c(c), _rho(rho)
       9           8 :     {
      10             :         // sanity checks
      11           8 :         if (amax <= 0)
      12           0 :             throw InvalidArgumentError("ArmijoCondition: amax has to be positive");
      13           8 :         if (c <= 0 or c >= 1)
      14           0 :             throw InvalidArgumentError("ArmijoCondition: c has to be in the range (0,1)");
      15           8 :         if (rho <= 0 or rho >= 1)
      16           0 :             throw InvalidArgumentError("ArmijoCondition: rho has to be in the range (0,1)");
      17           8 :     }
      18             : 
      19             :     template <typename data_t>
      20             :     data_t ArmijoCondition<data_t>::solve(DataContainer<data_t> xi, DataContainer<data_t> di)
      21          40 :     {
      22          40 :         auto ai = _amax;
      23          40 :         auto f0 = this->_problem->evaluate(xi);
      24          40 :         auto f0_prime = di.dot(this->_problem->getGradient(xi));
      25         224 :         for (index_t i = 0; i < this->_max_iterations; ++i) {
      26         224 :             if (this->_problem->evaluate(xi + ai * di) <= f0 + _c * ai * f0_prime) {
      27          40 :                 return ai;
      28         184 :             } else {
      29         184 :                 ai = ai * _rho;
      30         184 :             }
      31         224 :         }
      32          40 :         return ai;
      33          40 :     }
      34             : 
      35             :     template <typename data_t>
      36             :     ArmijoCondition<data_t>* ArmijoCondition<data_t>::cloneImpl() const
      37           4 :     {
      38           4 :         return new ArmijoCondition(*this->_problem, _amax, _c, _rho, this->_max_iterations);
      39           4 :     }
      40             : 
      41             :     template <typename data_t>
      42             :     bool ArmijoCondition<data_t>::isEqual(const LineSearchMethod<data_t>& other) const
      43           4 :     {
      44           4 :         auto otherAC = downcast_safe<ArmijoCondition<data_t>>(&other);
      45           4 :         if (!otherAC)
      46           0 :             return false;
      47             : 
      48           4 :         return (_amax == otherAC->_amax && _c == otherAC->_c && _rho == otherAC->_rho);
      49           4 :     }
      50             : 
      51             :     // ------------------------------------------
      52             :     // explicit template instantiation
      53             :     template class ArmijoCondition<float>;
      54             :     template class ArmijoCondition<double>;
      55             : } // namespace elsa

Generated by: LCOV version 1.14