Line data Source code
1 : #include "ArmijoCondition.h" 2 : namespace elsa 3 : { 4 : 5 : template <typename data_t> 6 : ArmijoCondition<data_t>::ArmijoCondition(const Functional<data_t>& problem, data_t amax, 7 : data_t c, data_t rho, index_t max_iterations) 8 : : LineSearchMethod<data_t>(problem, max_iterations), _amax(amax), _c(c), _rho(rho) 9 8 : { 10 : // sanity checks 11 8 : if (amax <= 0) 12 0 : throw InvalidArgumentError("ArmijoCondition: amax has to be positive"); 13 8 : if (c <= 0 or c >= 1) 14 0 : throw InvalidArgumentError("ArmijoCondition: c has to be in the range (0,1)"); 15 8 : if (rho <= 0 or rho >= 1) 16 0 : throw InvalidArgumentError("ArmijoCondition: rho has to be in the range (0,1)"); 17 8 : } 18 : 19 : template <typename data_t> 20 : data_t ArmijoCondition<data_t>::solve(DataContainer<data_t> xi, DataContainer<data_t> di) 21 40 : { 22 40 : auto ai = _amax; 23 40 : auto f0 = this->_problem->evaluate(xi); 24 40 : auto f0_prime = di.dot(this->_problem->getGradient(xi)); 25 224 : for (index_t i = 0; i < this->_max_iterations; ++i) { 26 224 : if (this->_problem->evaluate(xi + ai * di) <= f0 + _c * ai * f0_prime) { 27 40 : return ai; 28 184 : } else { 29 184 : ai = ai * _rho; 30 184 : } 31 224 : } 32 40 : return ai; 33 40 : } 34 : 35 : template <typename data_t> 36 : ArmijoCondition<data_t>* ArmijoCondition<data_t>::cloneImpl() const 37 4 : { 38 4 : return new ArmijoCondition(*this->_problem, _amax, _c, _rho, this->_max_iterations); 39 4 : } 40 : 41 : template <typename data_t> 42 : bool ArmijoCondition<data_t>::isEqual(const LineSearchMethod<data_t>& other) const 43 4 : { 44 4 : auto otherAC = downcast_safe<ArmijoCondition<data_t>>(&other); 45 4 : if (!otherAC) 46 0 : return false; 47 : 48 4 : return (_amax == otherAC->_amax && _c == otherAC->_c && _rho == otherAC->_rho); 49 4 : } 50 : 51 : // ------------------------------------------ 52 : // explicit template instantiation 53 : template class ArmijoCondition<float>; 54 : template class ArmijoCondition<double>; 55 : } // namespace elsa