LCOV - code coverage report
Current view: top level - elsa/functionals - LinearResidual.h (source / functions) Hit Total Coverage
Test: coverage-all.lcov Lines: 12 20 60.0 %
Date: 2024-05-16 04:22:26 Functions: 12 12 100.0 %

          Line data    Source code
       1             : #pragma once
       2             : 
       3             : #include <optional>
       4             : 
       5             : #include "DataContainer.h"
       6             : #include "DataDescriptor.h"
       7             : #include "LinearOperator.h"
       8             : 
       9             : namespace elsa
      10             : {
      11             :     /**
      12             :      * @brief Class representing a linear residual, i.e. Ax - b with operator A and vectors x, b.
      13             :      *
      14             :      * A linear residual is a vector-valued mapping \f$ \mathbb{R}^n\to\mathbb{R}^m \f$, namely
      15             :      * \f$ x \mapsto  Ax - b \f$, where A is a LinearOperator, b a constant data vector
      16             :      * (DataContainer) and x a variable (DataContainer). This linear residual can be used as input
      17             :      * to a Functional.
      18             :      *
      19             :      * @tparam data_t data type for the domain and range of the operator, default to real_t
      20             :      *
      21             :      * @author
      22             :      * * Matthias Wieczorek - initial code
      23             :      * * Tobias Lasser - modularization, modernization
      24             :      */
      25             :     template <typename data_t = real_t>
      26             :     class LinearResidual
      27             :     {
      28             :     public:
      29             :         /**
      30             :          * @brief Constructor for a trivial residual \f$ x \mapsto x \f$
      31             :          *
      32             :          * @param[in] descriptor describing the domain = range of the residual
      33             :          */
      34             :         explicit LinearResidual(const DataDescriptor& descriptor);
      35             : 
      36             :         /**
      37             :          * @brief Constructor for a simple residual \f$ x \mapsto x - b \f$
      38             :          *
      39             :          * @param[in] b a vector (DataContainer) that will be subtracted from x
      40             :          */
      41             :         explicit LinearResidual(const DataContainer<data_t>& b);
      42             : 
      43             :         /** @brief Constructor for a residual \f$ x \mapsto Ax \f$
      44             :          *
      45             :          * @param[in] A a LinearOperator
      46             :          */
      47             :         explicit LinearResidual(const LinearOperator<data_t>& A);
      48             : 
      49             :         /**
      50             :          * @brief Constructor for a residual \f$ x \mapsto Ax - b \f$
      51             :          *
      52             :          * @param[in] A a LinearOperator
      53             :          * @param[in] b a vector (DataContainer)
      54             :          */
      55             :         LinearResidual(const LinearOperator<data_t>& A, const DataContainer<data_t>& b);
      56             : 
      57             :         // Copy constructor
      58             :         LinearResidual(const LinearResidual<data_t>&);
      59             : 
      60             :         // Copy assignment
      61             :         LinearResidual& operator=(const LinearResidual<data_t>&);
      62             : 
      63             :         // Move constructor
      64             :         LinearResidual(LinearResidual<data_t>&&) noexcept;
      65             : 
      66             :         // Move assignment
      67             :         LinearResidual& operator=(LinearResidual<data_t>&&) noexcept;
      68             : 
      69             :         /// default destructor
      70         128 :         ~LinearResidual() = default;
      71             : 
      72             :         /// return the domain descriptor of the residual
      73             :         const DataDescriptor& getDomainDescriptor() const;
      74             : 
      75             :         /// return the range descriptor of the residual
      76             :         const DataDescriptor& getRangeDescriptor() const;
      77             : 
      78             :         /// return true if the residual has an operator A
      79             :         bool hasOperator() const;
      80             : 
      81             :         /// return true if the residual has a data vector b
      82             :         bool hasDataVector() const;
      83             : 
      84             :         /// return the operator A (throws if the residual has none)
      85             :         const LinearOperator<data_t>& getOperator() const;
      86             : 
      87             :         /// return the data vector b (throws if the residual has none)
      88             :         const DataContainer<data_t>& getDataVector() const;
      89             : 
      90             :         /**
      91             :          * @brief evaluate the residual at x and return the result
      92             :          *
      93             :          * @param[in] x input DataContainer (in the domain of the residual)
      94             :          *
      95             :          * @returns result DataContainer (in the range of the residual) containing the result of
      96             :          * the evaluation of the residual at x
      97             :          */
      98             :         DataContainer<data_t> evaluate(const DataContainer<data_t>& x) const;
      99             : 
     100             :         /**
     101             :          * @brief evaluate the residual at x and store in result
     102             :          *
     103             :          * @param[in] x input DataContainer (in the domain of the residual)
     104             :          * @param[out] result output DataContainer (in the range of the residual)
     105             :          */
     106             :         void evaluate(const DataContainer<data_t>& x, DataContainer<data_t>& result) const;
     107             : 
     108             :         /**
     109             :          * @brief return the Jacobian (first derivative) of the linear residual at x.
     110             :          * If A is set, then the Jacobian is A and this returns a copy of A.
     111             :          * If A is not set, then an Identity operator is returned.
     112             :          *
     113             :          * @param x input DataContainer (in the domain of the residual)
     114             :          *
     115             :          * @returns  a LinearOperator (the Jacobian)
     116             :          */
     117             :         LinearOperator<data_t> getJacobian(const DataContainer<data_t>& x);
     118             : 
     119             :     private:
     120             :         /// Descriptor of domain
     121             :         std::unique_ptr<DataDescriptor> domainDesc_;
     122             : 
     123             :         /// Descriptor of range
     124             :         std::unique_ptr<DataDescriptor> rangeDesc_;
     125             : 
     126             :         /// the operator A, nullptr implies no operator present
     127             :         std::unique_ptr<LinearOperator<data_t>> _operator{};
     128             : 
     129             :         /// optional  data vector b
     130             :         std::optional<DataContainer<data_t>> _dataVector{};
     131             :     };
     132             : 
     133             :     template <class data_t>
     134             :     bool operator==(const LinearResidual<data_t>& lhs, const LinearResidual<data_t>& rhs)
     135          32 :     {
     136          32 :         if (lhs.getDomainDescriptor() != rhs.getDomainDescriptor()) {
     137           0 :             return false;
     138           0 :         }
     139             : 
     140          32 :         if (lhs.getRangeDescriptor() != rhs.getRangeDescriptor()) {
     141           0 :             return false;
     142           0 :         }
     143             : 
     144          32 :         if (lhs.hasOperator() && rhs.hasOperator() && lhs.getOperator() != rhs.getOperator()) {
     145           0 :             return false;
     146           0 :         }
     147             : 
     148          32 :         if (lhs.hasDataVector() && rhs.hasDataVector()
     149          32 :             && lhs.getDataVector() != rhs.getDataVector()) {
     150           0 :             return false;
     151           0 :         }
     152             : 
     153          32 :         return true;
     154          32 :     }
     155             : 
     156             :     template <class data_t>
     157             :     bool operator!=(const LinearResidual<data_t>& lhs, const LinearResidual<data_t>& rhs)
     158          16 :     {
     159          16 :         return !(lhs == rhs);
     160          16 :     }
     161             : } // namespace elsa

Generated by: LCOV version 1.14