Line data Source code
1 : #include "LinearOperator.h"
2 :
3 : #include <stdexcept>
4 : #include <typeinfo>
5 :
6 : #include "DescriptorUtils.h"
7 :
8 : namespace elsa
9 : {
10 : template <typename data_t>
11 : LinearOperator<data_t>::LinearOperator(const DataDescriptor& domainDescriptor,
12 : const DataDescriptor& rangeDescriptor)
13 : : _domainDescriptor{domainDescriptor.clone()}, _rangeDescriptor{rangeDescriptor.clone()}
14 11595 : {
15 11595 : }
16 :
17 : template <typename data_t>
18 : LinearOperator<data_t>::LinearOperator(const LinearOperator<data_t>& other)
19 : : Cloneable<LinearOperator<data_t>>(),
20 : _domainDescriptor{other._domainDescriptor->clone()},
21 : _rangeDescriptor{other._rangeDescriptor->clone()},
22 : _scalar{other._scalar},
23 : _isLeaf{other._isLeaf},
24 : _isAdjoint{other._isAdjoint},
25 : _isComposite{other._isComposite},
26 : _mode{other._mode}
27 1093 : {
28 1093 : if (_isLeaf)
29 8 : _lhs = other._lhs->clone();
30 :
31 1093 : if (_isComposite) {
32 224 : if (_mode == CompositeMode::ADD || _mode == CompositeMode::MULT) {
33 12 : _lhs = other._lhs->clone();
34 12 : _rhs = other._rhs->clone();
35 12 : }
36 :
37 224 : if (_mode == CompositeMode::SCALAR_MULT) {
38 212 : _rhs = other._rhs->clone();
39 212 : }
40 224 : }
41 1093 : }
42 :
43 : template <typename data_t>
44 : LinearOperator<data_t>& LinearOperator<data_t>::operator=(const LinearOperator<data_t>& other)
45 24 : {
46 24 : if (*this != other) {
47 24 : _domainDescriptor = other._domainDescriptor->clone();
48 24 : _rangeDescriptor = other._rangeDescriptor->clone();
49 24 : _scalar = other._scalar;
50 24 : _isLeaf = other._isLeaf;
51 24 : _isAdjoint = other._isAdjoint;
52 24 : _isComposite = other._isComposite;
53 24 : _mode = other._mode;
54 :
55 24 : if (_isLeaf)
56 8 : _lhs = other._lhs->clone();
57 :
58 24 : if (_isComposite) {
59 16 : if (_mode == CompositeMode::ADD || _mode == CompositeMode::MULT) {
60 12 : _lhs = other._lhs->clone();
61 12 : _rhs = other._rhs->clone();
62 12 : }
63 :
64 16 : if (_mode == CompositeMode::SCALAR_MULT) {
65 4 : _rhs = other._rhs->clone();
66 4 : }
67 16 : }
68 24 : }
69 :
70 24 : return *this;
71 24 : }
72 :
73 : template <typename data_t>
74 : const DataDescriptor& LinearOperator<data_t>::getDomainDescriptor() const
75 10163 : {
76 10163 : return *_domainDescriptor;
77 10163 : }
78 :
79 : template <typename data_t>
80 : const DataDescriptor& LinearOperator<data_t>::getRangeDescriptor() const
81 7316 : {
82 7316 : return *_rangeDescriptor;
83 7316 : }
84 :
85 : template <typename data_t>
86 : DataContainer<data_t> LinearOperator<data_t>::apply(const DataContainer<data_t>& x) const
87 10586 : {
88 10586 : DataContainer<data_t> result(*_rangeDescriptor);
89 10586 : apply(x, result);
90 10586 : return result;
91 10586 : }
92 :
93 : template <typename data_t>
94 : void LinearOperator<data_t>::apply(const DataContainer<data_t>& x,
95 : DataContainer<data_t>& Ax) const
96 19741 : {
97 19741 : applyImpl(x, Ax);
98 19741 : }
99 :
100 : template <typename data_t>
101 : void LinearOperator<data_t>::applyImpl(const DataContainer<data_t>& x,
102 : DataContainer<data_t>& Ax) const
103 854 : {
104 854 : if (_isLeaf) {
105 380 : if (_isAdjoint) {
106 : // sanity check the arguments for the intended evaluation tree leaf operation
107 328 : if (_lhs->getRangeDescriptor().getNumberOfCoefficients() != x.getSize()
108 328 : || _lhs->getDomainDescriptor().getNumberOfCoefficients() != Ax.getSize())
109 4 : throw InvalidArgumentError(
110 4 : "LinearOperator::apply: incorrect input/output sizes for adjoint leaf");
111 :
112 324 : _lhs->applyAdjoint(x, Ax);
113 324 : return;
114 324 : } else {
115 : // sanity check the arguments for the intended evaluation tree leaf operation
116 52 : if (_lhs->getDomainDescriptor().getNumberOfCoefficients() != x.getSize()
117 52 : || _lhs->getRangeDescriptor().getNumberOfCoefficients() != Ax.getSize())
118 4 : throw InvalidArgumentError(
119 4 : "LinearOperator::apply: incorrect input/output sizes for leaf");
120 :
121 48 : _lhs->apply(x, Ax);
122 48 : return;
123 48 : }
124 380 : }
125 :
126 474 : if (_isComposite) {
127 470 : if (_mode == CompositeMode::ADD) {
128 : // sanity check the arguments for the intended evaluation tree leaf operation
129 12 : if (_rhs->getDomainDescriptor().getNumberOfCoefficients() != x.getSize()
130 12 : || _rhs->getRangeDescriptor().getNumberOfCoefficients() != Ax.getSize()
131 12 : || _lhs->getDomainDescriptor().getNumberOfCoefficients() != x.getSize()
132 12 : || _lhs->getRangeDescriptor().getNumberOfCoefficients() != Ax.getSize())
133 4 : throw InvalidArgumentError(
134 4 : "LinearOperator::apply: incorrect input/output sizes for add leaf");
135 :
136 8 : _rhs->apply(x, Ax);
137 8 : Ax += _lhs->apply(x);
138 8 : return;
139 8 : }
140 :
141 458 : if (_mode == CompositeMode::MULT) {
142 : // sanity check the arguments for the intended evaluation tree leaf operation
143 402 : if (_rhs->getDomainDescriptor().getNumberOfCoefficients() != x.getSize()
144 402 : || _lhs->getRangeDescriptor().getNumberOfCoefficients() != Ax.getSize())
145 8 : throw InvalidArgumentError(
146 8 : "LinearOperator::apply: incorrect input/output sizes for mult leaf");
147 :
148 394 : DataContainer<data_t> temp(_rhs->getRangeDescriptor());
149 394 : _rhs->apply(x, temp);
150 394 : _lhs->apply(temp, Ax);
151 394 : return;
152 394 : }
153 :
154 56 : if (_mode == CompositeMode::SCALAR_MULT) {
155 : // sanity check the arguments for the intended evaluation tree leaf operation
156 56 : if (_rhs->getDomainDescriptor().getNumberOfCoefficients() != x.getSize())
157 4 : throw InvalidArgumentError("LinearOperator::apply: incorrect input/output "
158 4 : "sizes for scalar mult. leaf");
159 : // sanity check the scalar in the optional
160 52 : if (!_scalar.has_value())
161 0 : throw InvalidArgumentError(
162 0 : "LinearOperator::apply: no value found in the scalar optional");
163 :
164 52 : _rhs->apply(x, Ax);
165 52 : Ax *= _scalar.value();
166 52 : return;
167 52 : }
168 56 : }
169 :
170 4 : throw LogicError("LinearOperator: apply called on ill-formed object");
171 4 : }
172 :
173 : template <typename data_t>
174 : DataContainer<data_t> LinearOperator<data_t>::applyAdjoint(const DataContainer<data_t>& y) const
175 731 : {
176 731 : DataContainer<data_t> result(*_domainDescriptor);
177 731 : applyAdjoint(y, result);
178 731 : return result;
179 731 : }
180 :
181 : template <typename data_t>
182 : void LinearOperator<data_t>::applyAdjoint(const DataContainer<data_t>& y,
183 : DataContainer<data_t>& Aty) const
184 5241 : {
185 5241 : applyAdjointImpl(y, Aty);
186 5241 : }
187 :
188 : template <typename data_t>
189 : void LinearOperator<data_t>::applyAdjointImpl(const DataContainer<data_t>& y,
190 : DataContainer<data_t>& Aty) const
191 198 : {
192 198 : if (_isLeaf) {
193 20 : if (_isAdjoint) {
194 : // sanity check the arguments for the intended evaluation tree leaf operation
195 12 : if (_lhs->getDomainDescriptor().getNumberOfCoefficients() != y.getSize()
196 12 : || _lhs->getRangeDescriptor().getNumberOfCoefficients() != Aty.getSize())
197 4 : throw InvalidArgumentError("LinearOperator::applyAdjoint: incorrect "
198 4 : "input/output sizes for adjoint leaf");
199 :
200 8 : _lhs->apply(y, Aty);
201 8 : return;
202 8 : } else {
203 : // sanity check the arguments for the intended evaluation tree leaf operation
204 8 : if (_lhs->getRangeDescriptor().getNumberOfCoefficients() != y.getSize()
205 8 : || _lhs->getDomainDescriptor().getNumberOfCoefficients() != Aty.getSize())
206 4 : throw InvalidArgumentError(
207 4 : "LinearOperator::applyAdjoint: incorrect input/output sizes for leaf");
208 :
209 4 : _lhs->applyAdjoint(y, Aty);
210 4 : return;
211 4 : }
212 20 : }
213 :
214 178 : if (_isComposite) {
215 174 : if (_mode == CompositeMode::ADD) {
216 : // sanity check the arguments for the intended evaluation tree leaf operation
217 12 : if (_rhs->getRangeDescriptor().getNumberOfCoefficients() != y.getSize()
218 12 : || _rhs->getDomainDescriptor().getNumberOfCoefficients() != Aty.getSize()
219 12 : || _lhs->getRangeDescriptor().getNumberOfCoefficients() != y.getSize()
220 12 : || _lhs->getDomainDescriptor().getNumberOfCoefficients() != Aty.getSize())
221 4 : throw InvalidArgumentError(
222 4 : "LinearOperator::applyAdjoint: incorrect input/output sizes for add leaf");
223 :
224 8 : _rhs->applyAdjoint(y, Aty);
225 8 : Aty += _lhs->applyAdjoint(y);
226 8 : return;
227 8 : }
228 :
229 162 : if (_mode == CompositeMode::MULT) {
230 : // sanity check the arguments for the intended evaluation tree leaf operation
231 82 : if (_lhs->getRangeDescriptor().getNumberOfCoefficients() != y.getSize()
232 82 : || _rhs->getDomainDescriptor().getNumberOfCoefficients() != Aty.getSize())
233 8 : throw InvalidArgumentError(
234 8 : "LinearOperator::applyAdjoint: incorrect input/output sizes for mult leaf");
235 :
236 74 : DataContainer<data_t> temp(_lhs->getDomainDescriptor());
237 74 : _lhs->applyAdjoint(y, temp);
238 74 : _rhs->applyAdjoint(temp, Aty);
239 74 : return;
240 74 : }
241 :
242 80 : if (_mode == CompositeMode::SCALAR_MULT) {
243 : // sanity check the arguments for the intended evaluation tree leaf operation
244 80 : if (_rhs->getRangeDescriptor().getNumberOfCoefficients() != y.getSize())
245 4 : throw InvalidArgumentError("LinearOperator::apply: incorrect input/output "
246 4 : "sizes for scalar mult. leaf");
247 : // sanity check the scalar in the optional
248 76 : if (!_scalar.has_value())
249 0 : throw InvalidArgumentError(
250 0 : "LinearOperator::apply: no value found in the scalar optional");
251 :
252 76 : _rhs->applyAdjoint(y, Aty);
253 76 : Aty *= _scalar.value();
254 76 : return;
255 76 : }
256 80 : }
257 :
258 4 : throw LogicError("LinearOperator: applyAdjoint called on ill-formed object");
259 4 : }
260 :
261 : template <typename data_t>
262 : LinearOperator<data_t>* LinearOperator<data_t>::cloneImpl() const
263 1578 : {
264 1578 : if (_isLeaf)
265 534 : return new LinearOperator<data_t>(*_lhs, _isAdjoint);
266 :
267 1044 : if (_isComposite) {
268 1016 : if (_mode == CompositeMode::ADD || _mode == CompositeMode::MULT) {
269 808 : return new LinearOperator<data_t>(*_lhs, *_rhs, _mode);
270 808 : }
271 :
272 208 : if (_mode == CompositeMode::SCALAR_MULT) {
273 208 : return new LinearOperator<data_t>(*this);
274 208 : }
275 28 : }
276 :
277 28 : return new LinearOperator<data_t>(*_domainDescriptor, *_rangeDescriptor);
278 28 : }
279 :
280 : template <typename data_t>
281 : bool LinearOperator<data_t>::isEqual(const LinearOperator<data_t>& other) const
282 722 : {
283 722 : if (typeid(other) != typeid(*this))
284 8 : return false;
285 :
286 714 : if (*_domainDescriptor != *other._domainDescriptor
287 714 : || *_rangeDescriptor != *other._rangeDescriptor)
288 24 : return false;
289 :
290 690 : if (_isLeaf ^ other._isLeaf || _isComposite ^ other._isComposite)
291 0 : return false;
292 :
293 690 : if (_isLeaf)
294 148 : return (_isAdjoint == other._isAdjoint) && (*_lhs == *other._lhs);
295 :
296 542 : if (_isComposite) {
297 96 : if (_mode == CompositeMode::ADD || _mode == CompositeMode::MULT) {
298 76 : return _mode == other._mode && (*_lhs == *other._lhs) && (*_rhs == *other._rhs);
299 76 : }
300 :
301 20 : if (_mode == CompositeMode::SCALAR_MULT) {
302 20 : return (_isAdjoint == other._isAdjoint) && (*_rhs == *other._rhs);
303 20 : }
304 446 : }
305 :
306 446 : return true;
307 446 : }
308 :
309 : template <typename data_t>
310 : LinearOperator<data_t>::LinearOperator(const LinearOperator<data_t>& op, bool isAdjoint)
311 : : _domainDescriptor{(isAdjoint) ? op.getRangeDescriptor().clone()
312 : : op.getDomainDescriptor().clone()},
313 : _rangeDescriptor{(isAdjoint) ? op.getDomainDescriptor().clone()
314 : : op.getRangeDescriptor().clone()},
315 : _lhs{op.clone()},
316 : _scalar{op._scalar},
317 : _isLeaf{true},
318 : _isAdjoint{isAdjoint}
319 920 : {
320 920 : }
321 :
322 : template <typename data_t>
323 : LinearOperator<data_t>::LinearOperator(const LinearOperator<data_t>& lhs,
324 : const LinearOperator<data_t>& rhs, CompositeMode mode)
325 : : _domainDescriptor{mode == CompositeMode::MULT
326 : ? rhs.getDomainDescriptor().clone()
327 : : bestCommon(*lhs._domainDescriptor, *rhs._domainDescriptor)},
328 : _rangeDescriptor{mode == CompositeMode::MULT
329 : ? lhs.getRangeDescriptor().clone()
330 : : bestCommon(*lhs._rangeDescriptor, *rhs._rangeDescriptor)},
331 : _lhs{lhs.clone()},
332 : _rhs{rhs.clone()},
333 : _isComposite{true},
334 : _mode{mode}
335 1248 : {
336 : // sanity check the descriptors
337 1248 : switch (_mode) {
338 72 : case CompositeMode::ADD:
339 : /// feasibility checked by bestCommon()
340 72 : break;
341 :
342 1176 : case CompositeMode::MULT:
343 : // for multiplication, domain of _lhs should match range of _rhs
344 1176 : if (_lhs->getDomainDescriptor().getNumberOfCoefficients()
345 1176 : != _rhs->getRangeDescriptor().getNumberOfCoefficients())
346 0 : throw InvalidArgumentError(
347 0 : "LinearOperator: composite mult domain/range mismatch");
348 1176 : break;
349 :
350 1176 : default:
351 0 : throw LogicError("LinearOperator: unknown composition mode");
352 1248 : }
353 1248 : }
354 :
355 : template <typename data_t>
356 : LinearOperator<data_t>::LinearOperator(data_t scalar, const LinearOperator<data_t>& rhs)
357 : : _domainDescriptor{rhs.getDomainDescriptor().clone()},
358 : _rangeDescriptor{rhs.getRangeDescriptor().clone()},
359 : _rhs{rhs.clone()},
360 : _scalar{scalar},
361 : _isComposite{true},
362 : _mode{CompositeMode::SCALAR_MULT}
363 48 : {
364 48 : }
365 :
366 : // ------------------------------------------
367 : // explicit template instantiation
368 : template class LinearOperator<float>;
369 : template class LinearOperator<complex<float>>;
370 : template class LinearOperator<double>;
371 : template class LinearOperator<complex<double>>;
372 :
373 : } // namespace elsa
|