Line data Source code
1 : #include "LinearResidual.h" 2 : #include "DataContainer.h" 3 : #include "DataDescriptor.h" 4 : #include "Identity.h" 5 : #include "LinearOperator.h" 6 : #include "TypeCasts.hpp" 7 : 8 : #include <optional> 9 : #include <stdexcept> 10 : 11 : namespace elsa 12 : { 13 : template <typename data_t> 14 : LinearResidual<data_t>::LinearResidual(const DataDescriptor& descriptor) 15 : : domainDesc_(descriptor.clone()), rangeDesc_(descriptor.clone()) 16 28 : { 17 28 : } 18 : 19 : template <typename data_t> 20 : LinearResidual<data_t>::LinearResidual(const DataContainer<data_t>& b) 21 : : domainDesc_(b.getDataDescriptor().clone()), 22 : rangeDesc_(b.getDataDescriptor().clone()), 23 : _dataVector(b) 24 28 : { 25 28 : } 26 : 27 : template <typename data_t> 28 : LinearResidual<data_t>::LinearResidual(const LinearOperator<data_t>& A) 29 : : domainDesc_(A.getDomainDescriptor().clone()), 30 : rangeDesc_(A.getRangeDescriptor().clone()), 31 : _operator(A.clone()) 32 28 : { 33 28 : } 34 : 35 : template <typename data_t> 36 : LinearResidual<data_t>::LinearResidual(const LinearOperator<data_t>& A, 37 : const DataContainer<data_t>& b) 38 : : domainDesc_(A.getDomainDescriptor().clone()), 39 : rangeDesc_(A.getRangeDescriptor().clone()), 40 : _operator(A.clone()), 41 : _dataVector{b} 42 28 : { 43 28 : if (A.getRangeDescriptor() != b.getDataDescriptor()) 44 0 : throw InvalidArgumentError("LinearResidual: A and b do not match"); 45 28 : } 46 : 47 : namespace detail 48 : { 49 : template <class data_t> 50 : std::unique_ptr<LinearOperator<data_t>> extractOp(LinearOperator<data_t>* op) 51 16 : { 52 16 : if (op) { 53 8 : return op->clone(); 54 8 : } else { 55 8 : return nullptr; 56 8 : } 57 16 : } 58 : } // namespace detail 59 : 60 : template <typename data_t> 61 : LinearResidual<data_t>::LinearResidual(const LinearResidual<data_t>& other) 62 : : domainDesc_(other.domainDesc_->clone()), 63 : rangeDesc_(other.rangeDesc_->clone()), 64 : _operator(detail::extractOp(other._operator.get())), 65 : _dataVector(other._dataVector) 66 16 : { 67 16 : } 68 : 69 : template <typename data_t> 70 : LinearResidual<data_t>& LinearResidual<data_t>::operator=(const LinearResidual<data_t>& other) 71 0 : { 72 0 : domainDesc_ = other.domainDesc_->clone(); 73 0 : rangeDesc_ = other.rangeDesc_->clone(); 74 : 75 0 : if (other.hasOperator()) { 76 0 : _operator = other._operator->clone(); 77 0 : } else { 78 0 : _operator = nullptr; 79 0 : } 80 : 81 0 : if (other.hasDataVector()) { 82 0 : _dataVector = other.getDataVector(); 83 0 : } else { 84 0 : _dataVector = std::nullopt; 85 0 : } 86 0 : return *this; 87 0 : } 88 : 89 : template <typename data_t> 90 : LinearResidual<data_t>::LinearResidual(LinearResidual<data_t>&& other) noexcept 91 : : domainDesc_(std::move(other.domainDesc_)), 92 : rangeDesc_(std::move(other.rangeDesc_)), 93 : _operator(std::move(other._operator)), 94 : _dataVector(std::move(other._dataVector)) 95 0 : { 96 0 : } 97 : 98 : template <typename data_t> 99 : LinearResidual<data_t>& 100 : LinearResidual<data_t>::operator=(LinearResidual<data_t>&& other) noexcept 101 0 : { 102 0 : domainDesc_ = std::move(other.domainDesc_); 103 0 : rangeDesc_ = std::move(other.rangeDesc_); 104 0 : _operator = std::move(other._operator); 105 0 : _dataVector = std::move(other._dataVector); 106 : 107 0 : return *this; 108 0 : } 109 : 110 : template <typename data_t> 111 : const DataDescriptor& LinearResidual<data_t>::getDomainDescriptor() const 112 80 : { 113 80 : return *domainDesc_; 114 80 : } 115 : 116 : template <typename data_t> 117 : const DataDescriptor& LinearResidual<data_t>::getRangeDescriptor() const 118 120 : { 119 120 : return *rangeDesc_; 120 120 : } 121 : 122 : template <typename data_t> 123 : bool LinearResidual<data_t>::hasOperator() const 124 188 : { 125 188 : return static_cast<bool>(_operator); 126 188 : } 127 : 128 : template <typename data_t> 129 : bool LinearResidual<data_t>::hasDataVector() const 130 140 : { 131 140 : return _dataVector.has_value(); 132 140 : } 133 : 134 : template <typename data_t> 135 : const LinearOperator<data_t>& LinearResidual<data_t>::getOperator() const 136 80 : { 137 80 : if (!_operator) 138 8 : throw Error("LinearResidual::getOperator: operator not present"); 139 : 140 72 : return *_operator; 141 72 : } 142 : 143 : template <typename data_t> 144 : const DataContainer<data_t>& LinearResidual<data_t>::getDataVector() const 145 72 : { 146 72 : if (!_dataVector) 147 8 : throw Error("LinearResidual::getDataVector: data vector not present"); 148 : 149 64 : return *_dataVector; 150 64 : } 151 : 152 : template <typename data_t> 153 : DataContainer<data_t> LinearResidual<data_t>::evaluate(const DataContainer<data_t>& x) const 154 32 : { 155 32 : DataContainer<data_t> out(this->getRangeDescriptor()); 156 32 : evaluate(x, out); 157 32 : return out; 158 32 : } 159 : 160 : template <typename data_t> 161 : void LinearResidual<data_t>::evaluate(const DataContainer<data_t>& x, 162 : DataContainer<data_t>& result) const 163 32 : { 164 32 : if (hasOperator()) 165 16 : _operator->apply(x, result); 166 16 : else 167 16 : result = x; 168 : 169 32 : if (hasDataVector()) { 170 16 : result -= *_dataVector; 171 16 : } 172 32 : } 173 : 174 : template <typename data_t> 175 : LinearOperator<data_t> 176 : LinearResidual<data_t>::getJacobian([[maybe_unused]] const DataContainer<data_t>& x) 177 16 : { 178 16 : if (hasOperator()) 179 8 : return leaf(*_operator); 180 8 : else 181 8 : return leaf(Identity<data_t>(this->getRangeDescriptor())); 182 16 : } 183 : 184 : // ------------------------------------------ 185 : // explicit template instantiation 186 : template class LinearResidual<float>; 187 : template class LinearResidual<double>; 188 : template class LinearResidual<complex<float>>; 189 : template class LinearResidual<complex<double>>; 190 : 191 : } // namespace elsa