Line data Source code
1 : #include "ALB.h"
2 : #include "DataContainer.h"
3 : #include "LinearOperator.h"
4 : #include "TypeCasts.hpp"
5 : #include "Logger.h"
6 :
7 : #include "spdlog/stopwatch.h"
8 : #include "PowerIterations.h"
9 :
10 : namespace elsa
11 : {
12 : template <typename data_t>
13 : ALB<data_t>::ALB(const LinearOperator<data_t>& A, const DataContainer<data_t>& b,
14 : ProximalOperator<data_t> prox, data_t mu, std::optional<data_t> beta,
15 : data_t epsilon)
16 : : A_(A.clone()),
17 : b_(b),
18 : v_(A_->getDomainDescriptor()),
19 : vPrev_(A_->getDomainDescriptor()),
20 : vTilda_(A_->getDomainDescriptor()),
21 : residual_(A_->getRangeDescriptor()),
22 : prox_(prox),
23 : mu_(mu),
24 : epsilon_(epsilon)
25 4 : {
26 4 : if (!beta.has_value()) {
27 0 : beta_ = data_t{2} / (mu_ * powerIterations(adjoint(*A_) * (*A_)));
28 0 : Logger::get("ALB")->info("Step length is chosen to be: {:8.5}", beta_);
29 4 : } else {
30 4 : beta_ = *beta;
31 4 : }
32 4 : }
33 :
34 : template <typename data_t>
35 : DataContainer<data_t> ALB<data_t>::setup(std::optional<DataContainer<data_t>> x0)
36 4 : {
37 4 : auto x = extract_or(x0, A_->getDomainDescriptor());
38 :
39 4 : v_ = emptylike(x);
40 4 : vPrev_ = zeroslike(x);
41 4 : vTilda_ = zeroslike(x);
42 :
43 4 : residual_ = emptylike(b_);
44 :
45 4 : return x;
46 4 : }
47 :
48 : template <typename data_t>
49 : DataContainer<data_t> ALB<data_t>::step(DataContainer<data_t> x)
50 68 : {
51 68 : vPrev_ = v_;
52 :
53 : // x^{k+1} = mu * prox(v^k, 1)
54 68 : x = mu_ * prox_.apply(vTilda_, 1);
55 :
56 : // residual = b - Ax^{k+1}
57 68 : lincomb(1, b_, -1, A_->apply(x), residual_);
58 :
59 : // v^{k+1} = v_tilda^k + beta * A^{*}(b - Ax^{k+1})
60 68 : lincomb(1, vTilda_, beta_, A_->applyAdjoint(residual_), v_);
61 :
62 : // a_k = (2i + 3) / (i + 3)
63 68 : auto a = static_cast<data_t>(2 * this->curiter_ + 3) / (this->curiter_ + 3);
64 :
65 : // v_tilda^{k+1} = a_k * v^{k+1} + (1 - a_k) * v^k
66 68 : lincomb(a, v_, (1 - a), vPrev_, vTilda_);
67 :
68 68 : return x;
69 68 : }
70 :
71 : template <typename data_t>
72 : bool ALB<data_t>::shouldStop() const
73 70 : {
74 70 : return this->curiter_ > 1 && residual_.squaredL2Norm() / b_.squaredL2Norm() <= epsilon_;
75 70 : }
76 :
77 : template <typename data_t>
78 : std::string ALB<data_t>::formatHeader() const
79 2 : {
80 2 : return fmt::format("{:^12} | {:^12} | {:^12} | {:^12}", "x-norm", "v-norm", "\\tilde{v}",
81 2 : "error");
82 2 : }
83 :
84 : template <typename data_t>
85 : std::string ALB<data_t>::formatStep(const DataContainer<data_t>& x) const
86 68 : {
87 68 : auto error = residual_.squaredL2Norm() / b_.squaredL2Norm();
88 68 : return fmt::format("{:>12} | {:>12} | {:>12} | {:>12}", x.squaredL2Norm(),
89 68 : v_.squaredL2Norm(), vTilda_.squaredL2Norm(), error);
90 68 : }
91 :
92 : template <typename data_t>
93 : auto ALB<data_t>::cloneImpl() const -> ALB<data_t>*
94 2 : {
95 2 : return new ALB<data_t>(*A_, b_, prox_, mu_, beta_, epsilon_);
96 2 : }
97 :
98 : template <typename data_t>
99 : auto ALB<data_t>::isEqual(const Solver<data_t>& other) const -> bool
100 2 : {
101 2 : auto otherAlb = downcast_safe<ALB>(&other);
102 2 : if (!otherAlb)
103 0 : return false;
104 :
105 2 : if (*A_ != *otherAlb->A_)
106 0 : return false;
107 :
108 2 : if (b_ != otherAlb->b_)
109 0 : return false;
110 :
111 2 : Logger::get("ALB")->info("beta: {}, {}", beta_, otherAlb->beta_);
112 2 : if (std::abs(beta_ - otherAlb->beta_) > 1e-5)
113 0 : return false;
114 :
115 2 : Logger::get("ALB")->info("mu: {}, {}", mu_, otherAlb->mu_);
116 2 : if (mu_ != otherAlb->mu_)
117 0 : return false;
118 :
119 2 : Logger::get("ALB")->info("epsilon: {}, {}", epsilon_, otherAlb->epsilon_);
120 2 : if (epsilon_ != otherAlb->epsilon_)
121 0 : return false;
122 :
123 2 : return true;
124 2 : }
125 :
126 : // ------------------------------------------
127 : // explicit template instantiation
128 : template class ALB<float>;
129 : template class ALB<double>;
130 : } // namespace elsa
|