Line data Source code
1 : #include "SQS.h"
2 : #include "Identity.h"
3 : #include "Scaling.h"
4 : #include "Logger.h"
5 : #include "Solver.h"
6 : #include "TypeCasts.hpp"
7 :
8 : namespace elsa
9 : {
10 : template <typename data_t>
11 : SQS<data_t>::SQS(const LeastSquares<data_t>& problem,
12 : std::vector<std::unique_ptr<LeastSquares<data_t>>>&& subsets,
13 : bool momentumAcceleration, data_t epsilon)
14 : : Solver<data_t>(),
15 : fullProblem_(downcast<LeastSquares<data_t>>(problem.clone())),
16 : subsets_(std::move(subsets)),
17 : epsilon_{epsilon},
18 : momentumAcceleration_{momentumAcceleration},
19 : subsetMode_(!subsets.empty())
20 :
21 4 : {
22 4 : Logger::get("SQS")->info("SQS running in ordered subset mode");
23 4 : }
24 :
25 : template <typename data_t>
26 : SQS<data_t>::SQS(const LeastSquares<data_t>& problem,
27 : std::vector<std::unique_ptr<LeastSquares<data_t>>>&& subsets,
28 : const LinearOperator<data_t>& preconditioner, bool momentumAcceleration,
29 : data_t epsilon)
30 : : Solver<data_t>(),
31 : fullProblem_(downcast<LeastSquares<data_t>>(problem.clone())),
32 : subsets_(std::move(subsets)),
33 : epsilon_{epsilon},
34 : preconditioner_{preconditioner.clone()},
35 : momentumAcceleration_{momentumAcceleration},
36 : subsetMode_(!subsets.empty())
37 4 : {
38 4 : Logger::get("SQS")->info("SQS running in ordered subset mode");
39 :
40 : // check that preconditioner is compatible with problem
41 4 : if (preconditioner_->getDomainDescriptor().getNumberOfCoefficients()
42 4 : != fullProblem_->getDomainDescriptor().getNumberOfCoefficients()
43 4 : || preconditioner_->getRangeDescriptor().getNumberOfCoefficients()
44 4 : != fullProblem_->getDomainDescriptor().getNumberOfCoefficients()) {
45 0 : throw InvalidArgumentError("SQS: incorrect size of preconditioner");
46 0 : }
47 4 : }
48 :
49 : template <typename data_t>
50 : SQS<data_t>::SQS(const LeastSquares<data_t>& problem, bool momentumAcceleration, data_t epsilon)
51 : : Solver<data_t>(),
52 : fullProblem_(downcast<LeastSquares<data_t>>(problem.clone())),
53 : epsilon_{epsilon},
54 : momentumAcceleration_{momentumAcceleration},
55 : subsetMode_(false)
56 :
57 8 : {
58 8 : Logger::get("SQS")->info("SQS running in normal mode");
59 8 : }
60 :
61 : template <typename data_t>
62 : SQS<data_t>::SQS(const LeastSquares<data_t>& problem,
63 : const LinearOperator<data_t>& preconditioner, bool momentumAcceleration,
64 : data_t epsilon)
65 : : Solver<data_t>(),
66 : fullProblem_(downcast<LeastSquares<data_t>>(problem.clone())),
67 : epsilon_{epsilon},
68 : preconditioner_{preconditioner.clone()},
69 : momentumAcceleration_{momentumAcceleration},
70 : subsetMode_(false)
71 4 : {
72 4 : Logger::get("SQS")->info("SQS running in normal mode");
73 :
74 : // check that preconditioner is compatible with problem
75 4 : if (preconditioner_->getDomainDescriptor().getNumberOfCoefficients()
76 4 : != fullProblem_->getDomainDescriptor().getNumberOfCoefficients()
77 4 : || preconditioner_->getRangeDescriptor().getNumberOfCoefficients()
78 4 : != fullProblem_->getDomainDescriptor().getNumberOfCoefficients()) {
79 0 : throw InvalidArgumentError("SQS: incorrect size of preconditioner");
80 0 : }
81 4 : }
82 :
83 : template <typename data_t>
84 : DataContainer<data_t> SQS<data_t>::solve(index_t iterations,
85 : std::optional<DataContainer<data_t>> x0)
86 8 : {
87 8 : auto& domain = fullProblem_->getDomainDescriptor();
88 8 : auto x = extract_or(x0, domain);
89 :
90 8 : auto convergenceThreshold =
91 8 : fullProblem_->getGradient(x).squaredL2Norm() * epsilon_ * epsilon_;
92 :
93 8 : auto hessian = fullProblem_->getHessian(x);
94 :
95 8 : auto rowsum = hessian.apply(ones<data_t>(domain));
96 8 : rowsum = static_cast<data_t>(1.0) / rowsum;
97 8 : auto diag = Scaling<data_t>(rowsum);
98 :
99 8 : data_t tOld = 1;
100 8 : data_t t = 1;
101 8 : data_t tNew = 0;
102 :
103 8 : auto& z = x;
104 :
105 8 : DataContainer<data_t> xOld = x;
106 8 : auto gradient = empty<data_t>(domain);
107 :
108 8 : index_t nSubsets = subsetMode_ ? subsets_.size() : 1;
109 :
110 1008 : for (index_t i = 0; i < iterations; i++) {
111 1000 : Logger::get("SQS")->info("iteration {} of {}", i + 1, iterations);
112 :
113 2000 : for (index_t m = 0; m < nSubsets; m++) {
114 1000 : if (subsetMode_) {
115 0 : subsets_[m]->getGradient(x, gradient);
116 1000 : } else {
117 1000 : fullProblem_->getGradient(x, gradient);
118 1000 : }
119 :
120 1000 : if (preconditioner_) {
121 800 : preconditioner_->apply(gradient, gradient);
122 800 : }
123 :
124 : // TODO: element wise relu
125 1000 : if (momentumAcceleration_) {
126 1000 : tNew = as<data_t>(1)
127 1000 : + std::sqrt(as<data_t>(1) + as<data_t>(4) * t * t) / as<data_t>(2);
128 :
129 1000 : lincomb(1, z, -nSubsets, diag.apply(gradient), x);
130 1000 : lincomb(1, x, tOld / tNew, (x - xOld), z);
131 1000 : } else {
132 0 : lincomb(1, z, -nSubsets, diag.apply(gradient), z);
133 0 : }
134 :
135 : // if the gradient is too small we stop
136 1000 : if (gradient.squaredL2Norm() <= convergenceThreshold) {
137 0 : if (!subsetMode_
138 0 : || fullProblem_->getGradient(x).squaredL2Norm() <= convergenceThreshold) {
139 0 : Logger::get("SQS")->info("SUCCESS: Reached convergence at {}/{} iteration",
140 0 : i + 1, iterations);
141 :
142 : // TODO: make return more sane
143 0 : if (momentumAcceleration_) {
144 0 : z = x;
145 0 : }
146 0 : return x;
147 0 : }
148 1000 : }
149 :
150 1000 : if (momentumAcceleration_) {
151 1000 : tOld = t;
152 1000 : t = tNew;
153 1000 : xOld = x;
154 1000 : }
155 1000 : }
156 1000 : }
157 :
158 8 : Logger::get("SQS")->warn("Failed to reach convergence at {} iterations", iterations);
159 :
160 : // TODO: make return more sane
161 8 : if (momentumAcceleration_) {
162 8 : z = x;
163 8 : }
164 8 : return x;
165 8 : }
166 :
167 : template <typename data_t>
168 : SQS<data_t>* SQS<data_t>::cloneImpl() const
169 8 : {
170 8 : std::vector<std::unique_ptr<LeastSquares<data_t>>> newsubsets;
171 8 : newsubsets.reserve(subsets_.size());
172 8 : for (const auto& ptr : subsets_) {
173 0 : newsubsets.emplace_back(downcast<LeastSquares<data_t>>(ptr->clone()));
174 0 : }
175 :
176 8 : if (preconditioner_)
177 4 : return new SQS(*fullProblem_, std::move(newsubsets), *preconditioner_,
178 4 : momentumAcceleration_, epsilon_);
179 :
180 4 : return new SQS(*fullProblem_, std::move(newsubsets), momentumAcceleration_, epsilon_);
181 4 : }
182 :
183 : template <typename data_t>
184 : bool SQS<data_t>::isEqual(const Solver<data_t>& other) const
185 8 : {
186 8 : auto otherSQS = downcast_safe<SQS>(&other);
187 8 : if (!otherSQS)
188 0 : return false;
189 :
190 8 : if (epsilon_ != otherSQS->epsilon_)
191 0 : return false;
192 :
193 8 : if ((preconditioner_ && !otherSQS->preconditioner_)
194 8 : || (!preconditioner_ && otherSQS->preconditioner_))
195 0 : return false;
196 :
197 8 : if (preconditioner_ && otherSQS->preconditioner_)
198 4 : if (*preconditioner_ != *otherSQS->preconditioner_)
199 0 : return false;
200 :
201 8 : return true;
202 8 : }
203 :
204 : // ------------------------------------------
205 : // explicit template instantiation
206 : template class SQS<float>;
207 : template class SQS<double>;
208 :
209 : } // namespace elsa
|