LCOV - code coverage report
Current view: top level - elsa/proximal_operators - ProximalOperator.h (source / functions) Hit Total Coverage
Test: coverage-all.lcov Lines: 20 21 95.2 %
Date: 2024-05-16 04:22:26 Functions: 52 62 83.9 %

          Line data    Source code
       1             : #pragma once
       2             : 
       3             : #include "elsaDefines.h"
       4             : #include "Cloneable.h"
       5             : #include "DataContainer.h"
       6             : #include "DataDescriptor.h"
       7             : #include "StrongTypes.h"
       8             : 
       9             : namespace elsa
      10             : {
      11             :     /// Customization point for ProximalOperators. For an object to be type erased by the
      12             :     /// ProximalOperator interface, you either need to provide a member function `apply` for the
      13             :     /// type, with the given parameters. Or you can overload this function.
      14             :     template <class T, class data_t>
      15             :     void applyProximal(const T& proximal, const DataContainer<data_t>& v, SelfType_t<data_t> tau,
      16             :                        DataContainer<data_t>& out)
      17          20 :     {
      18          20 :         proximal.apply(v, tau, out);
      19          20 :     }
      20             : 
      21             :     template <class T, class data_t>
      22             :     DataContainer<data_t> applyProximal(const T& proximal, const DataContainer<data_t>& v,
      23             :                                         SelfType_t<data_t> tau)
      24         116 :     {
      25         116 :         return proximal.apply(v, tau);
      26         116 :     }
      27             : 
      28             :     /**
      29             :      * @brief Base class representing a proximal operator prox.
      30             :      *
      31             :      * This class represents a proximal operator prox, expressed through its apply methods,
      32             :      * which implement the proximal operator of f with penalty r i.e.
      33             :      * @f$ prox_{f,\rho}(v) = argmin_{x}(f(x) + (\rho/2)ยท\| x - v \|^2_2). @f$
      34             :      *
      35             :      * The class implements type erasure. Classes, that bind to this wrapper should overload
      36             :      * either the `applyProximal`, or implement the `apply`  member function, with the signature
      37             :      * given in this class. Base types also are assumed to be Semi-Regular.
      38             :      *
      39             :      * References:
      40             :      * https://stanford.edu/~boyd/papers/pdf/admm_distr_stats.pdf
      41             :      *
      42             :      * @tparam data_t data type for the values of the operator, defaulting to real_t
      43             :      *
      44             :      * @author
      45             :      * - Andi Braimllari - initial code
      46             :      * - David Frank - Type Erasure
      47             :      */
      48             :     template <typename data_t = real_t>
      49             :     class ProximalOperator
      50             :     {
      51             :     private:
      52             :         /// Concept for proximal operators, they should be cloneable, and have an apply method
      53             :         struct ProxConcept {
      54          86 :             virtual ~ProxConcept() = default;
      55             :             virtual std::unique_ptr<ProxConcept> clone() const = 0;
      56             :             virtual DataContainer<data_t> apply(const DataContainer<data_t>& v,
      57             :                                                 SelfType_t<data_t> t) const = 0;
      58             :             virtual void apply(const DataContainer<data_t>& v, SelfType_t<data_t> t,
      59             :                                DataContainer<data_t>& out) const = 0;
      60             :         };
      61             : 
      62             :         /// Bridge wrapper for concrete types
      63             :         template <class T>
      64             :         struct ProxModel : public ProxConcept {
      65          86 :             ProxModel(T self) : self_(std::move(self)) {}
      66             : 
      67             :             /// Just clone the type (assumes regularity, i.e. copy constructible)
      68             :             std::unique_ptr<ProxConcept> clone() const override
      69          52 :             {
      70          52 :                 return std::make_unique<ProxModel<T>>(self_);
      71          52 :             }
      72             : 
      73             :             /// Apply proximal by calling `apply_proximal`, this enables flexible extension, without
      74             :             /// classes as well. The default implementation of `apply_proximal`,
      75             :             /// just calls the member function
      76             :             void apply(const DataContainer<data_t>& v, SelfType_t<data_t> t,
      77             :                        DataContainer<data_t>& out) const override
      78          20 :             {
      79          20 :                 applyProximal(self_, v, t, out);
      80          20 :             }
      81             : 
      82             :             DataContainer<data_t> apply(const DataContainer<data_t>& v,
      83             :                                         SelfType_t<data_t> t) const override
      84         116 :             {
      85         116 :                 return applyProximal(self_, v, t);
      86         116 :             }
      87             : 
      88             :         private:
      89             :             T self_;
      90             :         };
      91             : 
      92             :     public:
      93             :         /// defaulted default constructor for base-class (will point to nothing)
      94             :         ProximalOperator() = default;
      95             : 
      96             :         /// Type erasure constructor, taking everything that kan bind to the above provided
      97             :         /// interface
      98             :         template <typename T>
      99             :         ProximalOperator(T proxOp) : ptr_(std::make_unique<ProxModel<T>>(std::move(proxOp)))
     100          34 :         {
     101          34 :         }
     102             : 
     103             :         template <typename T>
     104             :         ProximalOperator& operator=(T proxOp)
     105             :         {
     106             :             ptr_ = std::make_unique<ProxModel<T>>(std::move(proxOp));
     107             :             return *this;
     108             :         }
     109             : 
     110             :         /// Copy constructor
     111             :         ProximalOperator(const ProximalOperator& other);
     112             : 
     113             :         /// Default move constructor
     114           0 :         ProximalOperator(ProximalOperator&& other) noexcept = default;
     115             : 
     116             :         /// Copy assignment
     117             :         ProximalOperator& operator=(const ProximalOperator& other);
     118             : 
     119             :         /// Default move assignment
     120             :         ProximalOperator& operator=(ProximalOperator&& other) noexcept = default;
     121             : 
     122             :         /// default destructor
     123          86 :         ~ProximalOperator() = default;
     124             : 
     125             :         /**
     126             :          * @brief apply the proximal operator to an element in the operator's domain
     127             :          *
     128             :          * @param[in] v input DataContainer
     129             :          * @param[in] t input Threshold
     130             :          *
     131             :          * @returns prox DataContainer containing the application of the proximal operator to
     132             :          * data v, i.e. in the range of the operator
     133             :          *
     134             :          * Please note: this method uses apply(v, t, prox(v)) to perform the actual operation.
     135             :          */
     136             :         auto apply(const DataContainer<data_t>& v, SelfType_t<data_t> t) const
     137             :             -> DataContainer<data_t>;
     138             : 
     139             :         /**
     140             :          * @brief apply the proximal operator to an element in the operator's domain
     141             :          *
     142             :          * @param[in] v input DataContainer
     143             :          * @param[in] t input Threshold
     144             :          * @param[out] prox output DataContainer
     145             :          *
     146             :          * Please note: this method calls the method applyImpl that has to be overridden in derived
     147             :          * classes. (Why is this method not virtual itself? Because you cannot have a non-virtual
     148             :          * function overloading a virtual one [apply with one vs. two arguments]).
     149             :          */
     150             :         void apply(const DataContainer<data_t>& v, SelfType_t<data_t> t,
     151             :                    DataContainer<data_t>& prox) const;
     152             : 
     153             :     private:
     154             :         std::unique_ptr<ProxConcept> ptr_ = {};
     155             :     };
     156             : } // namespace elsa

Generated by: LCOV version 1.14