Line data Source code
1 : #include "BFGS.h" 2 : #include "Logger.h" 3 : 4 : namespace elsa 5 : { 6 : 7 : template <typename data_t> 8 : BFGS<data_t>::BFGS(const Functional<data_t>& problem, 9 : const LineSearchMethod<data_t>& line_search_method, const data_t& tol) 10 : : Solver<data_t>(), _problem(problem.clone()), _ls(line_search_method.clone()), _tol{tol} 11 10 : { 12 : // sanity check 13 10 : if (tol < 0) 14 0 : throw InvalidArgumentError("BFGS: tolerance has to be non-negative"); 15 10 : } 16 : 17 : template <typename data_t> 18 : DataContainer<data_t> BFGS<data_t>::solve(index_t iterations, 19 : std::optional<DataContainer<data_t>> x0) 20 4 : { 21 4 : auto xi = extract_or(x0, _problem->getDomainDescriptor()); 22 4 : auto xi_1 = DataContainer<data_t>(_problem->getDomainDescriptor()); 23 4 : auto gi = _problem->getGradient(xi); 24 4 : auto gi_1 = DataContainer<data_t>(_problem->getDomainDescriptor()); 25 4 : auto n = xi.getSize(); 26 : 27 120 : auto to_map = [](auto&& vect) -> Eigen::Map<Vector_t<data_t>> { 28 120 : return Eigen::Map<Vector_t<data_t>>(thrust::raw_pointer_cast(vect.storage().data()), 29 120 : vect.getSize()); 30 120 : }; 31 : 32 4 : auto H = Matrix_t<data_t>(n, n); 33 4 : auto I = Matrix_t<data_t>::Identity(n, n); 34 : 35 4 : auto di = -gi; 36 : 37 4 : xi_1 = xi; 38 4 : gi_1 = gi; 39 4 : xi += _ls->solve(xi, di) * di; 40 4 : gi = _problem->getGradient(xi); 41 4 : auto si = xi - xi_1; 42 4 : auto yi = gi - gi_1; 43 4 : H = yi.dot(si) / yi.dot(yi) * I; 44 4 : auto rho = 1 / yi.dot(si); 45 4 : Logger::get("BFGS")->info("iteration {} of {}", 1, iterations); 46 : 47 44 : for (index_t i = 1; i < iterations; ++i) { 48 40 : Logger::get("BFGS")->info("iteration {} of {}", i + 1, iterations); 49 40 : if (gi.l2Norm() < _tol) { 50 0 : return xi; 51 0 : } 52 40 : auto si_map = to_map(si); 53 40 : auto yi_map = to_map(yi); 54 40 : auto gi_map = to_map(gi); 55 40 : H = (I - rho * si_map * yi_map.transpose()) * H 56 40 : * (I - rho * yi_map * si_map.transpose()) 57 40 : + rho * si_map * si_map.transpose(); 58 40 : di = DataContainer<data_t>{gi.getDataDescriptor(), -H * gi_map}; 59 40 : xi_1 = xi; 60 40 : gi_1 = gi; 61 40 : xi += _ls->solve(xi, di) * di; 62 40 : gi = _problem->getGradient(xi); 63 40 : si = xi - xi_1; 64 40 : yi = gi - gi_1; 65 40 : rho = 1 / yi.dot(si); 66 40 : } 67 : 68 4 : return xi; 69 4 : } // namespace elsa 70 : 71 : template <typename data_t> 72 : BFGS<data_t>* BFGS<data_t>::cloneImpl() const 73 4 : { 74 4 : return new BFGS(*_problem, *_ls, _tol); 75 4 : } 76 : 77 : template <typename data_t> 78 : bool BFGS<data_t>::isEqual(const Solver<data_t>& other) const 79 4 : { 80 4 : auto otherBFGS = downcast_safe<BFGS<data_t>>(&other); 81 4 : if (!otherBFGS) 82 0 : return false; 83 : 84 : // TODO: compare line search methods 85 4 : return _tol == otherBFGS->_tol; 86 4 : } 87 : 88 : // ------------------------------------------ 89 : // explicit template instantiation 90 : template class BFGS<float>; 91 : template class BFGS<double>; 92 : } // namespace elsa