Line data Source code
1 : #pragma once 2 : 3 : #include "elsaDefines.h" 4 : #include "Cloneable.h" 5 : #include "DataDescriptor.h" 6 : #include "DataContainer.h" 7 : 8 : #include <memory> 9 : 10 : namespace elsa 11 : { 12 : 13 : /** 14 : * @brief Base class representing a linear operator A. Also implements operator expression 15 : * functionality. 16 : * 17 : * @author Matthias Wieczorek - initial code 18 : * @author Maximilian Hornung - composite rewrite 19 : * @author Tobias Lasser - rewrite, modularization, modernization 20 : * 21 : * @tparam data_t data type for the domain and range of the operator, defaulting to real_t 22 : * 23 : * This class represents a linear operator A, expressed through its apply/applyAdjoint methods, 24 : * which implement Ax and A^ty for DataContainers x,y of appropriate sizes. Concrete 25 : * implementations of linear operators will derive from this class and override the 26 : * applyImpl/applyAdjointImpl methods. 27 : * 28 : * LinearOperator also provides functionality to support constructs like the operator expression 29 : * A^t*B+C, where A,B,C are linear operators. This operator composition is implemented via 30 : * evaluation trees. 31 : * 32 : * LinearOperator and all its derived classes are expected to be light-weight and easily 33 : * copyable/cloneable, due to the implementation of evaluation trees. Hence any 34 : * pre-computations/caching should only be done in a lazy manner (e.g. during the first call of 35 : * apply), and not in the constructor. 36 : */ 37 : template <typename data_t = real_t> 38 : class LinearOperator : public Cloneable<LinearOperator<data_t>> 39 : { 40 : public: 41 : /** 42 : * @brief Constructor for the linear operator A, mapping from domain to range 43 : * 44 : * @param[in] domainDescriptor DataDescriptor describing the domain of the operator 45 : * @param[in] rangeDescriptor DataDescriptor describing the range of the operator 46 : */ 47 : LinearOperator(const DataDescriptor& domainDescriptor, 48 : const DataDescriptor& rangeDescriptor); 49 : 50 : /// default destructor 51 40478 : ~LinearOperator() override = default; 52 : 53 : /// copy construction 54 : LinearOperator(const LinearOperator<data_t>& other); 55 : /// copy assignment 56 : LinearOperator<data_t>& operator=(const LinearOperator<data_t>& other); 57 : 58 : /// return the domain DataDescriptor 59 : const DataDescriptor& getDomainDescriptor() const; 60 : 61 : /// return the range DataDescriptor 62 : const DataDescriptor& getRangeDescriptor() const; 63 : 64 : /** 65 : * @brief apply the operator A to an element in the operator's domain 66 : * 67 : * @param[in] x input DataContainer (in the domain of the operator) 68 : * 69 : * @returns Ax DataContainer containing the application of operator A to data x, 70 : * i.e. in the range of the operator 71 : * 72 : * Please note: this method uses apply(x, Ax) to perform the actual operation. 73 : */ 74 : DataContainer<data_t> apply(const DataContainer<data_t>& x) const; 75 : 76 : /** 77 : * @brief apply the operator A to an element in the operator's domain 78 : * 79 : * @param[in] x input DataContainer (in the domain of the operator) 80 : * @param[out] Ax output DataContainer (in the range of the operator) 81 : * 82 : * Please note: this method calls the method applyImpl that has to be overridden in derived 83 : * classes. (Why is this method not virtual itself? Because you cannot have a non-virtual 84 : * function overloading a virtual one [apply with one vs. two arguments]). 85 : */ 86 : void apply(const DataContainer<data_t>& x, DataContainer<data_t>& Ax) const; 87 : 88 : /** 89 : * @brief apply the adjoint of operator A to an element of the operator's range 90 : * 91 : * @param[in] y input DataContainer (in the range of the operator) 92 : * 93 : * @returns A^ty DataContainer containing the application of A^t to data y, 94 : * i.e. in the domain of the operator 95 : * 96 : * Please note: this method uses applyAdjoint(y, Aty) to perform the actual operation. 97 : */ 98 : DataContainer<data_t> applyAdjoint(const DataContainer<data_t>& y) const; 99 : 100 : /** 101 : * @brief apply the adjoint of operator A to an element of the operator's range 102 : * 103 : * @param[in] y input DataContainer (in the range of the operator) 104 : * @param[out] Aty output DataContainer (in the domain of the operator) 105 : * 106 : * Please note: this method calls the method applyAdjointImpl that has to be overridden in 107 : * derived classes. (Why is this method not virtual itself? Because you cannot have a 108 : * non-virtual function overloading a virtual one [applyAdjoint with one vs. two args]). 109 : */ 110 : void applyAdjoint(const DataContainer<data_t>& y, DataContainer<data_t>& Aty) const; 111 : 112 : /// friend operator+ to support composition of LinearOperators (and its derivatives) 113 : friend LinearOperator<data_t> operator+(const LinearOperator<data_t>& lhs, 114 : const LinearOperator<data_t>& rhs) 115 138 : { 116 138 : return LinearOperator(lhs, rhs, CompositeMode::ADD); 117 138 : } 118 : 119 : /// friend operator* to support composition of LinearOperators (and its derivatives) 120 : friend LinearOperator<data_t> operator*(const LinearOperator<data_t>& lhs, 121 : const LinearOperator<data_t>& rhs) 122 341 : { 123 341 : return LinearOperator(lhs, rhs, CompositeMode::MULT); 124 341 : } 125 : 126 : /// friend operator* to support composition of a scalar and a LinearOperator 127 : friend LinearOperator<data_t> operator*(data_t scalar, const LinearOperator<data_t>& op) 128 24 : { 129 24 : return LinearOperator(scalar, op); 130 24 : } 131 : 132 : /// friend function to return the adjoint of a LinearOperator (and its derivatives) 133 : friend LinearOperator<data_t> adjoint(const LinearOperator<data_t>& op) 134 254 : { 135 254 : return LinearOperator(op, true); 136 254 : } 137 : 138 : /// friend function to return a leaf node of a LinearOperator (and its derivatives) 139 : friend LinearOperator<data_t> leaf(const LinearOperator<data_t>& op) 140 11184 : { 141 11184 : return LinearOperator(op, false); 142 11184 : } 143 : 144 : protected: 145 : /// the data descriptor of the domain of the operator 146 : std::unique_ptr<DataDescriptor> _domainDescriptor; 147 : 148 : /// the data descriptor of the range of the operator 149 : std::unique_ptr<DataDescriptor> _rangeDescriptor; 150 : 151 : /// implement the polymorphic clone operation 152 : LinearOperator<data_t>* cloneImpl() const override; 153 : 154 : /// implement the polymorphic comparison operation 155 : bool isEqual(const LinearOperator<data_t>& other) const override; 156 : 157 : /// the apply method that has to be overridden in derived classes 158 : virtual void applyImpl(const DataContainer<data_t>& x, DataContainer<data_t>& Ax) const; 159 : 160 : /// the applyAdjoint method that has to be overridden in derived classes 161 : virtual void applyAdjointImpl(const DataContainer<data_t>& y, 162 : DataContainer<data_t>& Aty) const; 163 : 164 : private: 165 : /// pointers to nodes in the evaluation tree 166 : std::unique_ptr<LinearOperator<data_t>> _lhs{}, _rhs{}; 167 : 168 : std::optional<data_t> _scalar = {}; 169 : 170 : /// flag whether this is a leaf-node 171 : bool _isLeaf{false}; 172 : 173 : /// flag whether this is a leaf-node to implement an adjoint operator 174 : bool _isAdjoint{false}; 175 : 176 : /// flag whether this is a composite (internal node) of the evaluation tree 177 : bool _isComposite{false}; 178 : 179 : /// enum class denoting the mode of composition (+, *) 180 : enum class CompositeMode { ADD, MULT, SCALAR_MULT }; 181 : 182 : /// variable storing the composition mode (+, *) 183 : CompositeMode _mode{CompositeMode::MULT}; 184 : 185 : /// constructor to produce an adjoint leaf node 186 : LinearOperator(const LinearOperator<data_t>& op, bool isAdjoint); 187 : 188 : /// constructor to produce a composite (internal node) of the evaluation tree 189 : LinearOperator(const LinearOperator<data_t>& lhs, const LinearOperator<data_t>& rhs, 190 : CompositeMode mode); 191 : 192 : /// constructor to produce a composite (internal node) of the evaluation tree 193 : LinearOperator(data_t scalar, const LinearOperator<data_t>& op); 194 : }; 195 : 196 : } // namespace elsa