LCOV - code coverage report
Current view: top level - elsa/core - LinearOperator.h (source / functions) Hit Total Coverage
Test: coverage-all.lcov Lines: 16 16 100.0 %
Date: 2025-01-02 06:42:49 Functions: 24 24 100.0 %

          Line data    Source code
       1             : #pragma once
       2             : 
       3             : #include "elsaDefines.h"
       4             : #include "Cloneable.h"
       5             : #include "DataDescriptor.h"
       6             : #include "DataContainer.h"
       7             : 
       8             : #include <memory>
       9             : #include <optional>
      10             : 
      11             : namespace elsa
      12             : {
      13             : 
      14             :     /**
      15             :      * @brief Base class representing a linear operator A. Also implements operator expression
      16             :      * functionality.
      17             :      *
      18             :      * @author Matthias Wieczorek - initial code
      19             :      * @author Maximilian Hornung - composite rewrite
      20             :      * @author Tobias Lasser - rewrite, modularization, modernization
      21             :      *
      22             :      * @tparam data_t data type for the domain and range of the operator, defaulting to real_t
      23             :      *
      24             :      * This class represents a linear operator A, expressed through its apply/applyAdjoint methods,
      25             :      * which implement Ax and A^ty for DataContainers x,y of appropriate sizes. Concrete
      26             :      * implementations of linear operators will derive from this class and override the
      27             :      * applyImpl/applyAdjointImpl methods.
      28             :      *
      29             :      * LinearOperator also provides functionality to support constructs like the operator expression
      30             :      * A^t*B+C, where A,B,C are linear operators. This operator composition is implemented via
      31             :      * evaluation trees.
      32             :      *
      33             :      * LinearOperator and all its derived classes are expected to be light-weight and easily
      34             :      * copyable/cloneable, due to the implementation of evaluation trees. Hence any
      35             :      * pre-computations/caching should only be done in a lazy manner (e.g. during the first call of
      36             :      * apply), and not in the constructor.
      37             :      */
      38             :     template <typename data_t = real_t>
      39             :     class LinearOperator : public Cloneable<LinearOperator<data_t>>
      40             :     {
      41             :     public:
      42             :         /**
      43             :          * @brief Constructor for the linear operator A, mapping from domain to range
      44             :          *
      45             :          * @param[in] domainDescriptor DataDescriptor describing the domain of the operator
      46             :          * @param[in] rangeDescriptor DataDescriptor describing the range of the operator
      47             :          */
      48             :         LinearOperator(const DataDescriptor& domainDescriptor,
      49             :                        const DataDescriptor& rangeDescriptor);
      50             : 
      51             :         /// default destructor
      52       14904 :         ~LinearOperator() override = default;
      53             : 
      54             :         /// copy construction
      55             :         LinearOperator(const LinearOperator<data_t>& other);
      56             :         /// copy assignment
      57             :         LinearOperator<data_t>& operator=(const LinearOperator<data_t>& other);
      58             : 
      59             :         /// return the domain DataDescriptor
      60             :         const DataDescriptor& getDomainDescriptor() const;
      61             : 
      62             :         /// return the range DataDescriptor
      63             :         const DataDescriptor& getRangeDescriptor() const;
      64             : 
      65             :         /**
      66             :          * @brief apply the operator A to an element in the operator's domain
      67             :          *
      68             :          * @param[in] x input DataContainer (in the domain of the operator)
      69             :          *
      70             :          * @returns Ax DataContainer containing the application of operator A to data x,
      71             :          * i.e. in the range of the operator
      72             :          *
      73             :          * Please note: this method uses apply(x, Ax) to perform the actual operation.
      74             :          */
      75             :         DataContainer<data_t> apply(const DataContainer<data_t>& x) const;
      76             : 
      77             :         /**
      78             :          * @brief apply the operator A to an element in the operator's domain
      79             :          *
      80             :          * @param[in] x input DataContainer (in the domain of the operator)
      81             :          * @param[out] Ax output DataContainer (in the range of the operator)
      82             :          *
      83             :          * Please note: this method calls the method applyImpl that has to be overridden in derived
      84             :          * classes. (Why is this method not virtual itself? Because you cannot have a non-virtual
      85             :          * function overloading a virtual one [apply with one vs. two arguments]).
      86             :          */
      87             :         void apply(const DataContainer<data_t>& x, DataContainer<data_t>& Ax) const;
      88             : 
      89             :         /**
      90             :          * @brief apply the adjoint of operator A to an element of the operator's range
      91             :          *
      92             :          * @param[in] y input DataContainer (in the range of the operator)
      93             :          *
      94             :          * @returns A^ty DataContainer containing the application of A^t to data y,
      95             :          * i.e. in the domain of the operator
      96             :          *
      97             :          * Please note: this method uses applyAdjoint(y, Aty) to perform the actual operation.
      98             :          */
      99             :         DataContainer<data_t> applyAdjoint(const DataContainer<data_t>& y) const;
     100             : 
     101             :         /**
     102             :          * @brief apply the adjoint of operator A to an element of the operator's range
     103             :          *
     104             :          * @param[in] y input DataContainer (in the range of the operator)
     105             :          * @param[out] Aty output DataContainer (in the domain of the operator)
     106             :          *
     107             :          * Please note: this method calls the method applyAdjointImpl that has to be overridden in
     108             :          * derived classes. (Why is this method not virtual itself? Because you cannot have a
     109             :          * non-virtual function overloading a virtual one [applyAdjoint with one vs. two args]).
     110             :          */
     111             :         void applyAdjoint(const DataContainer<data_t>& y, DataContainer<data_t>& Aty) const;
     112             : 
     113             :         /// friend operator+ to support composition of LinearOperators (and its derivatives)
     114             :         friend LinearOperator<data_t> operator+(const LinearOperator<data_t>& lhs,
     115             :                                                 const LinearOperator<data_t>& rhs)
     116          32 :         {
     117          32 :             return LinearOperator(lhs, rhs, CompositeMode::ADD);
     118          32 :         }
     119             : 
     120             :         /// friend operator* to support composition of LinearOperators (and its derivatives)
     121             :         friend LinearOperator<data_t> operator*(const LinearOperator<data_t>& lhs,
     122             :                                                 const LinearOperator<data_t>& rhs)
     123         408 :         {
     124         408 :             return LinearOperator(lhs, rhs, CompositeMode::MULT);
     125         408 :         }
     126             : 
     127             :         /// friend operator* to support composition of a scalar and a LinearOperator
     128             :         friend LinearOperator<data_t> operator*(data_t scalar, const LinearOperator<data_t>& op)
     129          48 :         {
     130          48 :             return LinearOperator(scalar, op);
     131          48 :         }
     132             : 
     133             :         /// friend function to return the adjoint of a LinearOperator (and its derivatives)
     134             :         friend LinearOperator<data_t> adjoint(const LinearOperator<data_t>& op)
     135         194 :         {
     136         194 :             return LinearOperator(op, true);
     137         194 :         }
     138             : 
     139             :         /// friend function to return a leaf node of a LinearOperator (and its derivatives)
     140             :         friend LinearOperator<data_t> leaf(const LinearOperator<data_t>& op)
     141         192 :         {
     142         192 :             return LinearOperator(op, false);
     143         192 :         }
     144             : 
     145             :     protected:
     146             :         /// the data descriptor of the domain of the operator
     147             :         std::unique_ptr<DataDescriptor> _domainDescriptor;
     148             : 
     149             :         /// the data descriptor of the range of the operator
     150             :         std::unique_ptr<DataDescriptor> _rangeDescriptor;
     151             : 
     152             :         /// implement the polymorphic clone operation
     153             :         LinearOperator<data_t>* cloneImpl() const override;
     154             : 
     155             :         /// implement the polymorphic comparison operation
     156             :         bool isEqual(const LinearOperator<data_t>& other) const override;
     157             : 
     158             :         /// the apply method that has to be overridden in derived classes
     159             :         virtual void applyImpl(const DataContainer<data_t>& x, DataContainer<data_t>& Ax) const;
     160             : 
     161             :         /// the applyAdjoint  method that has to be overridden in derived classes
     162             :         virtual void applyAdjointImpl(const DataContainer<data_t>& y,
     163             :                                       DataContainer<data_t>& Aty) const;
     164             : 
     165             :     private:
     166             :         /// pointers to nodes in the evaluation tree
     167             :         std::unique_ptr<LinearOperator<data_t>> _lhs{}, _rhs{};
     168             : 
     169             :         std::optional<data_t> _scalar = {};
     170             : 
     171             :         /// flag whether this is a leaf-node
     172             :         bool _isLeaf{false};
     173             : 
     174             :         /// flag whether this is a leaf-node to implement an adjoint operator
     175             :         bool _isAdjoint{false};
     176             : 
     177             :         /// flag whether this is a composite (internal node) of the evaluation tree
     178             :         bool _isComposite{false};
     179             : 
     180             :         /// enum class denoting the mode of composition (+, *)
     181             :         enum class CompositeMode { ADD, MULT, SCALAR_MULT };
     182             : 
     183             :         /// variable storing the composition mode (+, *)
     184             :         CompositeMode _mode{CompositeMode::MULT};
     185             : 
     186             :         /// constructor to produce an adjoint leaf node
     187             :         LinearOperator(const LinearOperator<data_t>& op, bool isAdjoint);
     188             : 
     189             :         /// constructor to produce a composite (internal node) of the evaluation tree
     190             :         LinearOperator(const LinearOperator<data_t>& lhs, const LinearOperator<data_t>& rhs,
     191             :                        CompositeMode mode);
     192             : 
     193             :         /// constructor to produce a composite (internal node) of the evaluation tree
     194             :         LinearOperator(data_t scalar, const LinearOperator<data_t>& op);
     195             :     };
     196             : 
     197             : } // namespace elsa

Generated by: LCOV version 1.14