Line data Source code
1 : #include "NewtonRaphson.h" 2 : #include "TypeCasts.hpp" 3 : 4 : namespace elsa 5 : { 6 : 7 : template <class data_t> 8 : NewtonRaphson<data_t>::NewtonRaphson(const Functional<data_t>& f, index_t iterations) 9 : : LineSearchMethod<data_t>(f), iters_(iterations) 10 2 : { 11 2 : } 12 : 13 : template <class data_t> 14 : data_t NewtonRaphson<data_t>::solve(DataContainer<data_t> xi, DataContainer<data_t> di) 15 2 : { 16 2 : auto alpha = data_t{-1.0} * this->_problem->getGradient(xi).dot(di) 17 2 : / di.dot(this->_problem->getHessian(xi).apply(di)); 18 : 19 2 : for (int i = 1; i < iters_; ++i) { 20 0 : xi = xi + alpha * di; 21 : 22 0 : alpha = data_t{-1.0} * this->_problem->getGradient(xi).dot(di) 23 0 : / di.dot(this->_problem->getHessian(xi).apply(di)); 24 0 : } 25 : 26 2 : return alpha; 27 2 : } 28 : 29 : /// implement the polymorphic comparison operation 30 : template <class data_t> 31 : bool NewtonRaphson<data_t>::isEqual(const LineSearchMethod<data_t>& other) const 32 0 : { 33 0 : auto tmp = downcast_safe<NewtonRaphson<data_t>>(&other); 34 0 : return static_cast<bool>(tmp); 35 0 : } 36 : 37 : template <class data_t> 38 : NewtonRaphson<data_t>* NewtonRaphson<data_t>::cloneImpl() const 39 0 : { 40 0 : return new NewtonRaphson(*this->_problem, iters_); 41 0 : } 42 : 43 : // ------------------------------------------ 44 : // explicit template instantiation 45 : template class NewtonRaphson<float>; 46 : template class NewtonRaphson<double>; 47 : } // namespace elsa