Line data Source code
1 : #include "LinearResidual.h" 2 : #include "Identity.h" 3 : #include "TypeCasts.hpp" 4 : 5 : #include <stdexcept> 6 : 7 : namespace elsa 8 : { 9 : template <typename data_t> 10 : LinearResidual<data_t>::LinearResidual(const DataDescriptor& descriptor) 11 : : Residual<data_t>(descriptor, descriptor) 12 1621 : { 13 1621 : } 14 : 15 : template <typename data_t> 16 : LinearResidual<data_t>::LinearResidual(const DataContainer<data_t>& b) 17 : : Residual<data_t>(b.getDataDescriptor(), b.getDataDescriptor()), _dataVector{b} 18 208 : { 19 208 : } 20 : 21 : template <typename data_t> 22 : LinearResidual<data_t>::LinearResidual(const LinearOperator<data_t>& A) 23 : : Residual<data_t>(A.getDomainDescriptor(), A.getRangeDescriptor()), _operator{A.clone()} 24 230 : { 25 230 : } 26 : 27 : template <typename data_t> 28 : LinearResidual<data_t>::LinearResidual(const LinearOperator<data_t>& A, 29 : const DataContainer<data_t>& b) 30 : : Residual<data_t>(A.getDomainDescriptor(), A.getRangeDescriptor()), 31 : _operator{A.clone()}, 32 : _dataVector{b} 33 1825 : { 34 1825 : if (A.getRangeDescriptor().getNumberOfCoefficients() != b.getSize()) 35 0 : throw InvalidArgumentError("LinearResidual: A and b do not match"); 36 1825 : } 37 : 38 : template <typename data_t> 39 : bool LinearResidual<data_t>::hasOperator() const 40 50192 : { 41 50192 : return static_cast<bool>(_operator); 42 50192 : } 43 : 44 : template <typename data_t> 45 : bool LinearResidual<data_t>::hasDataVector() const 46 16568 : { 47 16568 : return _dataVector.has_value(); 48 16568 : } 49 : 50 : template <typename data_t> 51 : const LinearOperator<data_t>& LinearResidual<data_t>::getOperator() const 52 1574 : { 53 1574 : if (!_operator) 54 8 : throw Error("LinearResidual::getOperator: operator not present"); 55 : 56 1566 : return *_operator; 57 1566 : } 58 : 59 : template <typename data_t> 60 : const DataContainer<data_t>& LinearResidual<data_t>::getDataVector() const 61 1552 : { 62 1552 : if (!_dataVector) 63 8 : throw Error("LinearResidual::getDataVector: data vector not present"); 64 : 65 1544 : return *_dataVector; 66 1544 : } 67 : 68 : template <typename data_t> 69 : LinearResidual<data_t>* LinearResidual<data_t>::cloneImpl() const 70 1880 : { 71 1880 : if (hasOperator() && hasDataVector()) 72 1071 : return new LinearResidual<data_t>(getOperator(), getDataVector()); 73 : 74 809 : if (hasOperator()) 75 68 : return new LinearResidual<data_t>(getOperator()); 76 : 77 741 : if (hasDataVector()) 78 68 : return new LinearResidual<data_t>(getDataVector()); 79 : 80 673 : return new LinearResidual<data_t>(this->getDomainDescriptor()); 81 673 : } 82 : 83 : template <typename data_t> 84 : bool LinearResidual<data_t>::isEqual(const Residual<data_t>& other) const 85 269 : { 86 269 : if (!Residual<data_t>::isEqual(other)) 87 0 : return false; 88 : 89 269 : auto otherLinearResidual = downcast_safe<LinearResidual>(&other); 90 269 : if (!otherLinearResidual) 91 0 : return false; 92 : 93 269 : if (hasOperator() != otherLinearResidual->hasOperator() 94 269 : || hasDataVector() != otherLinearResidual->hasDataVector()) 95 0 : return false; 96 : 97 269 : if ((_operator && !otherLinearResidual->_operator) 98 269 : || (!_operator && otherLinearResidual->_operator) 99 269 : || (_dataVector && !otherLinearResidual->_dataVector) 100 269 : || (!_dataVector && otherLinearResidual->_dataVector)) 101 0 : return false; 102 : 103 269 : if (_operator && otherLinearResidual->_operator 104 269 : && *_operator != *otherLinearResidual->_operator) 105 0 : return false; 106 : 107 269 : if (_dataVector && otherLinearResidual->_dataVector 108 269 : && *_dataVector != *otherLinearResidual->_dataVector) 109 0 : return false; 110 : 111 269 : return true; 112 269 : } 113 : 114 : template <typename data_t> 115 : void LinearResidual<data_t>::evaluateImpl(const DataContainer<data_t>& x, 116 : DataContainer<data_t>& result) const 117 11212 : { 118 11212 : if (hasOperator()) 119 11192 : _operator->apply(x, result); 120 20 : else 121 20 : result = x; 122 : 123 11212 : if (hasDataVector()) 124 11186 : result -= *_dataVector; 125 11212 : } 126 : 127 : template <typename data_t> 128 : LinearOperator<data_t> 129 : LinearResidual<data_t>::getJacobianImpl([[maybe_unused]] const DataContainer<data_t>& x) 130 10915 : { 131 10915 : if (hasOperator()) 132 10907 : return leaf(*_operator); 133 8 : else 134 8 : return leaf(Identity<data_t>(this->getRangeDescriptor())); 135 10915 : } 136 : 137 : // ------------------------------------------ 138 : // explicit template instantiation 139 : template class LinearResidual<float>; 140 : template class LinearResidual<double>; 141 : template class LinearResidual<complex<float>>; 142 : template class LinearResidual<complex<double>>; 143 : 144 : } // namespace elsa