Line data Source code
1 : #include "CGNL.h"
2 : #include "Logger.h"
3 : #include "TypeCasts.hpp"
4 : #include "spdlog/stopwatch.h"
5 : #include "LineSearchMethod.h"
6 :
7 : namespace elsa
8 : {
9 : template <typename data_t>
10 : CGNL<data_t>::CGNL(const Functional<data_t>& functional,
11 : const LineSearchMethod<data_t>& line_search_function)
12 : : Solver<data_t>(),
13 : f_{functional.clone()},
14 : r_(empty<data_t>(functional.getDomainDescriptor())),
15 : d_(empty<data_t>(functional.getDomainDescriptor())),
16 : delta_(0),
17 : deltaZero_(0),
18 : beta_(0),
19 : alpha_(0),
20 : lineSearch_{line_search_function.clone()},
21 : beta_function_{betaPolakRibiere}
22 4 : {
23 4 : }
24 :
25 : template <typename data_t>
26 : CGNL<data_t>::CGNL(const Functional<data_t>& functional,
27 : const LineSearchMethod<data_t>& line_search_function,
28 : const BetaFunction& beta_function)
29 : : Solver<data_t>(),
30 : f_{functional.clone()},
31 : r_(empty<data_t>(functional.getDomainDescriptor())),
32 : d_(empty<data_t>(functional.getDomainDescriptor())),
33 : delta_(0),
34 : deltaZero_(0),
35 : beta_(0),
36 : alpha_(0),
37 : lineSearch_{line_search_function.clone()},
38 : beta_function_{beta_function}
39 4 : {
40 4 : }
41 :
42 : template <typename data_t>
43 : DataContainer<data_t> CGNL<data_t>::setup(std::optional<DataContainer<data_t>> x0)
44 8 : {
45 8 : auto x = extract_or(x0, f_->getDomainDescriptor());
46 :
47 : // r <= -f'(x)
48 8 : r_ = data_t{-1.0} * f_->getGradient(x);
49 8 : d_ = r_;
50 :
51 : // delta <= r^T * d
52 8 : delta_ = r_.dot(d_);
53 :
54 : // deltaZero <= delta
55 8 : deltaZero_ = delta_;
56 :
57 : // restart every n
58 8 : restart_ = f_->getDomainDescriptor().getNumberOfCoefficients();
59 :
60 8 : this->name_ = "CGNL";
61 :
62 : // setup done!
63 8 : this->configured_ = true;
64 :
65 8 : return x;
66 8 : }
67 :
68 : template <typename data_t>
69 : DataContainer<data_t> CGNL<data_t>::step(DataContainer<data_t> x)
70 83 : {
71 : // line search
72 83 : alpha_ = lineSearch_->solve(x, d_);
73 :
74 : // update x
75 83 : x += alpha_ * d_;
76 :
77 : // beta function
78 : // r <= -f'(x)
79 83 : r_ = data_t{-1.0} * f_->getGradient(x);
80 :
81 83 : std::tie(beta_, delta_) = beta_function_(d_, r_, delta_);
82 :
83 83 : if (this->curiter_ % restart_ == 0 || beta_ <= 0.0) {
84 23 : d_ = r_;
85 60 : } else {
86 60 : d_ = r_ + beta_ * d_;
87 60 : }
88 :
89 83 : return x;
90 83 : }
91 :
92 : template <typename data_t>
93 : bool CGNL<data_t>::shouldStop() const
94 86 : {
95 86 : return delta_ <= epsilon_ * epsilon_ * deltaZero_;
96 86 : }
97 :
98 : template <typename data_t>
99 : std::string CGNL<data_t>::formatHeader() const
100 8 : {
101 8 : return fmt::format("{:^15} | {:^15} | {:^15}", "delta", "beta", "alpha");
102 8 : }
103 :
104 : template <typename data_t>
105 : std::string CGNL<data_t>::formatStep(const DataContainer<data_t>&) const
106 83 : {
107 83 : return fmt::format("{:>15.10} | {:>15.10} | {:>15.10}", delta_, beta_, alpha_);
108 83 : }
109 :
110 : template <typename data_t>
111 : CGNL<data_t>* CGNL<data_t>::cloneImpl() const
112 4 : {
113 4 : return new CGNL(*f_, *lineSearch_, beta_function_);
114 4 : }
115 :
116 : template <typename data_t>
117 : bool CGNL<data_t>::isEqual(const Solver<data_t>& other) const
118 4 : {
119 4 : auto otherCG = downcast_safe<CGNL>(&other);
120 4 : if (!otherCG)
121 0 : return false;
122 :
123 4 : if (epsilon_ != otherCG->epsilon_)
124 0 : return false;
125 :
126 4 : return true;
127 4 : }
128 :
129 : // ------------------------------------------
130 : // explicit template instantiation
131 : template class CGNL<float>;
132 : template class CGNL<double>;
133 : } // namespace elsa
|