Line data Source code
1 : #include "BarzilaiBorwein.h" 2 : namespace elsa 3 : { 4 : template <typename data_t> 5 : BarzilaiBorwein<data_t>::BarzilaiBorwein(const Functional<data_t>& problem, uint32_t m, 6 : data_t gamma, data_t sigma1, data_t sigma2, 7 : data_t epsilon, index_t max_iterations) 8 : : LineSearchMethod<data_t>(problem, max_iterations), 9 : _m(m), 10 : _gamma(gamma), 11 : _sigma1(sigma1), 12 : _sigma2(sigma2), 13 : _epsilon(epsilon), 14 : _gi_prev(DataContainer<data_t>(this->_problem->getDomainDescriptor())) 15 8 : { 16 : // sanity checks 17 8 : if (gamma <= 0 or gamma >= 1) 18 0 : throw InvalidArgumentError("BarzilaiBorwein: gamma has to be in the range (0,1)"); 19 8 : if (sigma1 <= 0 or sigma1 >= sigma2 or sigma1 >= 1) 20 0 : throw InvalidArgumentError( 21 0 : "BarzilaiBorwein: sigma1 has to satisfy 0 < sigma1 < sigma2 < 1"); 22 8 : if (sigma2 <= 0 or sigma1 >= sigma2 or sigma2 >= 1) 23 0 : throw InvalidArgumentError( 24 0 : "BarzilaiBorwein: sigma2 has to satisfy 0 < sigma1 < sigma2 < 1"); 25 8 : if (epsilon <= 0) 26 0 : throw InvalidArgumentError("BarzilaiBorwein: epsilon has to be in the range (0,1)"); 27 8 : _invepsilon = 1 / epsilon; 28 8 : _function_vals.reserve(m); 29 8 : _li_prev = 1; 30 8 : _iter = 0; 31 8 : } 32 : template <typename data_t> 33 : data_t BarzilaiBorwein<data_t>::solve(DataContainer<data_t> xi, DataContainer<data_t> di) 34 40 : { 35 40 : auto derphi = di.dot(di); 36 40 : if (_iter == 0) { 37 4 : _gi_prev = -di; 38 4 : if (_m > 0) { 39 4 : _function_vals.push_back(this->_problem->evaluate(xi)); 40 4 : } 41 4 : _derphi_prev = derphi; 42 4 : } 43 40 : auto ai = -_gi_prev.dot(-di - _gi_prev) / (_li_prev * _derphi_prev); 44 40 : if (ai <= _epsilon or ai >= _invepsilon) { 45 4 : auto g_norm = std::sqrt(derphi); 46 4 : if (g_norm > 1) { 47 4 : ai = 1; 48 4 : } else if (g_norm >= 1e-5 and g_norm <= 1) { 49 0 : ai = 1 / g_norm; 50 0 : } else { 51 0 : ai = 1e5; 52 0 : } 53 4 : } 54 40 : auto li = 1 / ai; 55 40 : data_t max_prev_f; 56 40 : if (_m > 0) { 57 40 : max_prev_f = *std::max_element(_function_vals.begin(), _function_vals.end()); 58 40 : } 59 40 : data_t fi; 60 40 : ++_iter; 61 50 : for (index_t i = 0; i < this->_max_iterations and _m > 0; ++i) { 62 50 : fi = this->_problem->evaluate(xi + li * di); 63 50 : if (fi <= max_prev_f - _gamma * li * derphi) { 64 40 : break; 65 40 : } else { 66 10 : li = (_sigma2 - _sigma1) / 2 * li; 67 10 : } 68 50 : } 69 40 : _li_prev = li; 70 40 : if (_m > 0) { 71 : 72 40 : if (_iter < _m) { 73 36 : _function_vals.push_back(fi); 74 36 : } else { 75 4 : _function_vals[_iter % _m] = fi; 76 4 : } 77 40 : } 78 40 : _gi_prev = -di; 79 40 : _derphi_prev = derphi; 80 40 : return li; 81 40 : } 82 : 83 : template <typename data_t> 84 : BarzilaiBorwein<data_t>* BarzilaiBorwein<data_t>::cloneImpl() const 85 4 : { 86 4 : return new BarzilaiBorwein(*this->_problem, _m, _gamma, _sigma1, _sigma2, _epsilon, 87 4 : this->_max_iterations); 88 4 : } 89 : 90 : template <typename data_t> 91 : bool BarzilaiBorwein<data_t>::isEqual(const LineSearchMethod<data_t>& other) const 92 4 : { 93 4 : auto otherBB = downcast_safe<BarzilaiBorwein<data_t>>(&other); 94 4 : if (!otherBB) 95 0 : return false; 96 : 97 4 : return (_m == otherBB->_m and _gamma == otherBB->_gamma and _sigma1 == otherBB->_sigma1 98 4 : and _sigma2 == otherBB->_sigma2 and _epsilon == otherBB->_epsilon); 99 4 : } 100 : // ------------------------------------------ 101 : // explicit template instantiation 102 : template class BarzilaiBorwein<float>; 103 : template class BarzilaiBorwein<double>; 104 : } // namespace elsa