Line data Source code
1 : #include "OGM.h"
2 : #include "DataContainer.h"
3 : #include "Functional.h"
4 : #include "TypeCasts.hpp"
5 : #include "Logger.h"
6 : #include "PowerIterations.h"
7 : #include "FixedStepSize.h"
8 :
9 : namespace elsa
10 : {
11 : template <typename data_t>
12 : OGM<data_t>::OGM(const Functional<data_t>& problem, data_t epsilon)
13 : : Solver<data_t>(),
14 : _problem(problem.clone()),
15 : _epsilon{epsilon},
16 : yOld{empty<data_t>(problem.getDomainDescriptor())},
17 : gradient{emptylike(yOld)}
18 4 : {
19 4 : if (!problem.isDifferentiable()) {
20 0 : throw InvalidArgumentError("OGM: Given problem is not differentiable!");
21 0 : }
22 4 : this->name_ = "OGM";
23 4 : }
24 :
25 : template <typename data_t>
26 : OGM<data_t>::OGM(const Functional<data_t>& problem,
27 : const LinearOperator<data_t>& preconditionerInverse, data_t epsilon)
28 : : Solver<data_t>(),
29 : _problem(problem.clone()),
30 : _epsilon{epsilon},
31 : _preconditionerInverse{preconditionerInverse.clone()},
32 : yOld{empty<data_t>(problem.getDomainDescriptor())},
33 : gradient{emptylike(yOld)}
34 0 : {
35 0 : if (!problem.isDifferentiable()) {
36 0 : throw InvalidArgumentError("OGM: Given problem is not differentiable!");
37 0 : }
38 :
39 : // check that preconditioner is compatible with problem
40 0 : if (_preconditionerInverse->getDomainDescriptor().getNumberOfCoefficients()
41 0 : != _problem->getDomainDescriptor().getNumberOfCoefficients()
42 0 : || _preconditionerInverse->getRangeDescriptor().getNumberOfCoefficients()
43 0 : != _problem->getDomainDescriptor().getNumberOfCoefficients()) {
44 0 : throw InvalidArgumentError("OGM: incorrect size of preconditioner");
45 0 : }
46 0 : this->name_ = "OGM";
47 0 : }
48 : template <typename data_t>
49 : OGM<data_t>::OGM(const Functional<data_t>& problem,
50 : const LineSearchMethod<data_t>& lineSearchMethod, data_t epsilon)
51 : : Solver<data_t>(),
52 : _problem(problem.clone()),
53 : _epsilon{epsilon},
54 : _lineSearchMethod(lineSearchMethod.clone()),
55 : yOld{empty<data_t>(problem.getDomainDescriptor())},
56 : gradient{emptylike(yOld)}
57 0 : {
58 0 : if (!problem.isDifferentiable()) {
59 0 : throw InvalidArgumentError("OGM: Given problem is not differentiable!");
60 0 : }
61 0 : this->name_ = "OGM";
62 0 : }
63 :
64 : template <typename data_t>
65 : OGM<data_t>::OGM(const Functional<data_t>& problem,
66 : const LinearOperator<data_t>& preconditionerInverse,
67 : const LineSearchMethod<data_t>& lineSearchMethod, data_t epsilon)
68 : : Solver<data_t>(),
69 : _problem(problem.clone()),
70 : _epsilon{epsilon},
71 : _preconditionerInverse{preconditionerInverse.clone()},
72 : _lineSearchMethod(lineSearchMethod.clone()),
73 : yOld{empty<data_t>(problem.getDomainDescriptor())},
74 : gradient{emptylike(yOld)}
75 0 : {
76 0 : if (!problem.isDifferentiable()) {
77 0 : throw InvalidArgumentError("OGM: Given problem is not differentiable!");
78 0 : }
79 :
80 : // check that preconditioner is compatible with problem
81 0 : if (_preconditionerInverse->getDomainDescriptor().getNumberOfCoefficients()
82 0 : != _problem->getDomainDescriptor().getNumberOfCoefficients()
83 0 : || _preconditionerInverse->getRangeDescriptor().getNumberOfCoefficients()
84 0 : != _problem->getDomainDescriptor().getNumberOfCoefficients()) {
85 0 : throw InvalidArgumentError("OGM: incorrect size of preconditioner");
86 0 : }
87 0 : this->name_ = "OGM";
88 0 : }
89 :
90 : template <typename data_t>
91 : bool OGM<data_t>::shouldStop() const
92 400 : {
93 400 : return this->gradient.squaredL2Norm() <= _epsilon * _epsilon * deltaZero;
94 400 : }
95 :
96 : template <typename data_t>
97 : DataContainer<data_t> OGM<data_t>::setup(std::optional<DataContainer<data_t>> x0)
98 2 : {
99 2 : this->thetaOld = data_t{1};
100 2 : auto x = extract_or(x0, _problem->getDomainDescriptor());
101 2 : this->yOld = x;
102 2 : this->gradient = _problem->getGradient(x);
103 2 : this->deltaZero = this->gradient.squaredL2Norm();
104 :
105 2 : if (!_lineSearchMethod) {
106 : // OGM is very picky when it comes to the accuracy of the used lipschitz constant
107 : // therefore we use 20 power iterations instead of 5 here to be more precise. In some
108 : // cases OGM might still not converge then an even more precise constant is needed
109 2 : auto L = powerIterations(_problem->getHessian(x), 20);
110 2 : this->_lineSearchMethod = std::make_unique<FixedStepSize<data_t>>(*_problem, 1 / L);
111 :
112 2 : Logger::get("OGM")->info("Starting optimization with lipschitz constant {}", L);
113 2 : } else {
114 0 : Logger::get("OGM")->info("Starting optimization with a lineSearchMethod");
115 0 : }
116 :
117 2 : this->configured_ = true;
118 :
119 2 : return x;
120 2 : }
121 :
122 : template <typename data_t>
123 : std::string OGM<data_t>::formatHeader() const
124 2 : {
125 2 : return fmt::format("| {:^13} | {:^13} |", "objective", "gradient");
126 2 : }
127 :
128 : template <typename data_t>
129 : std::string OGM<data_t>::formatStep(const DataContainer<data_t>& x) const
130 400 : {
131 400 : auto eval = _problem->evaluate(x);
132 400 : auto gradient = _problem->getGradient(x);
133 :
134 400 : return fmt::format("| {:>13} | {:>13} |", eval, gradient.squaredL2Norm());
135 400 : }
136 :
137 : template <typename data_t>
138 : DataContainer<data_t> OGM<data_t>::step(DataContainer<data_t> x)
139 400 : {
140 :
141 400 : auto y = emptylike(x);
142 :
143 400 : this->_problem->getGradient(x, this->gradient);
144 :
145 400 : if (this->_preconditionerInverse)
146 0 : this->_preconditionerInverse->apply(this->gradient, this->gradient);
147 :
148 400 : auto alpha = this->_lineSearchMethod->solve(x, -this->gradient);
149 :
150 400 : lincomb(1, x, -alpha, this->gradient, y);
151 :
152 : // TODO: Consult David on how to handle this in the steppable interface
153 400 : const auto theta =
154 400 : data_t{0.5} * (data_t{1} + std::sqrt(data_t{1} + data_t{4} * std::pow(thetaOld, 2)));
155 :
156 : // x_{i+1} = y_{i+1} + \frac{\theta_i-1}{\theta_{i+1}}(y_{i+1} - y_i) +
157 : // \frac{\theta_i}{\theta_{i+1}}/(y_{i+1} - x_i)
158 400 : lincomb(1, y, ((this->thetaOld - data_t{1}) / theta), (y - this->yOld), x);
159 400 : lincomb(1, x,
160 400 : -(this->thetaOld / theta) * this->_lineSearchMethod->solve(x, -this->gradient),
161 400 : this->gradient, x);
162 :
163 400 : this->thetaOld = theta;
164 400 : this->yOld = y;
165 :
166 400 : return x;
167 400 : }
168 :
169 : template <typename data_t>
170 : OGM<data_t>* OGM<data_t>::cloneImpl() const
171 2 : {
172 2 : if (_lineSearchMethod and _preconditionerInverse) {
173 0 : return new OGM(*_problem, *_preconditionerInverse, *_lineSearchMethod, _epsilon);
174 2 : } else if (_preconditionerInverse) {
175 0 : return new OGM(*_problem, *_preconditionerInverse, _epsilon);
176 2 : } else if (_lineSearchMethod) {
177 0 : return new OGM(*_problem, *_lineSearchMethod, _epsilon);
178 0 : }
179 2 : return new OGM(*_problem, _epsilon);
180 2 : }
181 :
182 : template <typename data_t>
183 : bool OGM<data_t>::isEqual(const Solver<data_t>& other) const
184 2 : {
185 2 : auto otherOGM = downcast_safe<OGM>(&other);
186 2 : if (!otherOGM)
187 0 : return false;
188 :
189 2 : if (_epsilon != otherOGM->_epsilon)
190 0 : return false;
191 :
192 2 : if ((_preconditionerInverse && !otherOGM->_preconditionerInverse)
193 2 : || (!_preconditionerInverse && otherOGM->_preconditionerInverse))
194 0 : return false;
195 :
196 2 : if (_preconditionerInverse && otherOGM->_preconditionerInverse)
197 0 : if (*_preconditionerInverse != *otherOGM->_preconditionerInverse)
198 0 : return false;
199 :
200 2 : if ((_lineSearchMethod and not otherOGM->_lineSearchMethod)
201 2 : or (not _lineSearchMethod and otherOGM->_lineSearchMethod))
202 0 : return false;
203 :
204 2 : if (_lineSearchMethod and otherOGM->_lineSearchMethod)
205 0 : if (not _lineSearchMethod->isEqual(*(otherOGM->_lineSearchMethod)))
206 0 : return false;
207 :
208 2 : return true;
209 2 : }
210 :
211 : // ------------------------------------------
212 : // explicit template instantiation
213 : template class OGM<float>;
214 : template class OGM<double>;
215 :
216 : } // namespace elsa
|