Line data Source code
1 : #include "POGM.h"
2 : #include "DataContainer.h"
3 : #include "Error.h"
4 : #include "Functional.h"
5 : #include "LeastSquares.h"
6 : #include "LinearOperator.h"
7 : #include "LinearResidual.h"
8 : #include "ProximalL1.h"
9 : #include "TypeCasts.hpp"
10 : #include "Logger.h"
11 : #include "PowerIterations.h"
12 :
13 : #include "WeightedLeastSquares.h"
14 : #include "spdlog/stopwatch.h"
15 : #include <cmath>
16 :
17 : namespace elsa
18 : {
19 : template <typename data_t>
20 : POGM<data_t>::POGM(const LinearOperator<data_t>& A, const DataContainer<data_t>& b,
21 : const Functional<data_t>& h, std::optional<data_t> mu, data_t epsilon)
22 : : g_(LeastSquares<data_t>(A, b).clone()), h_(h.clone()), epsilon_(epsilon)
23 2 : {
24 2 : if (!h.isProxFriendly()) {
25 0 : throw Error("POGM: h must be prox friendly");
26 0 : }
27 :
28 2 : if (mu.has_value()) {
29 2 : lineSearchMethod_ = FixedStepSize<data_t>(*g_, *mu).clone();
30 2 : } else {
31 0 : Logger::get("POGM")->info("Computing Lipschitz constant for least squares...");
32 : // Chose it a little larger, to be safe
33 0 : auto L = 1.05 * powerIterations(adjoint(A) * A);
34 0 : lineSearchMethod_ = FixedStepSize<data_t>(*g_, 1 / L).clone();
35 0 : Logger::get("POGM")->info("Step length chosen to be: {}", 1 / L);
36 0 : }
37 2 : }
38 :
39 : template <typename data_t>
40 : POGM<data_t>::POGM(const LinearOperator<data_t>& A, const DataContainer<data_t>& b,
41 : const DataContainer<data_t>& W, const Functional<data_t>& h,
42 : std::optional<data_t> mu, data_t epsilon)
43 : : g_(WeightedLeastSquares<data_t>(A, b, W).clone()), h_(h.clone()), epsilon_(epsilon)
44 2 : {
45 2 : if (!h.isProxFriendly()) {
46 0 : throw Error("APGD: h must be prox friendly");
47 0 : }
48 :
49 2 : if (mu.has_value()) {
50 2 : lineSearchMethod_ = FixedStepSize<data_t>(*g_, *mu).clone();
51 2 : } else {
52 0 : Logger::get("POGM")->info("Computing Lipschitz constant for least squares...");
53 : // Chose it a little larger, to be safe
54 0 : auto L = 1.05 * powerIterations(adjoint(A) * A);
55 0 : lineSearchMethod_ = FixedStepSize<data_t>(*g_, 1 / L).clone();
56 0 : Logger::get("POGM")->info("Step length chosen to be: {}", 1 / L);
57 0 : }
58 2 : }
59 :
60 : template <typename data_t>
61 : POGM<data_t>::POGM(const Functional<data_t>& g, const Functional<data_t>& h, data_t mu,
62 : data_t epsilon)
63 : : g_(g.clone()),
64 : h_(h.clone()),
65 : lineSearchMethod_(FixedStepSize<data_t>(*g_, mu).clone()),
66 : epsilon_(epsilon)
67 0 : {
68 0 : if (!h.isProxFriendly()) {
69 0 : throw Error("POGM: h must be prox friendly");
70 0 : }
71 :
72 0 : if (!g.isDifferentiable()) {
73 0 : throw Error("POGM: g must be differentiable");
74 0 : }
75 0 : }
76 :
77 : template <typename data_t>
78 : POGM<data_t>::POGM(const LinearOperator<data_t>& A, const DataContainer<data_t>& b,
79 : const Functional<data_t>& h,
80 : const LineSearchMethod<data_t>& lineSearchMethod, data_t epsilon)
81 : : g_(LeastSquares<data_t>(A, b).clone()),
82 : h_(h.clone()),
83 : lineSearchMethod_(lineSearchMethod.clone()),
84 : epsilon_(epsilon)
85 0 : {
86 0 : if (!h.isProxFriendly()) {
87 0 : throw Error("POGM: h must be prox friendly");
88 0 : }
89 0 : }
90 :
91 : template <typename data_t>
92 : POGM<data_t>::POGM(const LinearOperator<data_t>& A, const DataContainer<data_t>& b,
93 : const DataContainer<data_t>& W, const Functional<data_t>& h,
94 : const LineSearchMethod<data_t>& lineSearchMethod, data_t epsilon)
95 : : g_(WeightedLeastSquares<data_t>(A, b, W).clone()),
96 : h_(h.clone()),
97 : lineSearchMethod_(lineSearchMethod.clone()),
98 : epsilon_(epsilon)
99 0 : {
100 0 : if (!h.isProxFriendly()) {
101 0 : throw Error("APGD: h must be prox friendly");
102 0 : }
103 0 : }
104 :
105 : template <typename data_t>
106 : POGM<data_t>::POGM(const Functional<data_t>& g, const Functional<data_t>& h,
107 : const LineSearchMethod<data_t>& lineSearchMethod, data_t epsilon)
108 : : g_(g.clone()),
109 : h_(h.clone()),
110 : lineSearchMethod_(lineSearchMethod.clone()),
111 : epsilon_(epsilon)
112 4 : {
113 4 : if (!h.isProxFriendly()) {
114 0 : throw Error("POGM: h must be prox friendly");
115 0 : }
116 :
117 4 : if (!g.isDifferentiable()) {
118 0 : throw Error("POGM: g must be differentiable");
119 0 : }
120 4 : }
121 :
122 : template <typename data_t>
123 : auto POGM<data_t>::solve(index_t iterations, std::optional<DataContainer<data_t>> x0)
124 : -> DataContainer<data_t>
125 4 : {
126 4 : spdlog::stopwatch aggregate_time;
127 :
128 4 : auto x = DataContainer<data_t>(g_->getDomainDescriptor());
129 4 : if (x0.has_value()) {
130 0 : x = *x0;
131 4 : } else {
132 4 : x = 0;
133 4 : }
134 :
135 4 : auto w = x;
136 4 : auto wPrev = x;
137 4 : auto z = x;
138 :
139 4 : data_t theta = 1;
140 4 : data_t thetaPrev = 1;
141 :
142 4 : data_t gamma = 1;
143 4 : data_t gammaPrev = 1;
144 :
145 4 : auto grad = DataContainer<data_t>(g_->getDomainDescriptor());
146 :
147 4 : Logger::get("POGM")->info("| {:^6} | {:^12} | {:^12} | {:^9} |", "iter", "objective",
148 4 : "gradient", "elapsed");
149 :
150 204 : for (index_t iter = 0; iter < iterations; ++iter) {
151 200 : if (iter != iterations - 1) {
152 : // \frac{1}{2}(1 + \sqrt{4 \theta_{k-1}^2 + 1})
153 196 : theta = 0.5 * (1 + std::sqrt(4 * std::pow(thetaPrev, 2) + 1));
154 196 : } else {
155 : // \frac{1}{2}(1 + \sqrt{8 \theta_{k-1}^2 + 1})
156 4 : theta = 0.5 * (1 + std::sqrt(8 * std::pow(thetaPrev, 2) + 1));
157 4 : }
158 :
159 200 : auto mu = lineSearchMethod_->solve(x, -grad);
160 :
161 200 : gamma = mu * ((2 * thetaPrev) + theta - 1) / theta;
162 :
163 : // Compute gradient
164 200 : g_->getGradient(x, grad);
165 :
166 : // w = x - mu_ * grad
167 200 : lincomb(1, x, -mu, grad, w);
168 :
169 : // POGM term: w_k + (\theta_{k-1} - 1) / (L * \gamma_{k-1} * \theta_k) (z_{k-1} -
170 : // x_{k-1}) Use the fact the our mu should be (close to) 1 / L. Start with this term to
171 : // reuse z.
172 200 : data_t weight3 = mu * ((thetaPrev - 1) / (gammaPrev * theta));
173 200 : lincomb(1, w, weight3, z, z);
174 200 : lincomb(1, z, -weight3, x, z);
175 :
176 : // Nesterov momentum: (\theta_{k-1} - 1) / \theta_k (w_k - w_{k-1})
177 200 : auto weight1 = (thetaPrev - 1) / theta;
178 200 : lincomb(1, z, weight1, w, z);
179 200 : lincomb(1, z, -weight1, wPrev, z);
180 :
181 : // OGM mementum term: (\theta_{k-1} / \theta) (w_k - x_{k-1})
182 200 : data_t weight2 = thetaPrev / theta;
183 200 : lincomb(1, z, weight2, w, z);
184 200 : lincomb(1, z, -weight2, x, z);
185 :
186 : // x_{k+1} = prox_{gamma * g}(z)
187 200 : x = h_->proximal(z, gamma);
188 :
189 200 : wPrev = w;
190 :
191 200 : thetaPrev = theta;
192 200 : gammaPrev = gamma;
193 :
194 200 : if (grad.squaredL2Norm() <= epsilon_) {
195 0 : Logger::get("POGM")->info("SUCCESS: Reached convergence at {}/{} iteration",
196 0 : iter + 1, iterations);
197 0 : return x;
198 0 : }
199 :
200 200 : Logger::get("POGM")->info("| {:>6} | {:>12.3} | {:>12.3} | {:>8.3}s |", iter,
201 200 : g_->evaluate(x) + h_->evaluate(x), grad.squaredL2Norm(),
202 200 : aggregate_time);
203 200 : }
204 :
205 4 : Logger::get("POGM")->warn("Failed to reach convergence at {} iterations", iterations);
206 :
207 4 : return x;
208 4 : }
209 :
210 : template <typename data_t>
211 : auto POGM<data_t>::cloneImpl() const -> POGM<data_t>*
212 4 : {
213 4 : return new POGM<data_t>(*g_, *h_, *lineSearchMethod_, epsilon_);
214 4 : }
215 :
216 : template <typename data_t>
217 : auto POGM<data_t>::isEqual(const Solver<data_t>& other) const -> bool
218 4 : {
219 4 : auto otherPOGM = downcast_safe<POGM>(&other);
220 4 : if (!otherPOGM)
221 0 : return false;
222 :
223 4 : if (not lineSearchMethod_->isEqual(*(otherPOGM->lineSearchMethod_)))
224 0 : return false;
225 :
226 4 : if (epsilon_ != otherPOGM->epsilon_)
227 0 : return false;
228 :
229 4 : return true;
230 4 : }
231 :
232 : // ------------------------------------------
233 : // explicit template instantiation
234 : template class POGM<float>;
235 : template class POGM<double>;
236 : } // namespace elsa
|