Line data Source code
1 : #include "FGM.h"
2 : #include "DataContainer.h"
3 : #include "Error.h"
4 : #include "Functional.h"
5 : #include "Logger.h"
6 : #include "TypeCasts.hpp"
7 : #include "PowerIterations.h"
8 : #include "FixedStepSize.h"
9 :
10 : namespace elsa
11 : {
12 : template <typename data_t>
13 : FGM<data_t>::FGM(const Functional<data_t>& problem, data_t epsilon)
14 : : Solver<data_t>(),
15 : _problem{problem.clone()},
16 : _epsilon{epsilon},
17 : yOld{empty<data_t>(problem.getDomainDescriptor())},
18 : gradient{emptylike(yOld)}
19 :
20 2 : {
21 2 : if (!problem.isDifferentiable()) {
22 0 : throw InvalidArgumentError("FGM: Given problem is not differentiable!");
23 0 : }
24 2 : this->name_ = "FGM";
25 2 : }
26 :
27 : template <typename data_t>
28 : FGM<data_t>::FGM(const Functional<data_t>& problem,
29 : const LinearOperator<data_t>& preconditionerInverse, data_t epsilon)
30 : : Solver<data_t>(),
31 : _problem{problem.clone()},
32 : _epsilon{epsilon},
33 : _preconditionerInverse{preconditionerInverse.clone()},
34 : yOld{empty<data_t>(problem.getDomainDescriptor())},
35 : gradient{emptylike(yOld)}
36 0 : {
37 0 : if (!problem.isDifferentiable()) {
38 0 : throw InvalidArgumentError("FGM: Given problem is not differentiable!");
39 0 : }
40 :
41 : // check that preconditioner is compatible with problem
42 0 : if (_preconditionerInverse->getDomainDescriptor().getNumberOfCoefficients()
43 0 : != _problem->getDomainDescriptor().getNumberOfCoefficients()
44 0 : || _preconditionerInverse->getRangeDescriptor().getNumberOfCoefficients()
45 0 : != _problem->getDomainDescriptor().getNumberOfCoefficients()) {
46 0 : throw InvalidArgumentError("FGM: incorrect size of preconditioner");
47 0 : }
48 0 : this->name_ = "FGM";
49 0 : }
50 : template <typename data_t>
51 : FGM<data_t>::FGM(const Functional<data_t>& problem,
52 : const LineSearchMethod<data_t>& lineSearchMethod, data_t epsilon)
53 : : Solver<data_t>(),
54 : _problem{problem.clone()},
55 : _epsilon{epsilon},
56 : _lineSearchMethod{lineSearchMethod.clone()},
57 : yOld{empty<data_t>(problem.getDomainDescriptor())},
58 : gradient{emptylike(yOld)}
59 2 : {
60 2 : if (!problem.isDifferentiable()) {
61 0 : throw InvalidArgumentError("FGM: Given problem is not differentiable!");
62 0 : }
63 2 : this->name_ = "FGM";
64 2 : }
65 :
66 : template <typename data_t>
67 : FGM<data_t>::FGM(const Functional<data_t>& problem,
68 : const LinearOperator<data_t>& preconditionerInverse,
69 : const LineSearchMethod<data_t>& lineSearchMethod, data_t epsilon)
70 : : Solver<data_t>(),
71 : _problem{problem.clone()},
72 : _epsilon{epsilon},
73 : _preconditionerInverse{preconditionerInverse.clone()},
74 : _lineSearchMethod{lineSearchMethod.clone()},
75 : yOld{empty<data_t>(problem.getDomainDescriptor())},
76 : gradient{emptylike(yOld)}
77 0 : {
78 : // check that preconditioner is compatible with problem
79 0 : if (_preconditionerInverse->getDomainDescriptor().getNumberOfCoefficients()
80 0 : != _problem->getDomainDescriptor().getNumberOfCoefficients()
81 0 : || _preconditionerInverse->getRangeDescriptor().getNumberOfCoefficients()
82 0 : != _problem->getDomainDescriptor().getNumberOfCoefficients()) {
83 0 : throw InvalidArgumentError("FGM: incorrect size of preconditioner");
84 0 : }
85 0 : this->name_ = "FGM";
86 0 : }
87 : template <typename data_t>
88 : bool FGM<data_t>::shouldStop() const
89 6 : {
90 6 : return gradient.squaredL2Norm() <= _epsilon * _epsilon * deltaZero;
91 6 : }
92 :
93 : template <typename data_t>
94 : DataContainer<data_t> FGM<data_t>::setup(std::optional<DataContainer<data_t>> x0)
95 2 : {
96 2 : auto x = extract_or(x0, _problem->getDomainDescriptor());
97 :
98 2 : this->thetaOld = static_cast<data_t>(1.0);
99 2 : this->yOld = x;
100 2 : this->gradient = _problem->getGradient(x);
101 2 : this->deltaZero = this->gradient.squaredL2Norm();
102 :
103 2 : if (!_lineSearchMethod) {
104 2 : auto L = powerIterations(_problem->getHessian(x), 5);
105 2 : this->_lineSearchMethod = std::make_unique<FixedStepSize<data_t>>(*_problem, 1 / L);
106 2 : Logger::get("FGM")->info("Starting optimization with lipschitz constant {}", L);
107 2 : } else {
108 0 : Logger::get("FGM")->info("Starting optimization with a lineSearchMethod");
109 0 : }
110 2 : this->configured_ = true;
111 2 : return x;
112 2 : }
113 :
114 : template <typename data_t>
115 : DataContainer<data_t> FGM<data_t>::step(DataContainer<data_t> x)
116 4 : {
117 :
118 4 : this->gradient = _problem->getGradient(x);
119 :
120 4 : if (_preconditionerInverse)
121 0 : this->gradient = _preconditionerInverse->apply(this->gradient);
122 :
123 4 : auto alpha = _lineSearchMethod->solve(x, -this->gradient);
124 :
125 4 : auto y = emptylike(x);
126 4 : lincomb(1, x, -alpha, gradient, y);
127 :
128 4 : const auto theta =
129 4 : (data_t{1} + std::sqrt(data_t{1} + data_t{4} * thetaOld * thetaOld)) / data_t{2};
130 :
131 4 : lincomb(1, y, (thetaOld - data_t{1}) / theta, (y - yOld), x);
132 :
133 4 : this->thetaOld = theta;
134 4 : this->yOld = y;
135 :
136 4 : return x;
137 4 : }
138 :
139 : template <typename data_t>
140 : std::string FGM<data_t>::formatHeader() const
141 2 : {
142 2 : return fmt::format("| {:^13} | {:^13} |", "objective", "gradient");
143 2 : }
144 :
145 : template <typename data_t>
146 : std::string FGM<data_t>::formatStep(const DataContainer<data_t>& x) const
147 4 : {
148 4 : auto eval = _problem->evaluate(x);
149 4 : auto gradient = _problem->getGradient(x);
150 :
151 4 : return fmt::format("| {:>13} | {:>13} |", eval, gradient.squaredL2Norm());
152 4 : }
153 :
154 : template <typename data_t>
155 : FGM<data_t>* FGM<data_t>::cloneImpl() const
156 2 : {
157 2 : if (_lineSearchMethod and _preconditionerInverse) {
158 0 : return new FGM(*_problem, *_preconditionerInverse, *_lineSearchMethod, _epsilon);
159 2 : } else if (_preconditionerInverse) {
160 0 : return new FGM(*_problem, *_preconditionerInverse, _epsilon);
161 2 : } else if (_lineSearchMethod) {
162 2 : return new FGM(*_problem, *_lineSearchMethod, _epsilon);
163 2 : }
164 :
165 0 : return new FGM(*_problem, _epsilon);
166 0 : }
167 :
168 : template <typename data_t>
169 : bool FGM<data_t>::isEqual(const Solver<data_t>& other) const
170 2 : {
171 2 : auto otherFGM = downcast_safe<FGM>(&other);
172 2 : if (!otherFGM)
173 0 : return false;
174 :
175 2 : if (_epsilon != otherFGM->_epsilon)
176 0 : return false;
177 :
178 2 : if ((_preconditionerInverse && !otherFGM->_preconditionerInverse)
179 2 : || (!_preconditionerInverse && otherFGM->_preconditionerInverse))
180 0 : return false;
181 :
182 2 : if (_preconditionerInverse && otherFGM->_preconditionerInverse)
183 0 : if (*_preconditionerInverse != *otherFGM->_preconditionerInverse)
184 0 : return false;
185 :
186 2 : if ((_lineSearchMethod and not otherFGM->_lineSearchMethod)
187 2 : or (not _lineSearchMethod and otherFGM->_lineSearchMethod))
188 0 : return false;
189 :
190 2 : if (_lineSearchMethod and otherFGM->_lineSearchMethod)
191 2 : if (not _lineSearchMethod->isEqual(*(otherFGM->_lineSearchMethod)))
192 0 : return false;
193 :
194 2 : return true;
195 2 : }
196 :
197 : // ------------------------------------------
198 : // explicit template instantiation
199 : template class FGM<float>;
200 : template class FGM<double>;
201 :
202 : } // namespace elsa
|