Line data Source code
1 : #include "PGD.h"
2 : #include "DataContainer.h"
3 : #include "Functional.h"
4 : #include "LeastSquares.h"
5 : #include "LinearOperator.h"
6 : #include "LinearResidual.h"
7 : #include "ProximalL1.h"
8 : #include "Solver.h"
9 : #include "TypeCasts.hpp"
10 : #include "Logger.h"
11 : #include "PowerIterations.h"
12 :
13 : #include "WeightedLeastSquares.h"
14 : #include "spdlog/stopwatch.h"
15 :
16 : namespace elsa
17 : {
18 : template <typename data_t>
19 : PGD<data_t>::PGD(const LinearOperator<data_t>& A, const DataContainer<data_t>& b,
20 : const Functional<data_t>& h, std::optional<data_t> mu, data_t epsilon)
21 : : g_(LeastSquares<data_t>(A, b).clone()), h_(h.clone()), epsilon_(epsilon)
22 4 : {
23 4 : if (!h.isProxFriendly()) {
24 0 : throw Error("PGD: h must be prox friendly");
25 0 : }
26 :
27 4 : if (mu.has_value()) {
28 4 : lineSearchMethod_ = FixedStepSize<data_t>(*g_, *mu).clone();
29 4 : } else {
30 0 : Logger::get("PGD")->info("Computing Lipschitz constant for least squares...");
31 : // Chose it a little larger, to be safe
32 0 : auto L = 1.05 * powerIterations(adjoint(A) * A);
33 0 : lineSearchMethod_ = FixedStepSize<data_t>(*g_, 1 / L).clone();
34 0 : Logger::get("PGD")->info("Step length chosen to be: {}", 1 / L);
35 0 : }
36 4 : }
37 :
38 : template <typename data_t>
39 : PGD<data_t>::PGD(const LinearOperator<data_t>& A, const DataContainer<data_t>& b,
40 : const DataContainer<data_t>& W, const Functional<data_t>& h,
41 : std::optional<data_t> mu, data_t epsilon)
42 : : g_(WeightedLeastSquares<data_t>(A, b, W).clone()), h_(h.clone()), epsilon_(epsilon)
43 2 : {
44 2 : if (!h.isProxFriendly()) {
45 0 : throw Error("APGD: h must be prox friendly");
46 0 : }
47 :
48 2 : if (mu.has_value()) {
49 2 : lineSearchMethod_ = FixedStepSize<data_t>(*g_, *mu).clone();
50 2 : } else {
51 0 : Logger::get("PGD")->info("Computing Lipschitz constant for least squares...");
52 : // Chose it a little larger, to be safe
53 0 : auto L = 1.05 * powerIterations(adjoint(A) * A);
54 0 : lineSearchMethod_ = FixedStepSize<data_t>(*g_, 1 / L).clone();
55 0 : Logger::get("PGD")->info("Step length chosen to be: {}", 1 / L);
56 0 : }
57 2 : }
58 :
59 : template <typename data_t>
60 : PGD<data_t>::PGD(const Functional<data_t>& g, const Functional<data_t>& h, data_t mu,
61 : data_t epsilon)
62 : : g_(g.clone()),
63 : h_(h.clone()),
64 : epsilon_(epsilon),
65 : lineSearchMethod_(FixedStepSize<data_t>(*g_, mu).clone())
66 0 : {
67 0 : if (!h.isProxFriendly()) {
68 0 : throw Error("PGD: h must be prox friendly");
69 0 : }
70 :
71 0 : if (!g.isDifferentiable()) {
72 0 : throw Error("PGD: g must be differentiable");
73 0 : }
74 0 : }
75 :
76 : template <typename data_t>
77 : PGD<data_t>::PGD(const LinearOperator<data_t>& A, const DataContainer<data_t>& b,
78 : const Functional<data_t>& h, const LineSearchMethod<data_t>& lineSearchMethod,
79 : data_t epsilon)
80 : : g_(LeastSquares<data_t>(A, b).clone()),
81 : h_(h.clone()),
82 : epsilon_(epsilon),
83 : lineSearchMethod_(lineSearchMethod.clone())
84 0 : {
85 0 : if (!h.isProxFriendly()) {
86 0 : throw Error("PGD: h must be prox friendly");
87 0 : }
88 0 : }
89 :
90 : template <typename data_t>
91 : PGD<data_t>::PGD(const LinearOperator<data_t>& A, const DataContainer<data_t>& b,
92 : const DataContainer<data_t>& W, const Functional<data_t>& h,
93 : const LineSearchMethod<data_t>& lineSearchMethod, data_t epsilon)
94 : : g_(WeightedLeastSquares<data_t>(A, b, W).clone()),
95 : h_(h.clone()),
96 : epsilon_(epsilon),
97 : lineSearchMethod_(lineSearchMethod.clone())
98 0 : {
99 0 : if (!h.isProxFriendly()) {
100 0 : throw Error("APGD: h must be prox friendly");
101 0 : }
102 0 : }
103 :
104 : template <typename data_t>
105 : PGD<data_t>::PGD(const Functional<data_t>& g, const Functional<data_t>& h,
106 : const LineSearchMethod<data_t>& lineSearchMethod, data_t epsilon)
107 : : g_(g.clone()),
108 : h_(h.clone()),
109 : epsilon_(epsilon),
110 : lineSearchMethod_(lineSearchMethod.clone())
111 6 : {
112 6 : if (!h.isProxFriendly()) {
113 0 : throw Error("PGD: h must be prox friendly");
114 0 : }
115 :
116 6 : if (!g.isDifferentiable()) {
117 0 : throw Error("PGD: g must be differentiable");
118 0 : }
119 6 : }
120 :
121 : template <typename data_t>
122 : auto PGD<data_t>::solve(index_t iterations, std::optional<DataContainer<data_t>> x0)
123 : -> DataContainer<data_t>
124 6 : {
125 6 : spdlog::stopwatch aggregate_time;
126 :
127 6 : auto x = extract_or(x0, g_->getDomainDescriptor());
128 6 : auto grad = emptylike(x);
129 6 : auto y = emptylike(x);
130 :
131 6 : Logger::get("PGD")->info("| {:^6} | {:^12} | {:^12} | {:^9} |", "iter", "g", "gradient",
132 6 : "elapsed");
133 :
134 12 : for (index_t iter = 0; iter < iterations; ++iter) {
135 12 : g_->getGradient(x, grad);
136 12 : auto mu = lineSearchMethod_->solve(x, -grad);
137 :
138 : // y = x - mu_ * grad
139 12 : lincomb(1, x, -mu, grad, y);
140 :
141 : // apply proximal
142 12 : x = h_->proximal(y, mu);
143 :
144 12 : if (grad.squaredL2Norm() <= epsilon_) {
145 6 : Logger::get("PGD")->info("SUCCESS: Reached convergence at {}/{} iteration",
146 6 : iter + 1, iterations);
147 6 : return x;
148 6 : }
149 :
150 6 : Logger::get("PGD")->info("| {:>6} | {:>12.3} | {:>12.3} | {:>8.3}s |", iter,
151 6 : g_->evaluate(x), grad.squaredL2Norm(), aggregate_time);
152 6 : }
153 :
154 6 : Logger::get("PGD")->warn("Failed to reach convergence at {} iterations", iterations);
155 :
156 0 : return x;
157 6 : }
158 :
159 : template <typename data_t>
160 : auto PGD<data_t>::cloneImpl() const -> PGD<data_t>*
161 6 : {
162 6 : return new PGD<data_t>(*g_, *h_, *lineSearchMethod_, epsilon_);
163 6 : }
164 :
165 : template <typename data_t>
166 : auto PGD<data_t>::isEqual(const Solver<data_t>& other) const -> bool
167 6 : {
168 6 : auto otherPGD = downcast_safe<PGD>(&other);
169 6 : if (!otherPGD)
170 0 : return false;
171 :
172 6 : if (not lineSearchMethod_->isEqual(*(otherPGD->lineSearchMethod_)))
173 0 : return false;
174 :
175 6 : if (epsilon_ != otherPGD->epsilon_)
176 0 : return false;
177 :
178 6 : return true;
179 6 : }
180 :
181 : // ------------------------------------------
182 : // explicit template instantiation
183 : template class PGD<float>;
184 : template class PGD<double>;
185 : } // namespace elsa
|