Line data Source code
1 : #include "LB.h"
2 : #include "DataContainer.h"
3 : #include "LinearOperator.h"
4 : #include "TypeCasts.hpp"
5 : #include "Logger.h"
6 :
7 : #include "spdlog/stopwatch.h"
8 : #include "PowerIterations.h"
9 :
10 : namespace elsa
11 : {
12 : template <typename data_t>
13 : LB<data_t>::LB(const LinearOperator<data_t>& A, const DataContainer<data_t>& b,
14 : const ProximalOperator<data_t>& prox, data_t mu, std::optional<data_t> beta,
15 : data_t epsilon)
16 : : A_(A.clone()), b_(b), prox_(prox), mu_(mu), epsilon_(epsilon)
17 4 : {
18 4 : if (!beta.has_value()) {
19 :
20 0 : beta_ = data_t{2} / (mu_ * powerIterations(adjoint(*A_) * (*A_)));
21 0 : Logger::get("LB")->info("Step length is chosen to be: {:8.5}", beta_);
22 :
23 4 : } else {
24 4 : beta_ = *beta;
25 4 : }
26 4 : }
27 :
28 : template <typename data_t>
29 : auto LB<data_t>::solve(index_t iterations, std::optional<DataContainer<data_t>> x0)
30 : -> DataContainer<data_t>
31 2 : {
32 2 : spdlog::stopwatch iter_time;
33 :
34 2 : auto v = DataContainer<data_t>(A_->getDomainDescriptor());
35 2 : auto x = DataContainer<data_t>(A_->getDomainDescriptor());
36 :
37 2 : v = 0;
38 :
39 2 : if (x0.has_value()) {
40 0 : x = *x0;
41 2 : } else {
42 2 : x = 0;
43 2 : }
44 :
45 2 : auto residual = DataContainer<data_t>(b_.getDataDescriptor());
46 :
47 20 : for (int i = 0; i < iterations; ++i) {
48 :
49 : // x^{k+1} = mu * prox(v^k, 1)
50 20 : x = mu_ * prox_.apply(v, 1);
51 :
52 : // residual = b - Ax^{k+1}
53 20 : lincomb(1, b_, -1, A_->apply(x), residual);
54 :
55 : // v^{k+1} += beta * A^{*}(b - Ax^{k+1})
56 20 : v += beta_ * A_->applyAdjoint(residual);
57 :
58 20 : auto error = residual.squaredL2Norm() / b_.squaredL2Norm();
59 20 : if (residual.squaredL2Norm() / b_.squaredL2Norm() <= epsilon_) {
60 2 : Logger::get("LB")->info("SUCCESS: Reached convergence at {}/{} iteration", i + 1,
61 2 : iterations);
62 2 : return x;
63 2 : }
64 :
65 18 : Logger::get("LB")->info(
66 18 : "|iter: {:>6} | x: {:>12} | v: {:>12} | error: {:>12} | time: {:>8.3} |", i,
67 18 : x.squaredL2Norm(), v.squaredL2Norm(), error, iter_time);
68 18 : }
69 :
70 2 : Logger::get("LB")->warn("Failed to reach convergence at {} iterations", iterations);
71 :
72 0 : return x;
73 2 : }
74 :
75 : template <typename data_t>
76 : auto LB<data_t>::cloneImpl() const -> LB<data_t>*
77 2 : {
78 2 : return new LB<data_t>(*A_, b_, prox_, mu_, beta_, epsilon_);
79 2 : }
80 :
81 : template <typename data_t>
82 : auto LB<data_t>::isEqual(const Solver<data_t>& other) const -> bool
83 2 : {
84 2 : auto otherLb = downcast_safe<LB>(&other);
85 2 : if (!otherLb)
86 0 : return false;
87 :
88 2 : if (*A_ != *otherLb->A_)
89 0 : return false;
90 :
91 2 : if (b_ != otherLb->b_)
92 0 : return false;
93 :
94 2 : Logger::get("LB")->info("beta: {}, {}", beta_, otherLb->beta_);
95 2 : if (std::abs(beta_ - otherLb->beta_) > 1e-5)
96 0 : return false;
97 :
98 2 : Logger::get("LB")->info("mu: {}, {}", mu_, otherLb->mu_);
99 2 : if (mu_ != otherLb->mu_)
100 0 : return false;
101 :
102 2 : Logger::get("LB")->info("epsilon: {}, {}", epsilon_, otherLb->epsilon_);
103 2 : if (epsilon_ != otherLb->epsilon_)
104 0 : return false;
105 :
106 2 : return true;
107 2 : }
108 :
109 : // ------------------------------------------
110 : // explicit template instantiation
111 : template class LB<float>;
112 : template class LB<double>;
113 : } // namespace elsa
|