LCOV - code coverage report
Current view: top level - elsa/line_search - GoldsteinCondition.cpp (source / functions) Hit Total Coverage
Test: coverage-all.lcov Lines: 74 81 91.4 %
Date: 2024-05-16 04:22:26 Functions: 10 10 100.0 %

          Line data    Source code
       1             : #include "GoldsteinCondition.h"
       2             : #include "utils/utils.h"
       3             : 
       4             : namespace elsa
       5             : {
       6             : 
       7             :     template <typename data_t>
       8             :     GoldsteinCondition<data_t>::GoldsteinCondition(const Functional<data_t>& problem, data_t amax,
       9             :                                                    data_t c, index_t max_iterations)
      10             :         : LineSearchMethod<data_t>(problem, max_iterations), _amax(amax), _c(c)
      11           8 :     {
      12             :         // sanity checks
      13           8 :         if (amax <= 0)
      14           0 :             throw InvalidArgumentError("GoldsteinCondition: amax has to be greater than 0");
      15           8 :         if (c <= 0 or c >= 0.5)
      16           0 :             throw InvalidArgumentError("GoldsteinCondition: c has to be in the range (0,0.5)");
      17           8 :     }
      18             : 
      19             :     template <typename data_t>
      20             :     data_t GoldsteinCondition<data_t>::_zoom(data_t a_lo, data_t a_hi, data_t f_lo, data_t f_hi,
      21             :                                              data_t f0, data_t der_f_lo, data_t der_f0,
      22             :                                              const DataContainer<data_t>& xi,
      23             :                                              const DataContainer<data_t>& di,
      24             :                                              index_t max_iterations)
      25          36 :     {
      26          36 :         data_t aj = 0;
      27          36 :         data_t fj = f0;
      28          36 :         data_t a_min = 0;
      29         396 :         for (index_t i = 0; i < max_iterations; ++i) {
      30         360 :             data_t cchk = static_cast<data_t>(0.2) * (a_hi - a_lo);
      31         360 :             if (i > 0) {
      32         324 :                 a_min = cubic_interpolation<data_t>(a_lo, f_lo, der_f_lo, a_hi, f_hi, aj, fj);
      33         324 :             }
      34         360 :             if (i == 0 or std::isnan(a_min) or (a_min > a_hi - cchk) or (a_min < a_lo + cchk)) {
      35         291 :                 a_min = a_lo + static_cast<data_t>(0.5) * (a_hi - a_lo);
      36         291 :             }
      37         360 :             data_t f_min = this->_problem->evaluate(xi + a_min * di);
      38         360 :             if ((f_min > f0 + _c * a_min * der_f0) or f_min >= f_lo) {
      39         250 :                 aj = a_hi;
      40         250 :                 a_hi = a_min;
      41         250 :                 fj = f_hi;
      42         250 :                 f_hi = f_min;
      43         250 :             } else {
      44         110 :                 if (f_min <= f0 + (1 - _c) * a_min * der_f0) {
      45           0 :                     return a_min;
      46           0 :                 }
      47         110 :                 data_t der_f_min = di.dot(this->_problem->getGradient(xi + a_min * di));
      48         110 :                 if (der_f_min * (a_hi - a_lo) >= 0) {
      49          76 :                     aj = a_hi;
      50          76 :                     a_hi = a_lo;
      51          76 :                     fj = f_hi;
      52          76 :                     f_hi = f_lo;
      53          76 :                 } else {
      54          34 :                     aj = a_lo;
      55          34 :                     fj = f_lo;
      56          34 :                 }
      57         110 :                 a_lo = a_min;
      58         110 :                 f_lo = f_min;
      59         110 :                 der_f_lo = der_f_min;
      60         110 :             }
      61         360 :         }
      62          36 :         return a_min;
      63          36 :     }
      64             : 
      65             :     template <typename data_t>
      66             :     data_t GoldsteinCondition<data_t>::solve(DataContainer<data_t> xi, DataContainer<data_t> di)
      67          40 :     {
      68          40 :         data_t ai_1 = 0;
      69          40 :         auto f0 = this->_problem->evaluate(xi);
      70          40 :         auto der_f0 = di.dot(this->_problem->getGradient(xi));
      71          40 :         auto fi_1 = f0;
      72          40 :         auto der_fi_1 = der_f0;
      73          40 :         auto ai = std::min(static_cast<data_t>(1.0), _amax);
      74          41 :         for (index_t i = 0; i < this->_max_iterations; ++i) {
      75          41 :             auto fi = this->_problem->evaluate(xi + ai * di);
      76          41 :             auto der_fi = di.dot(this->_problem->getGradient(xi + ai * di));
      77          41 :             if ((fi > f0 + _c * ai * der_f0) or (fi >= fi_1 and i > 0)) {
      78          36 :                 return _zoom(ai_1, ai, fi_1, fi, f0, der_fi_1, der_f0, xi, di);
      79          36 :             }
      80           5 :             if (fi >= f0 + (1 - _c) * ai * der_f0) {
      81           4 :                 return ai;
      82           4 :             }
      83           1 :             if (der_fi >= 0) {
      84           0 :                 return _zoom(ai, ai_1, fi, fi_1, f0, der_fi, der_f0, xi, di);
      85           0 :             }
      86           1 :             auto a = 2 * ai;
      87           1 :             ai_1 = ai;
      88           1 :             ai = std::min(a, _amax);
      89           1 :             fi_1 = fi;
      90           1 :             der_fi_1 = der_fi;
      91           1 :         }
      92          40 :         return ai;
      93          40 :     }
      94             : 
      95             :     template <typename data_t>
      96             :     GoldsteinCondition<data_t>* GoldsteinCondition<data_t>::cloneImpl() const
      97           4 :     {
      98           4 :         return new GoldsteinCondition(*this->_problem, _amax, _c, this->_max_iterations);
      99           4 :     }
     100             : 
     101             :     template <typename data_t>
     102             :     bool GoldsteinCondition<data_t>::isEqual(const LineSearchMethod<data_t>& other) const
     103           4 :     {
     104           4 :         auto otherGC = downcast_safe<GoldsteinCondition<data_t>>(&other);
     105           4 :         if (!otherGC)
     106           0 :             return false;
     107             : 
     108           4 :         return (_amax == otherGC->_amax && _c == otherGC->_c);
     109           4 :     }
     110             : 
     111             :     // ------------------------------------------
     112             :     // explicit template instantiation
     113             :     template class GoldsteinCondition<float>;
     114             :     template class GoldsteinCondition<double>;
     115             : } // namespace elsa

Generated by: LCOV version 1.14