Line data Source code
1 : #include "LBFGS.h" 2 : #include "Logger.h" 3 : 4 : namespace elsa 5 : { 6 : 7 : template <typename data_t> 8 : LBFGS<data_t>::LBFGS(const Functional<data_t>& problem, 9 : const LineSearchMethod<data_t>& line_search_method, const index_t& memory, 10 : const data_t& tol) 11 : : Solver<data_t>(), 12 : _problem(problem.clone()), 13 : _ls(line_search_method.clone()), 14 : _m{memory}, 15 : _tol{tol} 16 10 : { 17 : // sanity check 18 10 : if (tol < 0) 19 0 : throw InvalidArgumentError("LBFGS: tolerance has to be non-negative"); 20 10 : if (memory < 1) 21 0 : throw InvalidArgumentError("LBFGS: memory has to be positive"); 22 10 : } 23 : 24 : template <typename data_t> 25 : DataContainer<data_t> LBFGS<data_t>::solve(index_t iterations, 26 : std::optional<DataContainer<data_t>> x0) 27 4 : { 28 : 29 4 : std::vector<DataContainer<data_t>> siVec; 30 4 : std::vector<DataContainer<data_t>> yiVec; 31 4 : std::vector<data_t> rhoVec(_m, 1); 32 4 : std::vector<data_t> alphaVec(_m, 1); 33 : 34 4 : siVec.reserve(_m); 35 4 : yiVec.reserve(_m); 36 : 37 4 : auto xi = extract_or(x0, _problem->getDomainDescriptor()); 38 4 : auto xi_1 = DataContainer<data_t>(_problem->getDomainDescriptor()); 39 4 : auto gi = _problem->getGradient(xi); 40 4 : auto gi_1 = DataContainer<data_t>(_problem->getDomainDescriptor()); 41 : 42 4 : auto di = -gi; 43 44 : for (index_t i = 0; i < iterations; ++i) { 44 42 : Logger::get("LBFGS")->info("iteration {} of {}", i + 1, iterations); 45 : 46 42 : xi_1 = xi; 47 42 : gi_1 = gi; 48 42 : xi += _ls->solve(xi, di) * di; 49 42 : gi = _problem->getGradient(xi); 50 : 51 42 : if (gi.l2Norm() < _tol) { 52 2 : return xi; 53 2 : } 54 : 55 40 : if (i < _m) { 56 38 : siVec.push_back(xi - xi_1); 57 38 : yiVec.push_back(gi - gi_1); 58 38 : } else { 59 2 : siVec[i % _m] = xi - xi_1; 60 2 : yiVec[i % _m] = gi - gi_1; 61 2 : } 62 40 : rhoVec[i % _m] = 1 / yiVec[i % _m].dot(siVec[i % _m]); 63 40 : di = gi; 64 260 : for (index_t j = i % _m, k = 0; k < i + 1 && k < _m; ++k, j = (j - 1 + _m) % _m) { 65 220 : alphaVec[j] = rhoVec[j] * siVec[j].dot(di); 66 220 : di -= alphaVec[j] * yiVec[j]; 67 220 : } 68 : 69 40 : auto gamma = yiVec[i % _m].dot(siVec[i % _m]) / yiVec[i % _m].dot(yiVec[i % _m]); 70 40 : di *= gamma; 71 : 72 260 : for (index_t k = 0, j = (i < _m ? 0 : (i + 1) % _m); k < i + 1 && k < _m; 73 220 : ++k, j = (j + 1) % _m) { 74 220 : auto beta = rhoVec[j] * yiVec[j].dot(di); 75 220 : di += siVec[j] * (alphaVec[j] - beta); 76 220 : } 77 : 78 40 : di = -di; 79 40 : } 80 : 81 4 : return xi; 82 4 : } // namespace elsa 83 : 84 : template <typename data_t> 85 : LBFGS<data_t>* LBFGS<data_t>::cloneImpl() const 86 4 : { 87 4 : return new LBFGS(*_problem, *_ls, _m, _tol); 88 4 : } 89 : 90 : template <typename data_t> 91 : bool LBFGS<data_t>::isEqual(const Solver<data_t>& other) const 92 4 : { 93 4 : auto otherLBFGS = downcast_safe<LBFGS<data_t>>(&other); 94 4 : if (!otherLBFGS) 95 0 : return false; 96 : 97 : // TODO: compare line search methods 98 4 : return _tol == otherLBFGS->_tol && _m == otherLBFGS->_m; 99 4 : } 100 : 101 : // ------------------------------------------ 102 : // explicit template instantiation 103 : template class LBFGS<float>; 104 : template class LBFGS<double>; 105 : } // namespace elsa