Line data Source code
1 : #include "LinearOperator.h"
2 :
3 : #include <stdexcept>
4 : #include <typeinfo>
5 :
6 : #include "DescriptorUtils.h"
7 :
8 : namespace elsa
9 : {
10 : template <typename data_t>
11 0 : LinearOperator<data_t>::LinearOperator(const DataDescriptor& domainDescriptor,
12 : const DataDescriptor& rangeDescriptor)
13 0 : : _domainDescriptor{domainDescriptor.clone()}, _rangeDescriptor{rangeDescriptor.clone()}
14 : {
15 0 : }
16 :
17 : template <typename data_t>
18 0 : LinearOperator<data_t>::LinearOperator(const LinearOperator<data_t>& other)
19 : : Cloneable<LinearOperator<data_t>>(),
20 0 : _domainDescriptor{other._domainDescriptor->clone()},
21 0 : _rangeDescriptor{other._rangeDescriptor->clone()},
22 0 : _scalar{other._scalar},
23 0 : _isLeaf{other._isLeaf},
24 0 : _isAdjoint{other._isAdjoint},
25 0 : _isComposite{other._isComposite},
26 0 : _mode{other._mode}
27 : {
28 0 : if (_isLeaf)
29 0 : _lhs = other._lhs->clone();
30 :
31 0 : if (_isComposite) {
32 0 : if (_mode == CompositeMode::ADD || _mode == CompositeMode::MULT) {
33 0 : _lhs = other._lhs->clone();
34 0 : _rhs = other._rhs->clone();
35 : }
36 :
37 0 : if (_mode == CompositeMode::SCALAR_MULT) {
38 0 : _rhs = other._rhs->clone();
39 : }
40 : }
41 0 : }
42 :
43 : template <typename data_t>
44 0 : LinearOperator<data_t>& LinearOperator<data_t>::operator=(const LinearOperator<data_t>& other)
45 : {
46 0 : if (*this != other) {
47 0 : _domainDescriptor = other._domainDescriptor->clone();
48 0 : _rangeDescriptor = other._rangeDescriptor->clone();
49 0 : _scalar = other._scalar;
50 0 : _isLeaf = other._isLeaf;
51 0 : _isAdjoint = other._isAdjoint;
52 0 : _isComposite = other._isComposite;
53 0 : _mode = other._mode;
54 :
55 0 : if (_isLeaf)
56 0 : _lhs = other._lhs->clone();
57 :
58 0 : if (_isComposite) {
59 0 : if (_mode == CompositeMode::ADD || _mode == CompositeMode::MULT) {
60 0 : _lhs = other._lhs->clone();
61 0 : _rhs = other._rhs->clone();
62 : }
63 :
64 0 : if (_mode == CompositeMode::SCALAR_MULT) {
65 0 : _rhs = other._rhs->clone();
66 : }
67 : }
68 : }
69 :
70 0 : return *this;
71 : }
72 :
73 : template <typename data_t>
74 0 : const DataDescriptor& LinearOperator<data_t>::getDomainDescriptor() const
75 : {
76 0 : return *_domainDescriptor;
77 : }
78 :
79 : template <typename data_t>
80 0 : const DataDescriptor& LinearOperator<data_t>::getRangeDescriptor() const
81 : {
82 0 : return *_rangeDescriptor;
83 : }
84 :
85 : template <typename data_t>
86 0 : DataContainer<data_t> LinearOperator<data_t>::apply(const DataContainer<data_t>& x) const
87 : {
88 0 : DataContainer<data_t> result(*_rangeDescriptor, x.getDataHandlerType());
89 0 : apply(x, result);
90 0 : return result;
91 0 : }
92 :
93 : template <typename data_t>
94 0 : void LinearOperator<data_t>::apply(const DataContainer<data_t>& x,
95 : DataContainer<data_t>& Ax) const
96 : {
97 0 : applyImpl(x, Ax);
98 0 : }
99 :
100 : template <typename data_t>
101 0 : void LinearOperator<data_t>::applyImpl(const DataContainer<data_t>& x,
102 : DataContainer<data_t>& Ax) const
103 : {
104 0 : if (_isLeaf) {
105 0 : if (_isAdjoint) {
106 : // sanity check the arguments for the intended evaluation tree leaf operation
107 0 : if (_lhs->getRangeDescriptor().getNumberOfCoefficients() != x.getSize()
108 0 : || _lhs->getDomainDescriptor().getNumberOfCoefficients() != Ax.getSize())
109 0 : throw InvalidArgumentError(
110 : "LinearOperator::apply: incorrect input/output sizes for adjoint leaf");
111 :
112 0 : _lhs->applyAdjoint(x, Ax);
113 0 : return;
114 : } else {
115 : // sanity check the arguments for the intended evaluation tree leaf operation
116 0 : if (_lhs->getDomainDescriptor().getNumberOfCoefficients() != x.getSize()
117 0 : || _lhs->getRangeDescriptor().getNumberOfCoefficients() != Ax.getSize())
118 0 : throw InvalidArgumentError(
119 : "LinearOperator::apply: incorrect input/output sizes for leaf");
120 :
121 0 : _lhs->apply(x, Ax);
122 0 : return;
123 : }
124 : }
125 :
126 0 : if (_isComposite) {
127 0 : if (_mode == CompositeMode::ADD) {
128 : // sanity check the arguments for the intended evaluation tree leaf operation
129 0 : if (_rhs->getDomainDescriptor().getNumberOfCoefficients() != x.getSize()
130 0 : || _rhs->getRangeDescriptor().getNumberOfCoefficients() != Ax.getSize()
131 0 : || _lhs->getDomainDescriptor().getNumberOfCoefficients() != x.getSize()
132 0 : || _lhs->getRangeDescriptor().getNumberOfCoefficients() != Ax.getSize())
133 0 : throw InvalidArgumentError(
134 : "LinearOperator::apply: incorrect input/output sizes for add leaf");
135 :
136 0 : _rhs->apply(x, Ax);
137 0 : Ax += _lhs->apply(x);
138 0 : return;
139 : }
140 :
141 0 : if (_mode == CompositeMode::MULT) {
142 : // sanity check the arguments for the intended evaluation tree leaf operation
143 0 : if (_rhs->getDomainDescriptor().getNumberOfCoefficients() != x.getSize()
144 0 : || _lhs->getRangeDescriptor().getNumberOfCoefficients() != Ax.getSize())
145 0 : throw InvalidArgumentError(
146 : "LinearOperator::apply: incorrect input/output sizes for mult leaf");
147 :
148 0 : DataContainer<data_t> temp(_rhs->getRangeDescriptor(), x.getDataHandlerType());
149 0 : _rhs->apply(x, temp);
150 0 : _lhs->apply(temp, Ax);
151 0 : return;
152 0 : }
153 :
154 0 : if (_mode == CompositeMode::SCALAR_MULT) {
155 : // sanity check the arguments for the intended evaluation tree leaf operation
156 0 : if (_rhs->getDomainDescriptor().getNumberOfCoefficients() != x.getSize())
157 0 : throw InvalidArgumentError("LinearOperator::apply: incorrect input/output "
158 : "sizes for scalar mult. leaf");
159 : // sanity check the scalar in the optional
160 0 : if (!_scalar.has_value())
161 0 : throw InvalidArgumentError(
162 : "LinearOperator::apply: no value found in the scalar optional");
163 :
164 0 : _rhs->apply(x, Ax);
165 0 : Ax *= _scalar.value();
166 0 : return;
167 : }
168 : }
169 :
170 0 : throw LogicError("LinearOperator: apply called on ill-formed object");
171 : }
172 :
173 : template <typename data_t>
174 0 : DataContainer<data_t> LinearOperator<data_t>::applyAdjoint(const DataContainer<data_t>& y) const
175 : {
176 0 : DataContainer<data_t> result(*_domainDescriptor, y.getDataHandlerType());
177 0 : applyAdjoint(y, result);
178 0 : return result;
179 0 : }
180 :
181 : template <typename data_t>
182 0 : void LinearOperator<data_t>::applyAdjoint(const DataContainer<data_t>& y,
183 : DataContainer<data_t>& Aty) const
184 : {
185 0 : applyAdjointImpl(y, Aty);
186 0 : }
187 :
188 : template <typename data_t>
189 0 : void LinearOperator<data_t>::applyAdjointImpl(const DataContainer<data_t>& y,
190 : DataContainer<data_t>& Aty) const
191 : {
192 0 : if (_isLeaf) {
193 0 : if (_isAdjoint) {
194 : // sanity check the arguments for the intended evaluation tree leaf operation
195 0 : if (_lhs->getDomainDescriptor().getNumberOfCoefficients() != y.getSize()
196 0 : || _lhs->getRangeDescriptor().getNumberOfCoefficients() != Aty.getSize())
197 0 : throw InvalidArgumentError("LinearOperator::applyAdjoint: incorrect "
198 : "input/output sizes for adjoint leaf");
199 :
200 0 : _lhs->apply(y, Aty);
201 0 : return;
202 : } else {
203 : // sanity check the arguments for the intended evaluation tree leaf operation
204 0 : if (_lhs->getRangeDescriptor().getNumberOfCoefficients() != y.getSize()
205 0 : || _lhs->getDomainDescriptor().getNumberOfCoefficients() != Aty.getSize())
206 0 : throw InvalidArgumentError(
207 : "LinearOperator::applyAdjoint: incorrect input/output sizes for leaf");
208 :
209 0 : _lhs->applyAdjoint(y, Aty);
210 0 : return;
211 : }
212 : }
213 :
214 0 : if (_isComposite) {
215 0 : if (_mode == CompositeMode::ADD) {
216 : // sanity check the arguments for the intended evaluation tree leaf operation
217 0 : if (_rhs->getRangeDescriptor().getNumberOfCoefficients() != y.getSize()
218 0 : || _rhs->getDomainDescriptor().getNumberOfCoefficients() != Aty.getSize()
219 0 : || _lhs->getRangeDescriptor().getNumberOfCoefficients() != y.getSize()
220 0 : || _lhs->getDomainDescriptor().getNumberOfCoefficients() != Aty.getSize())
221 0 : throw InvalidArgumentError(
222 : "LinearOperator::applyAdjoint: incorrect input/output sizes for add leaf");
223 :
224 0 : _rhs->applyAdjoint(y, Aty);
225 0 : Aty += _lhs->applyAdjoint(y);
226 0 : return;
227 : }
228 :
229 0 : if (_mode == CompositeMode::MULT) {
230 : // sanity check the arguments for the intended evaluation tree leaf operation
231 0 : if (_lhs->getRangeDescriptor().getNumberOfCoefficients() != y.getSize()
232 0 : || _rhs->getDomainDescriptor().getNumberOfCoefficients() != Aty.getSize())
233 0 : throw InvalidArgumentError(
234 : "LinearOperator::applyAdjoint: incorrect input/output sizes for mult leaf");
235 :
236 0 : DataContainer<data_t> temp(_lhs->getDomainDescriptor(), y.getDataHandlerType());
237 0 : _lhs->applyAdjoint(y, temp);
238 0 : _rhs->applyAdjoint(temp, Aty);
239 0 : return;
240 0 : }
241 :
242 0 : if (_mode == CompositeMode::SCALAR_MULT) {
243 : // sanity check the arguments for the intended evaluation tree leaf operation
244 0 : if (_rhs->getRangeDescriptor().getNumberOfCoefficients() != y.getSize())
245 0 : throw InvalidArgumentError("LinearOperator::apply: incorrect input/output "
246 : "sizes for scalar mult. leaf");
247 : // sanity check the scalar in the optional
248 0 : if (!_scalar.has_value())
249 0 : throw InvalidArgumentError(
250 : "LinearOperator::apply: no value found in the scalar optional");
251 :
252 0 : _rhs->applyAdjoint(y, Aty);
253 0 : Aty *= _scalar.value();
254 0 : return;
255 : }
256 : }
257 :
258 0 : throw LogicError("LinearOperator: applyAdjoint called on ill-formed object");
259 : }
260 :
261 : template <typename data_t>
262 0 : LinearOperator<data_t>* LinearOperator<data_t>::cloneImpl() const
263 : {
264 0 : if (_isLeaf)
265 0 : return new LinearOperator<data_t>(*_lhs, _isAdjoint);
266 :
267 0 : if (_isComposite) {
268 0 : if (_mode == CompositeMode::ADD || _mode == CompositeMode::MULT) {
269 0 : return new LinearOperator<data_t>(*_lhs, *_rhs, _mode);
270 : }
271 :
272 0 : if (_mode == CompositeMode::SCALAR_MULT) {
273 0 : return new LinearOperator<data_t>(*this);
274 : }
275 : }
276 :
277 0 : return new LinearOperator<data_t>(*_domainDescriptor, *_rangeDescriptor);
278 : }
279 :
280 : template <typename data_t>
281 0 : bool LinearOperator<data_t>::isEqual(const LinearOperator<data_t>& other) const
282 : {
283 0 : if (typeid(other) != typeid(*this))
284 0 : return false;
285 :
286 0 : if (*_domainDescriptor != *other._domainDescriptor
287 0 : || *_rangeDescriptor != *other._rangeDescriptor)
288 0 : return false;
289 :
290 0 : if (_isLeaf ^ other._isLeaf || _isComposite ^ other._isComposite)
291 0 : return false;
292 :
293 0 : if (_isLeaf)
294 0 : return (_isAdjoint == other._isAdjoint) && (*_lhs == *other._lhs);
295 :
296 0 : if (_isComposite) {
297 0 : if (_mode == CompositeMode::ADD || _mode == CompositeMode::MULT) {
298 0 : return _mode == other._mode && (*_lhs == *other._lhs) && (*_rhs == *other._rhs);
299 : }
300 :
301 0 : if (_mode == CompositeMode::SCALAR_MULT) {
302 0 : return (_isAdjoint == other._isAdjoint) && (*_rhs == *other._rhs);
303 : }
304 : }
305 :
306 0 : return true;
307 : }
308 :
309 : template <typename data_t>
310 0 : LinearOperator<data_t>::LinearOperator(const LinearOperator<data_t>& op, bool isAdjoint)
311 0 : : _domainDescriptor{(isAdjoint) ? op.getRangeDescriptor().clone()
312 0 : : op.getDomainDescriptor().clone()},
313 0 : _rangeDescriptor{(isAdjoint) ? op.getDomainDescriptor().clone()
314 0 : : op.getRangeDescriptor().clone()},
315 0 : _lhs{op.clone()},
316 0 : _scalar{op._scalar},
317 0 : _isLeaf{true},
318 0 : _isAdjoint{isAdjoint}
319 : {
320 0 : }
321 :
322 : template <typename data_t>
323 0 : LinearOperator<data_t>::LinearOperator(const LinearOperator<data_t>& lhs,
324 : const LinearOperator<data_t>& rhs, CompositeMode mode)
325 0 : : _domainDescriptor{mode == CompositeMode::MULT
326 0 : ? rhs.getDomainDescriptor().clone()
327 0 : : bestCommon(*lhs._domainDescriptor, *rhs._domainDescriptor)},
328 0 : _rangeDescriptor{mode == CompositeMode::MULT
329 0 : ? lhs.getRangeDescriptor().clone()
330 0 : : bestCommon(*lhs._rangeDescriptor, *rhs._rangeDescriptor)},
331 0 : _lhs{lhs.clone()},
332 0 : _rhs{rhs.clone()},
333 0 : _isComposite{true},
334 0 : _mode{mode}
335 : {
336 : // sanity check the descriptors
337 0 : switch (_mode) {
338 0 : case CompositeMode::ADD:
339 : /// feasibility checked by bestCommon()
340 0 : break;
341 :
342 0 : case CompositeMode::MULT:
343 : // for multiplication, domain of _lhs should match range of _rhs
344 0 : if (_lhs->getDomainDescriptor().getNumberOfCoefficients()
345 0 : != _rhs->getRangeDescriptor().getNumberOfCoefficients())
346 0 : throw InvalidArgumentError(
347 : "LinearOperator: composite mult domain/range mismatch");
348 0 : break;
349 :
350 0 : default:
351 0 : throw LogicError("LinearOperator: unknown composition mode");
352 : }
353 0 : }
354 :
355 : template <typename data_t>
356 0 : LinearOperator<data_t>::LinearOperator(data_t scalar, const LinearOperator<data_t>& rhs)
357 0 : : _domainDescriptor{rhs.getDomainDescriptor().clone()},
358 0 : _rangeDescriptor{rhs.getRangeDescriptor().clone()},
359 0 : _rhs{rhs.clone()},
360 0 : _scalar{scalar},
361 0 : _isComposite{true},
362 0 : _mode{CompositeMode::SCALAR_MULT}
363 : {
364 0 : }
365 :
366 : // ------------------------------------------
367 : // explicit template instantiation
368 : template class LinearOperator<float>;
369 : template class LinearOperator<complex<float>>;
370 : template class LinearOperator<double>;
371 : template class LinearOperator<complex<double>>;
372 :
373 : } // namespace elsa
|