Line data Source code
1 : #pragma once 2 : 3 : #include <optional> 4 : 5 : #include "DataContainer.h" 6 : #include "DataDescriptor.h" 7 : #include "LinearOperator.h" 8 : 9 : namespace elsa 10 : { 11 : /** 12 : * @brief Class representing a linear residual, i.e. Ax - b with operator A and vectors x, b. 13 : * 14 : * A linear residual is a vector-valued mapping \f$ \mathbb{R}^n\to\mathbb{R}^m \f$, namely 15 : * \f$ x \mapsto Ax - b \f$, where A is a LinearOperator, b a constant data vector 16 : * (DataContainer) and x a variable (DataContainer). This linear residual can be used as input 17 : * to a Functional. 18 : * 19 : * @tparam data_t data type for the domain and range of the operator, default to real_t 20 : * 21 : * @author 22 : * * Matthias Wieczorek - initial code 23 : * * Tobias Lasser - modularization, modernization 24 : */ 25 : template <typename data_t = real_t> 26 : class LinearResidual 27 : { 28 : public: 29 : /** 30 : * @brief Constructor for a trivial residual \f$ x \mapsto x \f$ 31 : * 32 : * @param[in] descriptor describing the domain = range of the residual 33 : */ 34 : explicit LinearResidual(const DataDescriptor& descriptor); 35 : 36 : /** 37 : * @brief Constructor for a simple residual \f$ x \mapsto x - b \f$ 38 : * 39 : * @param[in] b a vector (DataContainer) that will be subtracted from x 40 : */ 41 : explicit LinearResidual(const DataContainer<data_t>& b); 42 : 43 : /** @brief Constructor for a residual \f$ x \mapsto Ax \f$ 44 : * 45 : * @param[in] A a LinearOperator 46 : */ 47 : explicit LinearResidual(const LinearOperator<data_t>& A); 48 : 49 : /** 50 : * @brief Constructor for a residual \f$ x \mapsto Ax - b \f$ 51 : * 52 : * @param[in] A a LinearOperator 53 : * @param[in] b a vector (DataContainer) 54 : */ 55 : LinearResidual(const LinearOperator<data_t>& A, const DataContainer<data_t>& b); 56 : 57 : // Copy constructor 58 : LinearResidual(const LinearResidual<data_t>&); 59 : 60 : // Copy assignment 61 : LinearResidual& operator=(const LinearResidual<data_t>&); 62 : 63 : // Move constructor 64 : LinearResidual(LinearResidual<data_t>&&) noexcept; 65 : 66 : // Move assignment 67 : LinearResidual& operator=(LinearResidual<data_t>&&) noexcept; 68 : 69 : /// default destructor 70 128 : ~LinearResidual() = default; 71 : 72 : /// return the domain descriptor of the residual 73 : const DataDescriptor& getDomainDescriptor() const; 74 : 75 : /// return the range descriptor of the residual 76 : const DataDescriptor& getRangeDescriptor() const; 77 : 78 : /// return true if the residual has an operator A 79 : bool hasOperator() const; 80 : 81 : /// return true if the residual has a data vector b 82 : bool hasDataVector() const; 83 : 84 : /// return the operator A (throws if the residual has none) 85 : const LinearOperator<data_t>& getOperator() const; 86 : 87 : /// return the data vector b (throws if the residual has none) 88 : const DataContainer<data_t>& getDataVector() const; 89 : 90 : /** 91 : * @brief evaluate the residual at x and return the result 92 : * 93 : * @param[in] x input DataContainer (in the domain of the residual) 94 : * 95 : * @returns result DataContainer (in the range of the residual) containing the result of 96 : * the evaluation of the residual at x 97 : */ 98 : DataContainer<data_t> evaluate(const DataContainer<data_t>& x) const; 99 : 100 : /** 101 : * @brief evaluate the residual at x and store in result 102 : * 103 : * @param[in] x input DataContainer (in the domain of the residual) 104 : * @param[out] result output DataContainer (in the range of the residual) 105 : */ 106 : void evaluate(const DataContainer<data_t>& x, DataContainer<data_t>& result) const; 107 : 108 : /** 109 : * @brief return the Jacobian (first derivative) of the linear residual at x. 110 : * If A is set, then the Jacobian is A and this returns a copy of A. 111 : * If A is not set, then an Identity operator is returned. 112 : * 113 : * @param x input DataContainer (in the domain of the residual) 114 : * 115 : * @returns a LinearOperator (the Jacobian) 116 : */ 117 : LinearOperator<data_t> getJacobian(const DataContainer<data_t>& x); 118 : 119 : private: 120 : /// Descriptor of domain 121 : std::unique_ptr<DataDescriptor> domainDesc_; 122 : 123 : /// Descriptor of range 124 : std::unique_ptr<DataDescriptor> rangeDesc_; 125 : 126 : /// the operator A, nullptr implies no operator present 127 : std::unique_ptr<LinearOperator<data_t>> _operator{}; 128 : 129 : /// optional data vector b 130 : std::optional<DataContainer<data_t>> _dataVector{}; 131 : }; 132 : 133 : template <class data_t> 134 : bool operator==(const LinearResidual<data_t>& lhs, const LinearResidual<data_t>& rhs) 135 32 : { 136 32 : if (lhs.getDomainDescriptor() != rhs.getDomainDescriptor()) { 137 0 : return false; 138 0 : } 139 : 140 32 : if (lhs.getRangeDescriptor() != rhs.getRangeDescriptor()) { 141 0 : return false; 142 0 : } 143 : 144 32 : if (lhs.hasOperator() && rhs.hasOperator() && lhs.getOperator() != rhs.getOperator()) { 145 0 : return false; 146 0 : } 147 : 148 32 : if (lhs.hasDataVector() && rhs.hasDataVector() 149 32 : && lhs.getDataVector() != rhs.getDataVector()) { 150 0 : return false; 151 0 : } 152 : 153 32 : return true; 154 32 : } 155 : 156 : template <class data_t> 157 : bool operator!=(const LinearResidual<data_t>& lhs, const LinearResidual<data_t>& rhs) 158 16 : { 159 16 : return !(lhs == rhs); 160 16 : } 161 : } // namespace elsa