LCOV - code coverage report
Current view: top level - elsa/functionals - SeparableSum.h (source / functions) Hit Total Coverage
Test: coverage-all.lcov Lines: 8 8 100.0 %
Date: 2024-05-16 04:22:26 Functions: 8 8 100.0 %

          Line data    Source code
       1             : #pragma once
       2             : 
       3             : #include "BlockDescriptor.h"
       4             : #include "DataContainer.h"
       5             : #include "Functional.h"
       6             : #include "IdenticalBlocksDescriptor.h"
       7             : #include "RandomBlocksDescriptor.h"
       8             : #include "TypeCasts.hpp"
       9             : #include <algorithm>
      10             : #include <memory>
      11             : 
      12             : namespace elsa
      13             : {
      14             :     namespace detail
      15             :     {
      16             :         /// Helper to create a vector of unique_ptrs from references with a clone method.
      17             :         template <class data_t, class... Ts>
      18             :         std::vector<std::unique_ptr<Functional<data_t>>> make_vector(Ts&&... ts)
      19           6 :         {
      20           6 :             std::vector<std::unique_ptr<Functional<data_t>>> v;
      21           6 :             v.reserve(sizeof...(ts));
      22             : 
      23           6 :             (v.emplace_back(std::forward<Ts>(ts).clone()), ...);
      24           6 :             return v;
      25           6 :         }
      26             : 
      27             :         /// Create a BlockDescriptor given a list of functionals. If all functionals have the same
      28             :         /// data descriptor, a `IdenticalBlocksDescriptor` returned, else a `RandomBlocksDescriptor`
      29             :         /// is returned.
      30             :         template <class data_t>
      31             :         std::unique_ptr<BlockDescriptor>
      32             :             determineDescriptor(const std::vector<std::unique_ptr<Functional<data_t>>>& fns);
      33             :     } // namespace detail
      34             : 
      35             :     /**
      36             :      * @brief Class representing a separable sum of functionals. Given a sequence
      37             :      * of \f$k\f$ functions \f$ ( f_i )_{i=1}^k \f$, where \f$f_{i}: X_{i} \rightarrow (-\infty,
      38             :      * \infty]\f$, the separable sum \f$F\f$ is defined as:
      39             :      *
      40             :      * \f[
      41             :      * F:X_{1}\times X_{2}\cdots\times X_{m} \rightarrow (-\infty, \infty] \\
      42             :      * F(x_{1}, x_{2}, \cdots, x_{k}) = \sum_{i=1}^k f_{i}(x_{i})
      43             :      * \f]
      44             :      *
      45             :      * The great benefit of the separable sum, is that its proximal is easily derived.
      46             :      *
      47             :      * @see CombinedProximal
      48             :      */
      49             :     template <class data_t>
      50             :     class SeparableSum final : public Functional<data_t>
      51             :     {
      52             :     public:
      53             :         /// Create a separable sum from a vector of unique_ptrs to functionals
      54             :         explicit SeparableSum(std::vector<std::unique_ptr<Functional<data_t>>> fns);
      55             : 
      56             :         /// Create a separable sum from a single functional
      57             :         explicit SeparableSum(const Functional<data_t>& fn);
      58             : 
      59             :         /// Create a separable sum from two functionals
      60             :         SeparableSum(const Functional<data_t>& fn1, const Functional<data_t>& fn2);
      61             : 
      62             :         /// Create a separable sum from three functionals
      63             :         SeparableSum(const Functional<data_t>& fn1, const Functional<data_t>& fn2,
      64             :                      const Functional<data_t>& fn3);
      65             : 
      66             :         /// Create a separable sum from variadic number of functionals
      67             :         template <class... Args>
      68             :         SeparableSum(const Functional<data_t>& fn1, const Functional<data_t>& fn2,
      69             :                      const Functional<data_t>& fn3, const Functional<data_t>& fn4, Args&&... fns)
      70             :             : SeparableSum<data_t>(
      71             :                 detail::make_vector<data_t>(fn1, fn2, fn3, fn4, std::forward<Args>(fns)...))
      72           2 :         {
      73           2 :         }
      74             : 
      75             :         /// @brief Indicate if the functional has a simple to compute proximal
      76             :         bool isProxFriendly() const override;
      77             : 
      78             :         DataContainer<data_t> proximal(const DataContainer<data_t>& v,
      79             :                                        SelfType_t<data_t> t) const override;
      80             : 
      81             :         void proximal(const DataContainer<data_t>& v, SelfType_t<data_t> t,
      82             :                       DataContainer<data_t>& out) const override;
      83             : 
      84             :         /**
      85             :          * @brief The convex conjugate of a separable sum is given as:
      86             :          * @f[
      87             :          * f^*(x) = \sum_{i=0}^m f_i^*(x_i)
      88             :          * @f]
      89             :          */
      90             :         data_t convexConjugate(const DataContainer<data_t>& x) const override;
      91             : 
      92             :     private:
      93             :         /// Evaluate functional. Requires `Rx` to be a blocked `DataContainer`
      94             :         /// (i.e. its descriptor is of type `BlockDescriptor`), the functions
      95             :         /// throws if not meet.
      96             :         data_t evaluateImpl(const DataContainer<data_t>& Rx) const override;
      97             : 
      98             :         /// The derivative of the sum of functions, is the sum of the derivatives
      99             :         void getGradientImpl(const DataContainer<data_t>& Rx,
     100             :                              DataContainer<data_t>& out) const override;
     101             : 
     102             :         /// Not yet implemented
     103             :         LinearOperator<data_t> getHessianImpl(const DataContainer<data_t>& Rx) const override;
     104             : 
     105             :         /// Polymorphic clone implementations
     106             :         SeparableSum<data_t>* cloneImpl() const override;
     107             : 
     108             :         /// Polymorphic equalty implementations
     109             :         bool isEqual(const Functional<data_t>& other) const override;
     110             : 
     111             :         std::vector<std::unique_ptr<Functional<data_t>>> fns_{};
     112             :     };
     113             : } // namespace elsa

Generated by: LCOV version 1.14