Line data Source code
1 : #include "SteepestDescentStepLS.h" 2 : namespace elsa 3 : { 4 : 5 : template <typename data_t> 6 : SteepestDescentStepLS<data_t>::SteepestDescentStepLS(const LeastSquares<data_t>& problem) 7 : : LineSearchMethod<data_t>(problem), 8 : _A(problem.getOperator().clone()), 9 : _b(problem.getDataVector()) 10 20 : { 11 20 : } 12 : 13 : template <typename data_t> 14 : data_t SteepestDescentStepLS<data_t>::solve(DataContainer<data_t> xi, DataContainer<data_t> di) 15 123 : { 16 123 : auto ad = _A->apply(di); 17 123 : auto nom = -ad.dot(_A->apply(xi)) + ad.dot(_b); 18 123 : auto denom = ad.dot(ad); 19 123 : return nom / denom; 20 123 : } 21 : 22 : template <typename data_t> 23 : SteepestDescentStepLS<data_t>* SteepestDescentStepLS<data_t>::cloneImpl() const 24 12 : { 25 12 : return new SteepestDescentStepLS(LeastSquares<data_t>(*_A, _b)); 26 12 : } 27 : 28 : template <typename data_t> 29 : bool SteepestDescentStepLS<data_t>::isEqual(const LineSearchMethod<data_t>& other) const 30 4 : { 31 4 : auto otherSD = downcast_safe<SteepestDescentStepLS<data_t>>(&other); 32 4 : if (!otherSD) 33 0 : return false; 34 : 35 4 : return (*_A == *otherSD->_A && _b == otherSD->_b); 36 4 : } 37 : 38 : // ------------------------------------------ 39 : // explicit template instantiation 40 : template class SteepestDescentStepLS<float>; 41 : template class SteepestDescentStepLS<double>; 42 : } // namespace elsa