LCOV - code coverage report
Current view: top level - elsa/line_search - StrongWolfeCondition.cpp (source / functions) Hit Total Coverage
Test: coverage-all.lcov Lines: 76 91 83.5 %
Date: 2024-05-16 04:22:26 Functions: 10 10 100.0 %

          Line data    Source code
       1             : #include "StrongWolfeCondition.h"
       2             : #include "utils/utils.h"
       3             : 
       4             : namespace elsa
       5             : {
       6             : 
       7             :     template <typename data_t>
       8             :     StrongWolfeCondition<data_t>::StrongWolfeCondition(const Functional<data_t>& problem,
       9             :                                                        data_t amax, data_t c1, data_t c2,
      10             :                                                        index_t max_iterations)
      11             :         : LineSearchMethod<data_t>(problem, max_iterations), _amax(amax), _c1(c1), _c2(c2)
      12          40 :     {
      13             :         // sanity checks
      14          40 :         if (amax <= 0)
      15           0 :             throw InvalidArgumentError("StrongWolfeCondition: amax has to be greater than 0");
      16          40 :         if (c1 <= 0 or c1 >= 1.0)
      17           0 :             throw InvalidArgumentError("StrongWolfeCondition: c1 has to be in the range (0,1)");
      18          40 :         if (c2 <= c1 or c2 >= 1.0)
      19           0 :             throw InvalidArgumentError("StrongWolfeCondition: c2 has to be in the range (c1,1)");
      20          40 :     }
      21             : 
      22             :     template <typename data_t>
      23             :     data_t StrongWolfeCondition<data_t>::_zoom(data_t a_lo, data_t a_hi, data_t f_lo, data_t f_hi,
      24             :                                                data_t f0, data_t der_f_lo, data_t der_f0,
      25             :                                                const DataContainer<data_t>& xi,
      26             :                                                const DataContainer<data_t>& di,
      27             :                                                index_t max_iterations)
      28          46 :     {
      29          46 :         data_t aj = 0;
      30          46 :         auto fj = f0;
      31          46 :         data_t amin = 0;
      32          46 :         data_t dalpha, a, b, cchk;
      33         189 :         for (index_t i = 0; i < max_iterations; ++i) {
      34         189 :             dalpha = a_hi - a_lo;
      35         189 :             if (dalpha < 0) {
      36           9 :                 a = a_hi;
      37           9 :                 b = a_lo;
      38         180 :             } else {
      39         180 :                 a = a_lo;
      40         180 :                 b = a_hi;
      41         180 :             }
      42         189 :             cchk = 0.2 * dalpha;
      43             :             // TODO: check if interpolation is successful
      44         189 :             if (i > 0) {
      45         143 :                 amin = cubic_interpolation(a_lo, f_lo, der_f_lo, a_hi, f_hi, aj, fj);
      46         143 :             }
      47         189 :             if (i == 0 or std::isnan(amin) or amin > b - cchk or amin < a + cchk) {
      48         159 :                 amin = a_lo + 0.5 * dalpha;
      49         159 :             }
      50         189 :             auto fmin = this->_problem->evaluate(xi + amin * di);
      51         189 :             if ((fmin > f0 + _c1 * amin * der_f0) or fmin >= f_lo) {
      52         134 :                 aj = a_hi;
      53         134 :                 a_hi = amin;
      54         134 :                 fj = f_hi;
      55         134 :                 f_hi = fmin;
      56         134 :             } else {
      57          55 :                 auto der_fmin = di.dot(this->_problem->getGradient(xi + amin * di));
      58          55 :                 if (std::abs(der_fmin) <= -_c2 * der_f0) {
      59          46 :                     return amin;
      60          46 :                 }
      61           9 :                 if (der_fmin * dalpha >= 0) {
      62           9 :                     aj = a_hi;
      63           9 :                     a_hi = a_lo;
      64           9 :                     fj = f_hi;
      65           9 :                     f_hi = f_lo;
      66           9 :                 } else {
      67           0 :                     aj = a_lo;
      68           0 :                     fj = f_lo;
      69           0 :                 }
      70           9 :                 a_lo = amin;
      71           9 :                 f_lo = fmin;
      72           9 :                 der_f_lo = der_fmin;
      73           9 :             }
      74         189 :         }
      75          46 :         return amin;
      76          46 :     }
      77             : 
      78             :     template <typename data_t>
      79             :     data_t StrongWolfeCondition<data_t>::solve(DataContainer<data_t> xi, DataContainer<data_t> di)
      80         126 :     {
      81         126 :         data_t ai_1 = 0;
      82         126 :         data_t f0 = this->_problem->evaluate(xi);
      83         126 :         data_t der_f0 = di.dot(this->_problem->getGradient(xi));
      84         126 :         auto fi_1 = f0;
      85         126 :         auto der_fi_1 = der_f0;
      86         126 :         auto ai = std::min(static_cast<data_t>(1.0), _amax);
      87         126 :         for (index_t i = 0; i < this->_max_iterations; ++i) {
      88         126 :             auto fi = this->_problem->evaluate(xi + ai * di);
      89         126 :             if ((fi > f0 + _c1 * ai * der_f0) or (fi >= fi_1 and i > 0)) {
      90          46 :                 return _zoom(ai_1, ai, fi_1, fi, f0, der_fi_1, der_f0, xi, di);
      91          46 :             }
      92          80 :             auto der_fi = di.dot(this->_problem->getGradient(xi + ai * di));
      93          80 :             if (std::abs(der_fi) <= _c2 * std::abs(der_f0)) {
      94          80 :                 return ai;
      95          80 :             }
      96           0 :             if (der_fi >= 0) {
      97           0 :                 return _zoom(ai, ai_1, fi, fi_1, f0, der_fi, der_f0, xi, di);
      98           0 :             }
      99           0 :             ai_1 = ai;
     100           0 :             ai = std::min(2 * ai, _amax);
     101           0 :             fi_1 = fi;
     102           0 :             der_fi_1 = der_fi;
     103           0 :         }
     104         126 :         return ai;
     105         126 :     }
     106             : 
     107             :     template <typename data_t>
     108             :     StrongWolfeCondition<data_t>* StrongWolfeCondition<data_t>::cloneImpl() const
     109          24 :     {
     110          24 :         return new StrongWolfeCondition(*this->_problem, _amax, _c1, _c2, this->_max_iterations);
     111          24 :     }
     112             : 
     113             :     template <typename data_t>
     114             :     bool StrongWolfeCondition<data_t>::isEqual(const LineSearchMethod<data_t>& other) const
     115           4 :     {
     116           4 :         auto otherSWC = downcast_safe<StrongWolfeCondition<data_t>>(&other);
     117           4 :         if (!otherSWC)
     118           0 :             return false;
     119             : 
     120           4 :         return (_amax == otherSWC->_amax && _c1 == otherSWC->_c1 && _c2 == otherSWC->_c2);
     121           4 :     }
     122             :     // ------------------------------------------
     123             :     // explicit template instantiation
     124             :     template class StrongWolfeCondition<float>;
     125             :     template class StrongWolfeCondition<double>;
     126             : } // namespace elsa

Generated by: LCOV version 1.14