LCOV - code coverage report
Current view: top level - elsa/solvers - LBFGS.cpp (source / functions) Hit Total Coverage
Test: coverage-all.lcov Lines: 57 60 95.0 %
Date: 2024-05-16 04:22:26 Functions: 8 8 100.0 %

          Line data    Source code
       1             : #include "LBFGS.h"
       2             : #include "Logger.h"
       3             : 
       4             : namespace elsa
       5             : {
       6             : 
       7             :     template <typename data_t>
       8             :     LBFGS<data_t>::LBFGS(const Functional<data_t>& problem,
       9             :                          const LineSearchMethod<data_t>& line_search_method, const index_t& memory,
      10             :                          const data_t& tol)
      11             :         : Solver<data_t>(),
      12             :           _problem(problem.clone()),
      13             :           _ls(line_search_method.clone()),
      14             :           _m{memory},
      15             :           _tol{tol}
      16          10 :     {
      17             :         // sanity check
      18          10 :         if (tol < 0)
      19           0 :             throw InvalidArgumentError("LBFGS: tolerance has to be non-negative");
      20          10 :         if (memory < 1)
      21           0 :             throw InvalidArgumentError("LBFGS: memory has to be positive");
      22          10 :     }
      23             : 
      24             :     template <typename data_t>
      25             :     DataContainer<data_t> LBFGS<data_t>::solve(index_t iterations,
      26             :                                                std::optional<DataContainer<data_t>> x0)
      27           4 :     {
      28             : 
      29           4 :         std::vector<DataContainer<data_t>> siVec;
      30           4 :         std::vector<DataContainer<data_t>> yiVec;
      31           4 :         std::vector<data_t> rhoVec(_m, 1);
      32           4 :         std::vector<data_t> alphaVec(_m, 1);
      33             : 
      34           4 :         siVec.reserve(_m);
      35           4 :         yiVec.reserve(_m);
      36             : 
      37           4 :         auto xi = extract_or(x0, _problem->getDomainDescriptor());
      38           4 :         auto xi_1 = DataContainer<data_t>(_problem->getDomainDescriptor());
      39           4 :         auto gi = _problem->getGradient(xi);
      40           4 :         auto gi_1 = DataContainer<data_t>(_problem->getDomainDescriptor());
      41             : 
      42           4 :         auto di = -gi;
      43          44 :         for (index_t i = 0; i < iterations; ++i) {
      44          42 :             Logger::get("LBFGS")->info("iteration {} of {}", i + 1, iterations);
      45             : 
      46          42 :             xi_1 = xi;
      47          42 :             gi_1 = gi;
      48          42 :             xi += _ls->solve(xi, di) * di;
      49          42 :             gi = _problem->getGradient(xi);
      50             : 
      51          42 :             if (gi.l2Norm() < _tol) {
      52           2 :                 return xi;
      53           2 :             }
      54             : 
      55          40 :             if (i < _m) {
      56          38 :                 siVec.push_back(xi - xi_1);
      57          38 :                 yiVec.push_back(gi - gi_1);
      58          38 :             } else {
      59           2 :                 siVec[i % _m] = xi - xi_1;
      60           2 :                 yiVec[i % _m] = gi - gi_1;
      61           2 :             }
      62          40 :             rhoVec[i % _m] = 1 / yiVec[i % _m].dot(siVec[i % _m]);
      63          40 :             di = gi;
      64         260 :             for (index_t j = i % _m, k = 0; k < i + 1 && k < _m; ++k, j = (j - 1 + _m) % _m) {
      65         220 :                 alphaVec[j] = rhoVec[j] * siVec[j].dot(di);
      66         220 :                 di -= alphaVec[j] * yiVec[j];
      67         220 :             }
      68             : 
      69          40 :             auto gamma = yiVec[i % _m].dot(siVec[i % _m]) / yiVec[i % _m].dot(yiVec[i % _m]);
      70          40 :             di *= gamma;
      71             : 
      72         260 :             for (index_t k = 0, j = (i < _m ? 0 : (i + 1) % _m); k < i + 1 && k < _m;
      73         220 :                  ++k, j = (j + 1) % _m) {
      74         220 :                 auto beta = rhoVec[j] * yiVec[j].dot(di);
      75         220 :                 di += siVec[j] * (alphaVec[j] - beta);
      76         220 :             }
      77             : 
      78          40 :             di = -di;
      79          40 :         }
      80             : 
      81           4 :         return xi;
      82           4 :     } // namespace elsa
      83             : 
      84             :     template <typename data_t>
      85             :     LBFGS<data_t>* LBFGS<data_t>::cloneImpl() const
      86           4 :     {
      87           4 :         return new LBFGS(*_problem, *_ls, _m, _tol);
      88           4 :     }
      89             : 
      90             :     template <typename data_t>
      91             :     bool LBFGS<data_t>::isEqual(const Solver<data_t>& other) const
      92           4 :     {
      93           4 :         auto otherLBFGS = downcast_safe<LBFGS<data_t>>(&other);
      94           4 :         if (!otherLBFGS)
      95           0 :             return false;
      96             : 
      97             :         // TODO: compare line search methods
      98           4 :         return _tol == otherLBFGS->_tol && _m == otherLBFGS->_m;
      99           4 :     }
     100             : 
     101             :     // ------------------------------------------
     102             :     // explicit template instantiation
     103             :     template class LBFGS<float>;
     104             :     template class LBFGS<double>;
     105             : } // namespace elsa

Generated by: LCOV version 1.14