Line data Source code
1 : #pragma once 2 : 3 : #include "elsaDefines.h" 4 : #include "Cloneable.h" 5 : #include "DataDescriptor.h" 6 : #include "DataContainer.h" 7 : 8 : #include <memory> 9 : #include <optional> 10 : 11 : namespace elsa 12 : { 13 : 14 : /** 15 : * @brief Base class representing a linear operator A. Also implements operator expression 16 : * functionality. 17 : * 18 : * @author Matthias Wieczorek - initial code 19 : * @author Maximilian Hornung - composite rewrite 20 : * @author Tobias Lasser - rewrite, modularization, modernization 21 : * 22 : * @tparam data_t data type for the domain and range of the operator, defaulting to real_t 23 : * 24 : * This class represents a linear operator A, expressed through its apply/applyAdjoint methods, 25 : * which implement Ax and A^ty for DataContainers x,y of appropriate sizes. Concrete 26 : * implementations of linear operators will derive from this class and override the 27 : * applyImpl/applyAdjointImpl methods. 28 : * 29 : * LinearOperator also provides functionality to support constructs like the operator expression 30 : * A^t*B+C, where A,B,C are linear operators. This operator composition is implemented via 31 : * evaluation trees. 32 : * 33 : * LinearOperator and all its derived classes are expected to be light-weight and easily 34 : * copyable/cloneable, due to the implementation of evaluation trees. Hence any 35 : * pre-computations/caching should only be done in a lazy manner (e.g. during the first call of 36 : * apply), and not in the constructor. 37 : */ 38 : template <typename data_t = real_t> 39 : class LinearOperator : public Cloneable<LinearOperator<data_t>> 40 : { 41 : public: 42 : /** 43 : * @brief Constructor for the linear operator A, mapping from domain to range 44 : * 45 : * @param[in] domainDescriptor DataDescriptor describing the domain of the operator 46 : * @param[in] rangeDescriptor DataDescriptor describing the range of the operator 47 : */ 48 : LinearOperator(const DataDescriptor& domainDescriptor, 49 : const DataDescriptor& rangeDescriptor); 50 : 51 : /// default destructor 52 14904 : ~LinearOperator() override = default; 53 : 54 : /// copy construction 55 : LinearOperator(const LinearOperator<data_t>& other); 56 : /// copy assignment 57 : LinearOperator<data_t>& operator=(const LinearOperator<data_t>& other); 58 : 59 : /// return the domain DataDescriptor 60 : const DataDescriptor& getDomainDescriptor() const; 61 : 62 : /// return the range DataDescriptor 63 : const DataDescriptor& getRangeDescriptor() const; 64 : 65 : /** 66 : * @brief apply the operator A to an element in the operator's domain 67 : * 68 : * @param[in] x input DataContainer (in the domain of the operator) 69 : * 70 : * @returns Ax DataContainer containing the application of operator A to data x, 71 : * i.e. in the range of the operator 72 : * 73 : * Please note: this method uses apply(x, Ax) to perform the actual operation. 74 : */ 75 : DataContainer<data_t> apply(const DataContainer<data_t>& x) const; 76 : 77 : /** 78 : * @brief apply the operator A to an element in the operator's domain 79 : * 80 : * @param[in] x input DataContainer (in the domain of the operator) 81 : * @param[out] Ax output DataContainer (in the range of the operator) 82 : * 83 : * Please note: this method calls the method applyImpl that has to be overridden in derived 84 : * classes. (Why is this method not virtual itself? Because you cannot have a non-virtual 85 : * function overloading a virtual one [apply with one vs. two arguments]). 86 : */ 87 : void apply(const DataContainer<data_t>& x, DataContainer<data_t>& Ax) const; 88 : 89 : /** 90 : * @brief apply the adjoint of operator A to an element of the operator's range 91 : * 92 : * @param[in] y input DataContainer (in the range of the operator) 93 : * 94 : * @returns A^ty DataContainer containing the application of A^t to data y, 95 : * i.e. in the domain of the operator 96 : * 97 : * Please note: this method uses applyAdjoint(y, Aty) to perform the actual operation. 98 : */ 99 : DataContainer<data_t> applyAdjoint(const DataContainer<data_t>& y) const; 100 : 101 : /** 102 : * @brief apply the adjoint of operator A to an element of the operator's range 103 : * 104 : * @param[in] y input DataContainer (in the range of the operator) 105 : * @param[out] Aty output DataContainer (in the domain of the operator) 106 : * 107 : * Please note: this method calls the method applyAdjointImpl that has to be overridden in 108 : * derived classes. (Why is this method not virtual itself? Because you cannot have a 109 : * non-virtual function overloading a virtual one [applyAdjoint with one vs. two args]). 110 : */ 111 : void applyAdjoint(const DataContainer<data_t>& y, DataContainer<data_t>& Aty) const; 112 : 113 : /// friend operator+ to support composition of LinearOperators (and its derivatives) 114 : friend LinearOperator<data_t> operator+(const LinearOperator<data_t>& lhs, 115 : const LinearOperator<data_t>& rhs) 116 32 : { 117 32 : return LinearOperator(lhs, rhs, CompositeMode::ADD); 118 32 : } 119 : 120 : /// friend operator* to support composition of LinearOperators (and its derivatives) 121 : friend LinearOperator<data_t> operator*(const LinearOperator<data_t>& lhs, 122 : const LinearOperator<data_t>& rhs) 123 408 : { 124 408 : return LinearOperator(lhs, rhs, CompositeMode::MULT); 125 408 : } 126 : 127 : /// friend operator* to support composition of a scalar and a LinearOperator 128 : friend LinearOperator<data_t> operator*(data_t scalar, const LinearOperator<data_t>& op) 129 48 : { 130 48 : return LinearOperator(scalar, op); 131 48 : } 132 : 133 : /// friend function to return the adjoint of a LinearOperator (and its derivatives) 134 : friend LinearOperator<data_t> adjoint(const LinearOperator<data_t>& op) 135 194 : { 136 194 : return LinearOperator(op, true); 137 194 : } 138 : 139 : /// friend function to return a leaf node of a LinearOperator (and its derivatives) 140 : friend LinearOperator<data_t> leaf(const LinearOperator<data_t>& op) 141 192 : { 142 192 : return LinearOperator(op, false); 143 192 : } 144 : 145 : protected: 146 : /// the data descriptor of the domain of the operator 147 : std::unique_ptr<DataDescriptor> _domainDescriptor; 148 : 149 : /// the data descriptor of the range of the operator 150 : std::unique_ptr<DataDescriptor> _rangeDescriptor; 151 : 152 : /// implement the polymorphic clone operation 153 : LinearOperator<data_t>* cloneImpl() const override; 154 : 155 : /// implement the polymorphic comparison operation 156 : bool isEqual(const LinearOperator<data_t>& other) const override; 157 : 158 : /// the apply method that has to be overridden in derived classes 159 : virtual void applyImpl(const DataContainer<data_t>& x, DataContainer<data_t>& Ax) const; 160 : 161 : /// the applyAdjoint method that has to be overridden in derived classes 162 : virtual void applyAdjointImpl(const DataContainer<data_t>& y, 163 : DataContainer<data_t>& Aty) const; 164 : 165 : private: 166 : /// pointers to nodes in the evaluation tree 167 : std::unique_ptr<LinearOperator<data_t>> _lhs{}, _rhs{}; 168 : 169 : std::optional<data_t> _scalar = {}; 170 : 171 : /// flag whether this is a leaf-node 172 : bool _isLeaf{false}; 173 : 174 : /// flag whether this is a leaf-node to implement an adjoint operator 175 : bool _isAdjoint{false}; 176 : 177 : /// flag whether this is a composite (internal node) of the evaluation tree 178 : bool _isComposite{false}; 179 : 180 : /// enum class denoting the mode of composition (+, *) 181 : enum class CompositeMode { ADD, MULT, SCALAR_MULT }; 182 : 183 : /// variable storing the composition mode (+, *) 184 : CompositeMode _mode{CompositeMode::MULT}; 185 : 186 : /// constructor to produce an adjoint leaf node 187 : LinearOperator(const LinearOperator<data_t>& op, bool isAdjoint); 188 : 189 : /// constructor to produce a composite (internal node) of the evaluation tree 190 : LinearOperator(const LinearOperator<data_t>& lhs, const LinearOperator<data_t>& rhs, 191 : CompositeMode mode); 192 : 193 : /// constructor to produce a composite (internal node) of the evaluation tree 194 : LinearOperator(data_t scalar, const LinearOperator<data_t>& op); 195 : }; 196 : 197 : } // namespace elsa