LCOV - code coverage report
Current view: top level - elsa/line_search - BarzilaiBorwein.cpp (source / functions) Hit Total Coverage
Test: coverage-all.lcov Lines: 63 74 85.1 %
Date: 2024-05-16 04:22:26 Functions: 8 8 100.0 %

          Line data    Source code
       1             : #include "BarzilaiBorwein.h"
       2             : namespace elsa
       3             : {
       4             :     template <typename data_t>
       5             :     BarzilaiBorwein<data_t>::BarzilaiBorwein(const Functional<data_t>& problem, uint32_t m,
       6             :                                              data_t gamma, data_t sigma1, data_t sigma2,
       7             :                                              data_t epsilon, index_t max_iterations)
       8             :         : LineSearchMethod<data_t>(problem, max_iterations),
       9             :           _m(m),
      10             :           _gamma(gamma),
      11             :           _sigma1(sigma1),
      12             :           _sigma2(sigma2),
      13             :           _epsilon(epsilon),
      14             :           _gi_prev(DataContainer<data_t>(this->_problem->getDomainDescriptor()))
      15           8 :     {
      16             :         // sanity checks
      17           8 :         if (gamma <= 0 or gamma >= 1)
      18           0 :             throw InvalidArgumentError("BarzilaiBorwein: gamma has to be in the range (0,1)");
      19           8 :         if (sigma1 <= 0 or sigma1 >= sigma2 or sigma1 >= 1)
      20           0 :             throw InvalidArgumentError(
      21           0 :                 "BarzilaiBorwein: sigma1 has to satisfy 0 < sigma1 < sigma2 < 1");
      22           8 :         if (sigma2 <= 0 or sigma1 >= sigma2 or sigma2 >= 1)
      23           0 :             throw InvalidArgumentError(
      24           0 :                 "BarzilaiBorwein: sigma2 has to satisfy 0 < sigma1 < sigma2 < 1");
      25           8 :         if (epsilon <= 0)
      26           0 :             throw InvalidArgumentError("BarzilaiBorwein: epsilon has to be in the range (0,1)");
      27           8 :         _invepsilon = 1 / epsilon;
      28           8 :         _function_vals.reserve(m);
      29           8 :         _li_prev = 1;
      30           8 :         _iter = 0;
      31           8 :     }
      32             :     template <typename data_t>
      33             :     data_t BarzilaiBorwein<data_t>::solve(DataContainer<data_t> xi, DataContainer<data_t> di)
      34          40 :     {
      35          40 :         auto derphi = di.dot(di);
      36          40 :         if (_iter == 0) {
      37           4 :             _gi_prev = -di;
      38           4 :             if (_m > 0) {
      39           4 :                 _function_vals.push_back(this->_problem->evaluate(xi));
      40           4 :             }
      41           4 :             _derphi_prev = derphi;
      42           4 :         }
      43          40 :         auto ai = -_gi_prev.dot(-di - _gi_prev) / (_li_prev * _derphi_prev);
      44          40 :         if (ai <= _epsilon or ai >= _invepsilon) {
      45           4 :             auto g_norm = std::sqrt(derphi);
      46           4 :             if (g_norm > 1) {
      47           4 :                 ai = 1;
      48           4 :             } else if (g_norm >= 1e-5 and g_norm <= 1) {
      49           0 :                 ai = 1 / g_norm;
      50           0 :             } else {
      51           0 :                 ai = 1e5;
      52           0 :             }
      53           4 :         }
      54          40 :         auto li = 1 / ai;
      55          40 :         data_t max_prev_f;
      56          40 :         if (_m > 0) {
      57          40 :             max_prev_f = *std::max_element(_function_vals.begin(), _function_vals.end());
      58          40 :         }
      59          40 :         data_t fi;
      60          40 :         ++_iter;
      61          50 :         for (index_t i = 0; i < this->_max_iterations and _m > 0; ++i) {
      62          50 :             fi = this->_problem->evaluate(xi + li * di);
      63          50 :             if (fi <= max_prev_f - _gamma * li * derphi) {
      64          40 :                 break;
      65          40 :             } else {
      66          10 :                 li = (_sigma2 - _sigma1) / 2 * li;
      67          10 :             }
      68          50 :         }
      69          40 :         _li_prev = li;
      70          40 :         if (_m > 0) {
      71             : 
      72          40 :             if (_iter < _m) {
      73          36 :                 _function_vals.push_back(fi);
      74          36 :             } else {
      75           4 :                 _function_vals[_iter % _m] = fi;
      76           4 :             }
      77          40 :         }
      78          40 :         _gi_prev = -di;
      79          40 :         _derphi_prev = derphi;
      80          40 :         return li;
      81          40 :     }
      82             : 
      83             :     template <typename data_t>
      84             :     BarzilaiBorwein<data_t>* BarzilaiBorwein<data_t>::cloneImpl() const
      85           4 :     {
      86           4 :         return new BarzilaiBorwein(*this->_problem, _m, _gamma, _sigma1, _sigma2, _epsilon,
      87           4 :                                    this->_max_iterations);
      88           4 :     }
      89             : 
      90             :     template <typename data_t>
      91             :     bool BarzilaiBorwein<data_t>::isEqual(const LineSearchMethod<data_t>& other) const
      92           4 :     {
      93           4 :         auto otherBB = downcast_safe<BarzilaiBorwein<data_t>>(&other);
      94           4 :         if (!otherBB)
      95           0 :             return false;
      96             : 
      97           4 :         return (_m == otherBB->_m and _gamma == otherBB->_gamma and _sigma1 == otherBB->_sigma1
      98           4 :                 and _sigma2 == otherBB->_sigma2 and _epsilon == otherBB->_epsilon);
      99           4 :     }
     100             :     // ------------------------------------------
     101             :     // explicit template instantiation
     102             :     template class BarzilaiBorwein<float>;
     103             :     template class BarzilaiBorwein<double>;
     104             : } // namespace elsa

Generated by: LCOV version 1.14