Line data Source code
1 : #include "StrongWolfeCondition.h"
2 : #include "utils/utils.h"
3 :
4 : namespace elsa
5 : {
6 :
7 : template <typename data_t>
8 : StrongWolfeCondition<data_t>::StrongWolfeCondition(const Functional<data_t>& problem,
9 : data_t amax, data_t c1, data_t c2,
10 : index_t max_iterations)
11 : : LineSearchMethod<data_t>(problem, max_iterations), _amax(amax), _c1(c1), _c2(c2)
12 40 : {
13 : // sanity checks
14 40 : if (amax <= 0)
15 0 : throw InvalidArgumentError("StrongWolfeCondition: amax has to be greater than 0");
16 40 : if (c1 <= 0 or c1 >= 1.0)
17 0 : throw InvalidArgumentError("StrongWolfeCondition: c1 has to be in the range (0,1)");
18 40 : if (c2 <= c1 or c2 >= 1.0)
19 0 : throw InvalidArgumentError("StrongWolfeCondition: c2 has to be in the range (c1,1)");
20 40 : }
21 :
22 : template <typename data_t>
23 : data_t StrongWolfeCondition<data_t>::_zoom(data_t a_lo, data_t a_hi, data_t f_lo, data_t f_hi,
24 : data_t f0, data_t der_f_lo, data_t der_f0,
25 : const DataContainer<data_t>& xi,
26 : const DataContainer<data_t>& di,
27 : index_t max_iterations)
28 46 : {
29 46 : data_t aj = 0;
30 46 : auto fj = f0;
31 46 : data_t amin = 0;
32 46 : data_t dalpha, a, b, cchk;
33 189 : for (index_t i = 0; i < max_iterations; ++i) {
34 189 : dalpha = a_hi - a_lo;
35 189 : if (dalpha < 0) {
36 9 : a = a_hi;
37 9 : b = a_lo;
38 180 : } else {
39 180 : a = a_lo;
40 180 : b = a_hi;
41 180 : }
42 189 : cchk = 0.2 * dalpha;
43 : // TODO: check if interpolation is successful
44 189 : if (i > 0) {
45 143 : amin = cubic_interpolation(a_lo, f_lo, der_f_lo, a_hi, f_hi, aj, fj);
46 143 : }
47 189 : if (i == 0 or std::isnan(amin) or amin > b - cchk or amin < a + cchk) {
48 159 : amin = a_lo + 0.5 * dalpha;
49 159 : }
50 189 : auto fmin = this->_problem->evaluate(xi + amin * di);
51 189 : if ((fmin > f0 + _c1 * amin * der_f0) or fmin >= f_lo) {
52 134 : aj = a_hi;
53 134 : a_hi = amin;
54 134 : fj = f_hi;
55 134 : f_hi = fmin;
56 134 : } else {
57 55 : auto der_fmin = di.dot(this->_problem->getGradient(xi + amin * di));
58 55 : if (std::abs(der_fmin) <= -_c2 * der_f0) {
59 46 : return amin;
60 46 : }
61 9 : if (der_fmin * dalpha >= 0) {
62 9 : aj = a_hi;
63 9 : a_hi = a_lo;
64 9 : fj = f_hi;
65 9 : f_hi = f_lo;
66 9 : } else {
67 0 : aj = a_lo;
68 0 : fj = f_lo;
69 0 : }
70 9 : a_lo = amin;
71 9 : f_lo = fmin;
72 9 : der_f_lo = der_fmin;
73 9 : }
74 189 : }
75 46 : return amin;
76 46 : }
77 :
78 : template <typename data_t>
79 : data_t StrongWolfeCondition<data_t>::solve(DataContainer<data_t> xi, DataContainer<data_t> di)
80 126 : {
81 126 : data_t ai_1 = 0;
82 126 : data_t f0 = this->_problem->evaluate(xi);
83 126 : data_t der_f0 = di.dot(this->_problem->getGradient(xi));
84 126 : auto fi_1 = f0;
85 126 : auto der_fi_1 = der_f0;
86 126 : auto ai = std::min(static_cast<data_t>(1.0), _amax);
87 126 : for (index_t i = 0; i < this->_max_iterations; ++i) {
88 126 : auto fi = this->_problem->evaluate(xi + ai * di);
89 126 : if ((fi > f0 + _c1 * ai * der_f0) or (fi >= fi_1 and i > 0)) {
90 46 : return _zoom(ai_1, ai, fi_1, fi, f0, der_fi_1, der_f0, xi, di);
91 46 : }
92 80 : auto der_fi = di.dot(this->_problem->getGradient(xi + ai * di));
93 80 : if (std::abs(der_fi) <= _c2 * std::abs(der_f0)) {
94 80 : return ai;
95 80 : }
96 0 : if (der_fi >= 0) {
97 0 : return _zoom(ai, ai_1, fi, fi_1, f0, der_fi, der_f0, xi, di);
98 0 : }
99 0 : ai_1 = ai;
100 0 : ai = std::min(2 * ai, _amax);
101 0 : fi_1 = fi;
102 0 : der_fi_1 = der_fi;
103 0 : }
104 126 : return ai;
105 126 : }
106 :
107 : template <typename data_t>
108 : StrongWolfeCondition<data_t>* StrongWolfeCondition<data_t>::cloneImpl() const
109 24 : {
110 24 : return new StrongWolfeCondition(*this->_problem, _amax, _c1, _c2, this->_max_iterations);
111 24 : }
112 :
113 : template <typename data_t>
114 : bool StrongWolfeCondition<data_t>::isEqual(const LineSearchMethod<data_t>& other) const
115 4 : {
116 4 : auto otherSWC = downcast_safe<StrongWolfeCondition<data_t>>(&other);
117 4 : if (!otherSWC)
118 0 : return false;
119 :
120 4 : return (_amax == otherSWC->_amax && _c1 == otherSWC->_c1 && _c2 == otherSWC->_c2);
121 4 : }
122 : // ------------------------------------------
123 : // explicit template instantiation
124 : template class StrongWolfeCondition<float>;
125 : template class StrongWolfeCondition<double>;
126 : } // namespace elsa
|