Line data Source code
1 : #include "ADMML2.h"
2 :
3 : #include "DataContainer.h"
4 : #include "LinearOperator.h"
5 : #include "Solver.h"
6 : #include "ProximalOperator.h"
7 : #include "TypeCasts.hpp"
8 : #include "elsaDefines.h"
9 : #include "Logger.h"
10 : #include "RegularizedInversion.h"
11 : #include "PowerIterations.h"
12 :
13 : #include <cmath>
14 : #include <memory>
15 : #include <optional>
16 : #include <vector>
17 :
18 : namespace elsa
19 : {
20 : template <class data_t>
21 : ADMML2<data_t>::ADMML2(const LinearOperator<data_t>& op, const DataContainer<data_t>& b,
22 : const LinearOperator<data_t>& A, const ProximalOperator<data_t>& proxg,
23 : std::optional<data_t> tau, index_t ninneriters)
24 : : Solver<data_t>(),
25 : op_(op.clone()),
26 : b_(b),
27 : A_(A.clone()),
28 : proxg_(proxg),
29 : tau_(0),
30 : ninneriters_(ninneriters)
31 2 : {
32 2 : auto eigenval = data_t{1} / powerIterations(adjoint(A) * A);
33 :
34 2 : if (tau.has_value()) {
35 2 : tau_ = *tau;
36 2 : if (tau_ < 0 || tau_ > eigenval) {
37 0 : Logger::get("ADMML2")->info("tau ({:8.5}), should be between 0 and {:8.5}", tau_,
38 0 : eigenval);
39 0 : }
40 2 : } else {
41 0 : tau_ = 0.9 * eigenval;
42 0 : Logger::get("ADMML2")->info("tau is chosen {}", tau_, eigenval);
43 0 : }
44 2 : }
45 :
46 : template <class data_t>
47 : ADMML2<data_t>::ADMML2(const LinearOperator<data_t>& op, const DataContainer<data_t>& b,
48 : const DataContainer<data_t>& W, const LinearOperator<data_t>& A,
49 : const ProximalOperator<data_t>& proxg, std::optional<data_t> tau,
50 : index_t ninneriters)
51 : : Solver<data_t>(),
52 : op_(op.clone()),
53 : b_(b),
54 : A_(A.clone()),
55 : W_(W),
56 : proxg_(proxg),
57 : tau_(0),
58 : ninneriters_(ninneriters)
59 0 : {
60 0 : auto eigenval = data_t{1} / powerIterations(adjoint(A) * A);
61 :
62 0 : if (tau.has_value()) {
63 0 : tau_ = *tau;
64 0 : if (tau_ < 0 || tau_ > eigenval) {
65 0 : Logger::get("ADMML2")->info("tau ({:8.5}), should be between 0 and {:8.5}", tau_,
66 0 : eigenval);
67 0 : }
68 0 : } else {
69 0 : tau_ = 0.9 * eigenval;
70 0 : Logger::get("ADMML2")->info("tau is chosen {}", tau_, eigenval);
71 0 : }
72 0 : }
73 :
74 : template <class data_t>
75 : DataContainer<data_t> ADMML2<data_t>::solve(index_t iterations,
76 : std::optional<DataContainer<data_t>> x0)
77 2 : {
78 2 : auto x = extract_or(x0, op_->getDomainDescriptor());
79 :
80 2 : const auto& range = A_->getRangeDescriptor();
81 2 : auto z = zeros<data_t>(range);
82 2 : auto u = zeros<data_t>(range);
83 :
84 2 : auto Ax = empty<data_t>(range);
85 2 : auto tmp = empty<data_t>(range);
86 :
87 2 : auto sqrttau = data_t{1} / std::sqrt(tau_);
88 :
89 2 : auto loglevel = Logger::getLevel();
90 2 : Logger::get("ADMML2")->info("| {:^4} | {:^12} | {:^12} | {:^12} |", "iter", "f", "z", "u");
91 22 : for (index_t iter = 0; iter < iterations; ++iter) {
92 20 : Logger::setLevel(Logger::LogLevel::ERR);
93 :
94 : // x_{k+1} = \min_x 0.5 ||Op x - b||_2^2 + \frac{1}{2\tau}||Ax - z_k + u_k||_2^2
95 20 : x = reguarlizedInversion<data_t>(*op_, b_, *A_, z - u, sqrttau, ninneriters_, W_, x);
96 :
97 20 : Logger::setLevel(loglevel);
98 :
99 20 : A_->apply(x, Ax);
100 :
101 : // Ax_{k+1} + u_k
102 20 : lincomb(1, Ax, 1, u, tmp);
103 :
104 : // z_{k+1} = prox_{\tau * g}(Ax_{k+1} + u_k)
105 20 : z = proxg_.apply(tmp, tau_);
106 :
107 : // u_{k+1} = u_k + Ax_{k+1} - z_{k+1}
108 20 : u += Ax;
109 20 : u -= z;
110 :
111 20 : Logger::get("ADMML2")->info("| {:>4} | {:12.7} | {:12.7} | {:12.7} |", iter,
112 20 : 0.5 * (op_->apply(x) - b_).l2Norm(), z.l2Norm(),
113 20 : u.l2Norm());
114 20 : }
115 :
116 2 : return x;
117 2 : }
118 :
119 : template <class data_t>
120 : ADMML2<data_t>* ADMML2<data_t>::cloneImpl() const
121 0 : {
122 0 : return new ADMML2(*op_, b_, *A_, proxg_, tau_, ninneriters_);
123 0 : }
124 :
125 : template <class data_t>
126 : bool ADMML2<data_t>::isEqual(const Solver<data_t>& other) const
127 0 : {
128 0 : auto otherADMM = downcast_safe<ADMML2>(&other);
129 0 : if (!otherADMM)
130 0 : return false;
131 :
132 0 : if (*op_ != *otherADMM->op_)
133 0 : return false;
134 :
135 0 : if (*A_ != *otherADMM->A_)
136 0 : return false;
137 :
138 0 : if (tau_ != otherADMM->tau_)
139 0 : return false;
140 :
141 0 : if (ninneriters_ != otherADMM->ninneriters_)
142 0 : return false;
143 :
144 0 : return true;
145 0 : }
146 :
147 : // ------------------------------------------
148 : // explicit template instantiation
149 : template class ADMML2<float>;
150 : template class ADMML2<double>;
151 : } // namespace elsa
|