Line data Source code
1 : #include "LinearOperator.h"
2 :
3 : #include <stdexcept>
4 : #include <typeinfo>
5 :
6 : #include "DescriptorUtils.h"
7 :
8 : namespace elsa
9 : {
10 : template <typename data_t>
11 : LinearOperator<data_t>::LinearOperator(const DataDescriptor& domainDescriptor,
12 : const DataDescriptor& rangeDescriptor)
13 : : _domainDescriptor{domainDescriptor.clone()}, _rangeDescriptor{rangeDescriptor.clone()}
14 22343 : {
15 22343 : }
16 :
17 : template <typename data_t>
18 : LinearOperator<data_t>::LinearOperator(const LinearOperator<data_t>& other)
19 : : Cloneable<LinearOperator<data_t>>(),
20 : _domainDescriptor{other._domainDescriptor->clone()},
21 : _rangeDescriptor{other._rangeDescriptor->clone()},
22 : _scalar{other._scalar},
23 : _isLeaf{other._isLeaf},
24 : _isAdjoint{other._isAdjoint},
25 : _isComposite{other._isComposite},
26 : _mode{other._mode}
27 183 : {
28 183 : if (_isLeaf)
29 8 : _lhs = other._lhs->clone();
30 :
31 183 : if (_isComposite) {
32 152 : if (_mode == CompositeMode::ADD || _mode == CompositeMode::MULT) {
33 132 : _lhs = other._lhs->clone();
34 132 : _rhs = other._rhs->clone();
35 132 : }
36 :
37 152 : if (_mode == CompositeMode::SCALAR_MULT) {
38 20 : _rhs = other._rhs->clone();
39 20 : }
40 152 : }
41 183 : }
42 :
43 : template <typename data_t>
44 : LinearOperator<data_t>& LinearOperator<data_t>::operator=(const LinearOperator<data_t>& other)
45 90 : {
46 90 : if (*this != other) {
47 90 : _domainDescriptor = other._domainDescriptor->clone();
48 90 : _rangeDescriptor = other._rangeDescriptor->clone();
49 90 : _scalar = other._scalar;
50 90 : _isLeaf = other._isLeaf;
51 90 : _isAdjoint = other._isAdjoint;
52 90 : _isComposite = other._isComposite;
53 90 : _mode = other._mode;
54 :
55 90 : if (_isLeaf)
56 8 : _lhs = other._lhs->clone();
57 :
58 90 : if (_isComposite) {
59 82 : if (_mode == CompositeMode::ADD || _mode == CompositeMode::MULT) {
60 78 : _lhs = other._lhs->clone();
61 78 : _rhs = other._rhs->clone();
62 78 : }
63 :
64 82 : if (_mode == CompositeMode::SCALAR_MULT) {
65 4 : _rhs = other._rhs->clone();
66 4 : }
67 82 : }
68 90 : }
69 :
70 90 : return *this;
71 90 : }
72 :
73 : template <typename data_t>
74 : const DataDescriptor& LinearOperator<data_t>::getDomainDescriptor() const
75 73444 : {
76 73444 : return *_domainDescriptor;
77 73444 : }
78 :
79 : template <typename data_t>
80 : const DataDescriptor& LinearOperator<data_t>::getRangeDescriptor() const
81 58646 : {
82 58646 : return *_rangeDescriptor;
83 58646 : }
84 :
85 : template <typename data_t>
86 : DataContainer<data_t> LinearOperator<data_t>::apply(const DataContainer<data_t>& x) const
87 18441 : {
88 18441 : DataContainer<data_t> result(*_rangeDescriptor, x.getDataHandlerType());
89 18441 : apply(x, result);
90 18441 : return result;
91 18441 : }
92 :
93 : template <typename data_t>
94 : void LinearOperator<data_t>::apply(const DataContainer<data_t>& x,
95 : DataContainer<data_t>& Ax) const
96 48552 : {
97 48552 : applyImpl(x, Ax);
98 48552 : }
99 :
100 : template <typename data_t>
101 : void LinearOperator<data_t>::applyImpl(const DataContainer<data_t>& x,
102 : DataContainer<data_t>& Ax) const
103 12046 : {
104 12046 : if (_isLeaf) {
105 3633 : if (_isAdjoint) {
106 : // sanity check the arguments for the intended evaluation tree leaf operation
107 1405 : if (_lhs->getRangeDescriptor().getNumberOfCoefficients() != x.getSize()
108 1405 : || _lhs->getDomainDescriptor().getNumberOfCoefficients() != Ax.getSize())
109 4 : throw InvalidArgumentError(
110 4 : "LinearOperator::apply: incorrect input/output sizes for adjoint leaf");
111 :
112 1401 : _lhs->applyAdjoint(x, Ax);
113 1401 : return;
114 2228 : } else {
115 : // sanity check the arguments for the intended evaluation tree leaf operation
116 2228 : if (_lhs->getDomainDescriptor().getNumberOfCoefficients() != x.getSize()
117 2228 : || _lhs->getRangeDescriptor().getNumberOfCoefficients() != Ax.getSize())
118 4 : throw InvalidArgumentError(
119 4 : "LinearOperator::apply: incorrect input/output sizes for leaf");
120 :
121 2224 : _lhs->apply(x, Ax);
122 2224 : return;
123 2224 : }
124 3633 : }
125 :
126 8413 : if (_isComposite) {
127 8409 : if (_mode == CompositeMode::ADD) {
128 : // sanity check the arguments for the intended evaluation tree leaf operation
129 5289 : if (_rhs->getDomainDescriptor().getNumberOfCoefficients() != x.getSize()
130 5289 : || _rhs->getRangeDescriptor().getNumberOfCoefficients() != Ax.getSize()
131 5289 : || _lhs->getDomainDescriptor().getNumberOfCoefficients() != x.getSize()
132 5289 : || _lhs->getRangeDescriptor().getNumberOfCoefficients() != Ax.getSize())
133 4 : throw InvalidArgumentError(
134 4 : "LinearOperator::apply: incorrect input/output sizes for add leaf");
135 :
136 5285 : _rhs->apply(x, Ax);
137 5285 : Ax += _lhs->apply(x);
138 5285 : return;
139 5285 : }
140 :
141 3120 : if (_mode == CompositeMode::MULT) {
142 : // sanity check the arguments for the intended evaluation tree leaf operation
143 3112 : if (_rhs->getDomainDescriptor().getNumberOfCoefficients() != x.getSize()
144 3112 : || _lhs->getRangeDescriptor().getNumberOfCoefficients() != Ax.getSize())
145 8 : throw InvalidArgumentError(
146 8 : "LinearOperator::apply: incorrect input/output sizes for mult leaf");
147 :
148 3104 : DataContainer<data_t> temp(_rhs->getRangeDescriptor(), x.getDataHandlerType());
149 3104 : _rhs->apply(x, temp);
150 3104 : _lhs->apply(temp, Ax);
151 3104 : return;
152 3104 : }
153 :
154 8 : if (_mode == CompositeMode::SCALAR_MULT) {
155 : // sanity check the arguments for the intended evaluation tree leaf operation
156 8 : if (_rhs->getDomainDescriptor().getNumberOfCoefficients() != x.getSize())
157 4 : throw InvalidArgumentError("LinearOperator::apply: incorrect input/output "
158 4 : "sizes for scalar mult. leaf");
159 : // sanity check the scalar in the optional
160 4 : if (!_scalar.has_value())
161 0 : throw InvalidArgumentError(
162 0 : "LinearOperator::apply: no value found in the scalar optional");
163 :
164 4 : _rhs->apply(x, Ax);
165 4 : Ax *= _scalar.value();
166 4 : return;
167 4 : }
168 8 : }
169 :
170 4 : throw LogicError("LinearOperator: apply called on ill-formed object");
171 4 : }
172 :
173 : template <typename data_t>
174 : DataContainer<data_t> LinearOperator<data_t>::applyAdjoint(const DataContainer<data_t>& y) const
175 4208 : {
176 4208 : DataContainer<data_t> result(*_domainDescriptor, y.getDataHandlerType());
177 4208 : applyAdjoint(y, result);
178 4208 : return result;
179 4208 : }
180 :
181 : template <typename data_t>
182 : void LinearOperator<data_t>::applyAdjoint(const DataContainer<data_t>& y,
183 : DataContainer<data_t>& Aty) const
184 31188 : {
185 31188 : applyAdjointImpl(y, Aty);
186 31188 : }
187 :
188 : template <typename data_t>
189 : void LinearOperator<data_t>::applyAdjointImpl(const DataContainer<data_t>& y,
190 : DataContainer<data_t>& Aty) const
191 14641 : {
192 14641 : if (_isLeaf) {
193 11622 : if (_isAdjoint) {
194 : // sanity check the arguments for the intended evaluation tree leaf operation
195 12 : if (_lhs->getDomainDescriptor().getNumberOfCoefficients() != y.getSize()
196 12 : || _lhs->getRangeDescriptor().getNumberOfCoefficients() != Aty.getSize())
197 4 : throw InvalidArgumentError("LinearOperator::applyAdjoint: incorrect "
198 4 : "input/output sizes for adjoint leaf");
199 :
200 8 : _lhs->apply(y, Aty);
201 8 : return;
202 11610 : } else {
203 : // sanity check the arguments for the intended evaluation tree leaf operation
204 11610 : if (_lhs->getRangeDescriptor().getNumberOfCoefficients() != y.getSize()
205 11610 : || _lhs->getDomainDescriptor().getNumberOfCoefficients() != Aty.getSize())
206 4 : throw InvalidArgumentError(
207 4 : "LinearOperator::applyAdjoint: incorrect input/output sizes for leaf");
208 :
209 11606 : _lhs->applyAdjoint(y, Aty);
210 11606 : return;
211 11606 : }
212 11622 : }
213 :
214 3019 : if (_isComposite) {
215 3015 : if (_mode == CompositeMode::ADD) {
216 : // sanity check the arguments for the intended evaluation tree leaf operation
217 2991 : if (_rhs->getRangeDescriptor().getNumberOfCoefficients() != y.getSize()
218 2991 : || _rhs->getDomainDescriptor().getNumberOfCoefficients() != Aty.getSize()
219 2991 : || _lhs->getRangeDescriptor().getNumberOfCoefficients() != y.getSize()
220 2991 : || _lhs->getDomainDescriptor().getNumberOfCoefficients() != Aty.getSize())
221 4 : throw InvalidArgumentError(
222 4 : "LinearOperator::applyAdjoint: incorrect input/output sizes for add leaf");
223 :
224 2987 : _rhs->applyAdjoint(y, Aty);
225 2987 : Aty += _lhs->applyAdjoint(y);
226 2987 : return;
227 2987 : }
228 :
229 24 : if (_mode == CompositeMode::MULT) {
230 : // sanity check the arguments for the intended evaluation tree leaf operation
231 16 : if (_lhs->getRangeDescriptor().getNumberOfCoefficients() != y.getSize()
232 16 : || _rhs->getDomainDescriptor().getNumberOfCoefficients() != Aty.getSize())
233 8 : throw InvalidArgumentError(
234 8 : "LinearOperator::applyAdjoint: incorrect input/output sizes for mult leaf");
235 :
236 8 : DataContainer<data_t> temp(_lhs->getDomainDescriptor(), y.getDataHandlerType());
237 8 : _lhs->applyAdjoint(y, temp);
238 8 : _rhs->applyAdjoint(temp, Aty);
239 8 : return;
240 8 : }
241 :
242 8 : if (_mode == CompositeMode::SCALAR_MULT) {
243 : // sanity check the arguments for the intended evaluation tree leaf operation
244 8 : if (_rhs->getRangeDescriptor().getNumberOfCoefficients() != y.getSize())
245 4 : throw InvalidArgumentError("LinearOperator::apply: incorrect input/output "
246 4 : "sizes for scalar mult. leaf");
247 : // sanity check the scalar in the optional
248 4 : if (!_scalar.has_value())
249 0 : throw InvalidArgumentError(
250 0 : "LinearOperator::apply: no value found in the scalar optional");
251 :
252 4 : _rhs->applyAdjoint(y, Aty);
253 4 : Aty *= _scalar.value();
254 4 : return;
255 4 : }
256 8 : }
257 :
258 4 : throw LogicError("LinearOperator: applyAdjoint called on ill-formed object");
259 4 : }
260 :
261 : template <typename data_t>
262 : LinearOperator<data_t>* LinearOperator<data_t>::cloneImpl() const
263 6055 : {
264 6055 : if (_isLeaf)
265 1854 : return new LinearOperator<data_t>(*_lhs, _isAdjoint);
266 :
267 4201 : if (_isComposite) {
268 4173 : if (_mode == CompositeMode::ADD || _mode == CompositeMode::MULT) {
269 4157 : return new LinearOperator<data_t>(*_lhs, *_rhs, _mode);
270 4157 : }
271 :
272 16 : if (_mode == CompositeMode::SCALAR_MULT) {
273 16 : return new LinearOperator<data_t>(*this);
274 16 : }
275 28 : }
276 :
277 28 : return new LinearOperator<data_t>(*_domainDescriptor, *_rangeDescriptor);
278 28 : }
279 :
280 : template <typename data_t>
281 : bool LinearOperator<data_t>::isEqual(const LinearOperator<data_t>& other) const
282 902 : {
283 902 : if (typeid(other) != typeid(*this))
284 8 : return false;
285 :
286 894 : if (*_domainDescriptor != *other._domainDescriptor
287 894 : || *_rangeDescriptor != *other._rangeDescriptor)
288 24 : return false;
289 :
290 870 : if (_isLeaf ^ other._isLeaf || _isComposite ^ other._isComposite)
291 42 : return false;
292 :
293 828 : if (_isLeaf)
294 154 : return (_isAdjoint == other._isAdjoint) && (*_lhs == *other._lhs);
295 :
296 674 : if (_isComposite) {
297 124 : if (_mode == CompositeMode::ADD || _mode == CompositeMode::MULT) {
298 104 : return _mode == other._mode && (*_lhs == *other._lhs) && (*_rhs == *other._rhs);
299 104 : }
300 :
301 20 : if (_mode == CompositeMode::SCALAR_MULT) {
302 20 : return (_isAdjoint == other._isAdjoint) && (*_rhs == *other._rhs);
303 20 : }
304 550 : }
305 :
306 550 : return true;
307 550 : }
308 :
309 : template <typename data_t>
310 : LinearOperator<data_t>::LinearOperator(const LinearOperator<data_t>& op, bool isAdjoint)
311 : : _domainDescriptor{(isAdjoint) ? op.getRangeDescriptor().clone()
312 : : op.getDomainDescriptor().clone()},
313 : _rangeDescriptor{(isAdjoint) ? op.getDomainDescriptor().clone()
314 : : op.getRangeDescriptor().clone()},
315 : _lhs{op.clone()},
316 : _scalar{op._scalar},
317 : _isLeaf{true},
318 : _isAdjoint{isAdjoint}
319 13292 : {
320 13292 : }
321 :
322 : template <typename data_t>
323 : LinearOperator<data_t>::LinearOperator(const LinearOperator<data_t>& lhs,
324 : const LinearOperator<data_t>& rhs, CompositeMode mode)
325 : : _domainDescriptor{mode == CompositeMode::MULT
326 : ? rhs.getDomainDescriptor().clone()
327 : : bestCommon(*lhs._domainDescriptor, *rhs._domainDescriptor)},
328 : _rangeDescriptor{mode == CompositeMode::MULT
329 : ? lhs.getRangeDescriptor().clone()
330 : : bestCommon(*lhs._rangeDescriptor, *rhs._rangeDescriptor)},
331 : _lhs{lhs.clone()},
332 : _rhs{rhs.clone()},
333 : _isComposite{true},
334 : _mode{mode}
335 4636 : {
336 : // sanity check the descriptors
337 4636 : switch (_mode) {
338 3379 : case CompositeMode::ADD:
339 : /// feasibility checked by bestCommon()
340 3379 : break;
341 :
342 1257 : case CompositeMode::MULT:
343 : // for multiplication, domain of _lhs should match range of _rhs
344 1257 : if (_lhs->getDomainDescriptor().getNumberOfCoefficients()
345 1257 : != _rhs->getRangeDescriptor().getNumberOfCoefficients())
346 0 : throw InvalidArgumentError(
347 0 : "LinearOperator: composite mult domain/range mismatch");
348 1257 : break;
349 :
350 1257 : default:
351 0 : throw LogicError("LinearOperator: unknown composition mode");
352 4636 : }
353 4636 : }
354 :
355 : template <typename data_t>
356 : LinearOperator<data_t>::LinearOperator(data_t scalar, const LinearOperator<data_t>& rhs)
357 : : _domainDescriptor{rhs.getDomainDescriptor().clone()},
358 : _rangeDescriptor{rhs.getRangeDescriptor().clone()},
359 : _rhs{rhs.clone()},
360 : _scalar{scalar},
361 : _isComposite{true},
362 : _mode{CompositeMode::SCALAR_MULT}
363 24 : {
364 24 : }
365 :
366 : // ------------------------------------------
367 : // explicit template instantiation
368 : template class LinearOperator<float>;
369 : template class LinearOperator<complex<float>>;
370 : template class LinearOperator<double>;
371 : template class LinearOperator<complex<double>>;
372 :
373 : } // namespace elsa
|