Line data Source code
1 : #include "GoldsteinCondition.h"
2 : #include "utils/utils.h"
3 :
4 : namespace elsa
5 : {
6 :
7 : template <typename data_t>
8 : GoldsteinCondition<data_t>::GoldsteinCondition(const Functional<data_t>& problem, data_t amax,
9 : data_t c, index_t max_iterations)
10 : : LineSearchMethod<data_t>(problem, max_iterations), _amax(amax), _c(c)
11 8 : {
12 : // sanity checks
13 8 : if (amax <= 0)
14 0 : throw InvalidArgumentError("GoldsteinCondition: amax has to be greater than 0");
15 8 : if (c <= 0 or c >= 0.5)
16 0 : throw InvalidArgumentError("GoldsteinCondition: c has to be in the range (0,0.5)");
17 8 : }
18 :
19 : template <typename data_t>
20 : data_t GoldsteinCondition<data_t>::_zoom(data_t a_lo, data_t a_hi, data_t f_lo, data_t f_hi,
21 : data_t f0, data_t der_f_lo, data_t der_f0,
22 : const DataContainer<data_t>& xi,
23 : const DataContainer<data_t>& di,
24 : index_t max_iterations)
25 36 : {
26 36 : data_t aj = 0;
27 36 : data_t fj = f0;
28 36 : data_t a_min = 0;
29 396 : for (index_t i = 0; i < max_iterations; ++i) {
30 360 : data_t cchk = static_cast<data_t>(0.2) * (a_hi - a_lo);
31 360 : if (i > 0) {
32 324 : a_min = cubic_interpolation<data_t>(a_lo, f_lo, der_f_lo, a_hi, f_hi, aj, fj);
33 324 : }
34 360 : if (i == 0 or std::isnan(a_min) or (a_min > a_hi - cchk) or (a_min < a_lo + cchk)) {
35 291 : a_min = a_lo + static_cast<data_t>(0.5) * (a_hi - a_lo);
36 291 : }
37 360 : data_t f_min = this->_problem->evaluate(xi + a_min * di);
38 360 : if ((f_min > f0 + _c * a_min * der_f0) or f_min >= f_lo) {
39 250 : aj = a_hi;
40 250 : a_hi = a_min;
41 250 : fj = f_hi;
42 250 : f_hi = f_min;
43 250 : } else {
44 110 : if (f_min <= f0 + (1 - _c) * a_min * der_f0) {
45 0 : return a_min;
46 0 : }
47 110 : data_t der_f_min = di.dot(this->_problem->getGradient(xi + a_min * di));
48 110 : if (der_f_min * (a_hi - a_lo) >= 0) {
49 76 : aj = a_hi;
50 76 : a_hi = a_lo;
51 76 : fj = f_hi;
52 76 : f_hi = f_lo;
53 76 : } else {
54 34 : aj = a_lo;
55 34 : fj = f_lo;
56 34 : }
57 110 : a_lo = a_min;
58 110 : f_lo = f_min;
59 110 : der_f_lo = der_f_min;
60 110 : }
61 360 : }
62 36 : return a_min;
63 36 : }
64 :
65 : template <typename data_t>
66 : data_t GoldsteinCondition<data_t>::solve(DataContainer<data_t> xi, DataContainer<data_t> di)
67 40 : {
68 40 : data_t ai_1 = 0;
69 40 : auto f0 = this->_problem->evaluate(xi);
70 40 : auto der_f0 = di.dot(this->_problem->getGradient(xi));
71 40 : auto fi_1 = f0;
72 40 : auto der_fi_1 = der_f0;
73 40 : auto ai = std::min(static_cast<data_t>(1.0), _amax);
74 41 : for (index_t i = 0; i < this->_max_iterations; ++i) {
75 41 : auto fi = this->_problem->evaluate(xi + ai * di);
76 41 : auto der_fi = di.dot(this->_problem->getGradient(xi + ai * di));
77 41 : if ((fi > f0 + _c * ai * der_f0) or (fi >= fi_1 and i > 0)) {
78 36 : return _zoom(ai_1, ai, fi_1, fi, f0, der_fi_1, der_f0, xi, di);
79 36 : }
80 5 : if (fi >= f0 + (1 - _c) * ai * der_f0) {
81 4 : return ai;
82 4 : }
83 1 : if (der_fi >= 0) {
84 0 : return _zoom(ai, ai_1, fi, fi_1, f0, der_fi, der_f0, xi, di);
85 0 : }
86 1 : auto a = 2 * ai;
87 1 : ai_1 = ai;
88 1 : ai = std::min(a, _amax);
89 1 : fi_1 = fi;
90 1 : der_fi_1 = der_fi;
91 1 : }
92 40 : return ai;
93 40 : }
94 :
95 : template <typename data_t>
96 : GoldsteinCondition<data_t>* GoldsteinCondition<data_t>::cloneImpl() const
97 4 : {
98 4 : return new GoldsteinCondition(*this->_problem, _amax, _c, this->_max_iterations);
99 4 : }
100 :
101 : template <typename data_t>
102 : bool GoldsteinCondition<data_t>::isEqual(const LineSearchMethod<data_t>& other) const
103 4 : {
104 4 : auto otherGC = downcast_safe<GoldsteinCondition<data_t>>(&other);
105 4 : if (!otherGC)
106 0 : return false;
107 :
108 4 : return (_amax == otherGC->_amax && _c == otherGC->_c);
109 4 : }
110 :
111 : // ------------------------------------------
112 : // explicit template instantiation
113 : template class GoldsteinCondition<float>;
114 : template class GoldsteinCondition<double>;
115 : } // namespace elsa
|