LCOV - code coverage report
Current view: top level - elsa/solvers - BFGS.cpp (source / functions) Hit Total Coverage
Test: coverage-all.lcov Lines: 53 57 93.0 %
Date: 2024-12-21 07:37:52 Functions: 10 10 100.0 %

          Line data    Source code
       1             : #include "BFGS.h"
       2             : #include "Logger.h"
       3             : 
       4             : namespace elsa
       5             : {
       6             : 
       7             :     template <typename data_t>
       8             :     BFGS<data_t>::BFGS(const Functional<data_t>& problem,
       9             :                        const LineSearchMethod<data_t>& line_search_method, const data_t& tol)
      10             :         : Solver<data_t>(), _problem(problem.clone()), _ls(line_search_method.clone()), _tol{tol}
      11          10 :     {
      12             :         // sanity check
      13          10 :         if (tol < 0)
      14           0 :             throw InvalidArgumentError("BFGS: tolerance has to be non-negative");
      15          10 :     }
      16             : 
      17             :     template <typename data_t>
      18             :     DataContainer<data_t> BFGS<data_t>::solve(index_t iterations,
      19             :                                               std::optional<DataContainer<data_t>> x0)
      20           4 :     {
      21           4 :         auto xi = extract_or(x0, _problem->getDomainDescriptor());
      22           4 :         auto xi_1 = DataContainer<data_t>(_problem->getDomainDescriptor());
      23           4 :         auto gi = _problem->getGradient(xi);
      24           4 :         auto gi_1 = DataContainer<data_t>(_problem->getDomainDescriptor());
      25           4 :         auto n = xi.getSize();
      26             : 
      27         120 :         auto to_map = [](auto&& vect) -> Eigen::Map<Vector_t<data_t>> {
      28         120 :             return Eigen::Map<Vector_t<data_t>>(thrust::raw_pointer_cast(vect.storage().data()),
      29         120 :                                                 vect.getSize());
      30         120 :         };
      31             : 
      32           4 :         auto H = Matrix_t<data_t>(n, n);
      33           4 :         auto I = Matrix_t<data_t>::Identity(n, n);
      34             : 
      35           4 :         auto di = -gi;
      36             : 
      37           4 :         xi_1 = xi;
      38           4 :         gi_1 = gi;
      39           4 :         xi += _ls->solve(xi, di) * di;
      40           4 :         gi = _problem->getGradient(xi);
      41           4 :         auto si = xi - xi_1;
      42           4 :         auto yi = gi - gi_1;
      43           4 :         H = yi.dot(si) / yi.dot(yi) * I;
      44           4 :         auto rho = 1 / yi.dot(si);
      45           4 :         Logger::get("BFGS")->info("iteration {} of {}", 1, iterations);
      46             : 
      47          44 :         for (index_t i = 1; i < iterations; ++i) {
      48          40 :             Logger::get("BFGS")->info("iteration {} of {}", i + 1, iterations);
      49          40 :             if (gi.l2Norm() < _tol) {
      50           0 :                 return xi;
      51           0 :             }
      52          40 :             auto si_map = to_map(si);
      53          40 :             auto yi_map = to_map(yi);
      54          40 :             auto gi_map = to_map(gi);
      55          40 :             H = (I - rho * si_map * yi_map.transpose()) * H
      56          40 :                     * (I - rho * yi_map * si_map.transpose())
      57          40 :                 + rho * si_map * si_map.transpose();
      58          40 :             di = DataContainer<data_t>{gi.getDataDescriptor(), -H * gi_map};
      59          40 :             xi_1 = xi;
      60          40 :             gi_1 = gi;
      61          40 :             xi += _ls->solve(xi, di) * di;
      62          40 :             gi = _problem->getGradient(xi);
      63          40 :             si = xi - xi_1;
      64          40 :             yi = gi - gi_1;
      65          40 :             rho = 1 / yi.dot(si);
      66          40 :         }
      67             : 
      68           4 :         return xi;
      69           4 :     } // namespace elsa
      70             : 
      71             :     template <typename data_t>
      72             :     BFGS<data_t>* BFGS<data_t>::cloneImpl() const
      73           4 :     {
      74           4 :         return new BFGS(*_problem, *_ls, _tol);
      75           4 :     }
      76             : 
      77             :     template <typename data_t>
      78             :     bool BFGS<data_t>::isEqual(const Solver<data_t>& other) const
      79           4 :     {
      80           4 :         auto otherBFGS = downcast_safe<BFGS<data_t>>(&other);
      81           4 :         if (!otherBFGS)
      82           0 :             return false;
      83             : 
      84             :         // TODO: compare line search methods
      85           4 :         return _tol == otherBFGS->_tol;
      86           4 :     }
      87             : 
      88             :     // ------------------------------------------
      89             :     // explicit template instantiation
      90             :     template class BFGS<float>;
      91             :     template class BFGS<double>;
      92             : } // namespace elsa

Generated by: LCOV version 1.14