Line data Source code
1 : #include "APGD.h"
2 : #include "DataContainer.h"
3 : #include "Error.h"
4 : #include "Functional.h"
5 : #include "LeastSquares.h"
6 : #include "LinearOperator.h"
7 : #include "LinearResidual.h"
8 : #include "ProximalL1.h"
9 : #include "TypeCasts.hpp"
10 : #include "Logger.h"
11 : #include "PowerIterations.h"
12 :
13 : #include "WeightedLeastSquares.h"
14 : #include "spdlog/stopwatch.h"
15 :
16 : namespace elsa
17 : {
18 : template <typename data_t>
19 : APGD<data_t>::APGD(const LinearOperator<data_t>& A, const DataContainer<data_t>& b,
20 : const Functional<data_t>& h, std::optional<data_t> mu, data_t epsilon)
21 : : g_(LeastSquares<data_t>(A, b).clone()),
22 : h_(h.clone()),
23 : xPrev_(empty<data_t>(g_->getDomainDescriptor())),
24 : y_(empty<data_t>(g_->getDomainDescriptor())),
25 : z_(empty<data_t>(g_->getDomainDescriptor())),
26 : grad_(empty<data_t>(g_->getDomainDescriptor())),
27 : epsilon_(epsilon)
28 4 : {
29 4 : if (!h.isProxFriendly()) {
30 0 : throw Error("APGD: h must be prox friendly");
31 0 : }
32 :
33 4 : if (mu.has_value()) {
34 4 : lineSearchMethod_ = FixedStepSize<data_t>(*g_, *mu).clone();
35 4 : } else {
36 0 : Logger::get("APGD")->info("Computing Lipschitz constant for least squares...");
37 : // Chose it a little larger, to be safe
38 0 : auto L = 1.05 * powerIterations(adjoint(A) * A);
39 0 : lineSearchMethod_ = FixedStepSize<data_t>(*g_, 1 / L).clone();
40 0 : Logger::get("APGD")->info("Step length chosen to be: {}", 1 / L);
41 0 : }
42 :
43 4 : this->name_ = "APGD";
44 4 : }
45 :
46 : template <typename data_t>
47 : APGD<data_t>::APGD(const LinearOperator<data_t>& A, const DataContainer<data_t>& b,
48 : const DataContainer<data_t>& W, const Functional<data_t>& h,
49 : std::optional<data_t> mu, data_t epsilon)
50 : : g_(WeightedLeastSquares<data_t>(A, b, W).clone()),
51 : h_(h.clone()),
52 : xPrev_(empty<data_t>(g_->getDomainDescriptor())),
53 : y_(empty<data_t>(g_->getDomainDescriptor())),
54 : z_(empty<data_t>(g_->getDomainDescriptor())),
55 : grad_(empty<data_t>(g_->getDomainDescriptor())),
56 : epsilon_(epsilon)
57 2 : {
58 2 : if (!h.isProxFriendly()) {
59 0 : throw Error("APGD: h must be prox friendly");
60 0 : }
61 :
62 2 : if (mu.has_value()) {
63 2 : lineSearchMethod_ = FixedStepSize<data_t>(*g_, *mu).clone();
64 2 : } else {
65 0 : Logger::get("APGD")->info("Computing Lipschitz constant for least squares...");
66 : // Chose it a little larger, to be safe
67 0 : auto L = 1.05 * powerIterations(adjoint(A) * A);
68 0 : lineSearchMethod_ = FixedStepSize<data_t>(*g_, 1 / L).clone();
69 0 : Logger::get("APGD")->info("Step length chosen to be: {}", 1 / L);
70 0 : }
71 :
72 2 : this->name_ = "APGD";
73 2 : }
74 :
75 : template <typename data_t>
76 : APGD<data_t>::APGD(const Functional<data_t>& g, const Functional<data_t>& h, data_t mu,
77 : data_t epsilon)
78 : : g_(g.clone()),
79 : h_(h.clone()),
80 : xPrev_(empty<data_t>(g_->getDomainDescriptor())),
81 : y_(empty<data_t>(g_->getDomainDescriptor())),
82 : z_(empty<data_t>(g_->getDomainDescriptor())),
83 : grad_(empty<data_t>(g_->getDomainDescriptor())),
84 : lineSearchMethod_(FixedStepSize<data_t>(*g_, mu).clone()),
85 : epsilon_(epsilon)
86 0 : {
87 0 : if (!h.isProxFriendly()) {
88 0 : throw Error("APGD: h must be prox friendly");
89 0 : }
90 :
91 0 : if (!g.isDifferentiable()) {
92 0 : throw Error("APGD: g must be differentiable");
93 0 : }
94 :
95 0 : this->name_ = "APGD";
96 0 : }
97 :
98 : template <typename data_t>
99 : APGD<data_t>::APGD(const LinearOperator<data_t>& A, const DataContainer<data_t>& b,
100 : const Functional<data_t>& h,
101 : const LineSearchMethod<data_t>& lineSearchMethod, data_t epsilon)
102 : : g_(LeastSquares<data_t>(A, b).clone()),
103 : h_(h.clone()),
104 : xPrev_(empty<data_t>(g_->getDomainDescriptor())),
105 : y_(empty<data_t>(g_->getDomainDescriptor())),
106 : z_(empty<data_t>(g_->getDomainDescriptor())),
107 : grad_(empty<data_t>(g_->getDomainDescriptor())),
108 : lineSearchMethod_(lineSearchMethod.clone()),
109 : epsilon_(epsilon)
110 0 : {
111 0 : if (!h.isProxFriendly()) {
112 0 : throw Error("APGD: h must be prox friendly");
113 0 : }
114 :
115 0 : this->name_ = "APGD";
116 0 : }
117 :
118 : template <typename data_t>
119 : APGD<data_t>::APGD(const LinearOperator<data_t>& A, const DataContainer<data_t>& b,
120 : const DataContainer<data_t>& W, const Functional<data_t>& h,
121 : const LineSearchMethod<data_t>& lineSearchMethod, data_t epsilon)
122 : : g_(WeightedLeastSquares<data_t>(A, b, W).clone()),
123 : h_(h.clone()),
124 : xPrev_(empty<data_t>(g_->getDomainDescriptor())),
125 : y_(empty<data_t>(g_->getDomainDescriptor())),
126 : z_(empty<data_t>(g_->getDomainDescriptor())),
127 : grad_(empty<data_t>(g_->getDomainDescriptor())),
128 : lineSearchMethod_(lineSearchMethod.clone()),
129 : epsilon_(epsilon)
130 0 : {
131 0 : if (!h.isProxFriendly()) {
132 0 : throw Error("APGD: h must be prox friendly");
133 0 : }
134 :
135 0 : this->name_ = "APGD";
136 0 : }
137 :
138 : template <typename data_t>
139 : APGD<data_t>::APGD(const Functional<data_t>& g, const Functional<data_t>& h,
140 : const LineSearchMethod<data_t>& lineSearchMethod, data_t epsilon)
141 : : g_(g.clone()),
142 : h_(h.clone()),
143 : xPrev_(empty<data_t>(g_->getDomainDescriptor())),
144 : y_(empty<data_t>(g_->getDomainDescriptor())),
145 : z_(empty<data_t>(g_->getDomainDescriptor())),
146 : grad_(empty<data_t>(g_->getDomainDescriptor())),
147 : lineSearchMethod_(lineSearchMethod.clone()),
148 : epsilon_(epsilon)
149 6 : {
150 6 : if (!h.isProxFriendly()) {
151 0 : throw Error("APGD: h must be prox friendly");
152 0 : }
153 :
154 6 : if (!g.isDifferentiable()) {
155 0 : throw Error("APGD: g must be differentiable");
156 0 : }
157 :
158 6 : this->name_ = "APGD";
159 6 : }
160 :
161 : template <typename data_t>
162 : DataContainer<data_t> APGD<data_t>::setup(std::optional<DataContainer<data_t>> x0)
163 6 : {
164 6 : auto x = extract_or(x0, g_->getDomainDescriptor());
165 :
166 6 : xPrev_ = x;
167 6 : y_ = x;
168 6 : z_ = x;
169 :
170 6 : tPrev_ = 1;
171 :
172 : // update gradient
173 6 : g_->getGradient(x, grad_);
174 :
175 : // setup done!
176 6 : this->configured_ = true;
177 :
178 6 : return x;
179 6 : }
180 :
181 : template <typename data_t>
182 : DataContainer<data_t> APGD<data_t>::step(DataContainer<data_t> x)
183 6 : {
184 6 : auto mu = lineSearchMethod_->solve(x, -grad_);
185 : // z = y - mu_ * grad
186 6 : lincomb(1, y_, -mu, grad_, z_);
187 :
188 : // x_{k+1} = prox_{mu * g}(y - mu * grad)
189 : // x = prox_.apply(z, mu_);
190 6 : x = h_->proximal(z_, mu);
191 :
192 : // t_{k+1} = \frac{\sqrt{1 + 4t_k^2} + 1}{2}
193 6 : data_t t = (1 + std::sqrt(1 + 4 * tPrev_ * tPrev_)) / 2;
194 :
195 : // y_{k+1} = x_k + \frac{t_{k-1} - 1}{t_k}(x_k - x_{k-1})
196 6 : lincomb(1, x, (tPrev_ - 1) / t, x - xPrev_, y_); // 1 temporary
197 :
198 6 : xPrev_ = x;
199 6 : tPrev_ = t;
200 :
201 : // update gradient last
202 6 : g_->getGradient(x, grad_);
203 :
204 6 : return x;
205 6 : }
206 :
207 : template <typename data_t>
208 : bool APGD<data_t>::shouldStop() const
209 12 : {
210 12 : return grad_.squaredL2Norm() <= epsilon_;
211 12 : }
212 :
213 : template <typename data_t>
214 : std::string APGD<data_t>::formatHeader() const
215 6 : {
216 6 : return fmt::format("{:^12} | {:^12}", "objective", "gradient");
217 6 : }
218 :
219 : template <typename data_t>
220 : std::string APGD<data_t>::formatStep(const DataContainer<data_t>& x) const
221 6 : {
222 6 : return fmt::format("{:>12.3} | {:>12.3}", g_->evaluate(x) + h_->evaluate(x),
223 6 : grad_.squaredL2Norm());
224 6 : }
225 :
226 : template <typename data_t>
227 : auto APGD<data_t>::cloneImpl() const -> APGD<data_t>*
228 6 : {
229 6 : return new APGD<data_t>(*g_, *h_, *lineSearchMethod_, epsilon_);
230 6 : }
231 :
232 : template <typename data_t>
233 : auto APGD<data_t>::isEqual(const Solver<data_t>& other) const -> bool
234 6 : {
235 6 : auto otherAPGD = downcast_safe<APGD>(&other);
236 6 : if (!otherAPGD)
237 0 : return false;
238 :
239 6 : if (not lineSearchMethod_->isEqual(*(otherAPGD->lineSearchMethod_)))
240 0 : return false;
241 :
242 6 : if (epsilon_ != otherAPGD->epsilon_)
243 0 : return false;
244 :
245 6 : return true;
246 6 : }
247 :
248 : // ------------------------------------------
249 : // explicit template instantiation
250 : template class APGD<float>;
251 : template class APGD<double>;
252 : } // namespace elsa
|