Line data Source code
1 : #pragma once 2 : 3 : #include "BlockDescriptor.h" 4 : #include "DataContainer.h" 5 : #include "Functional.h" 6 : #include "IdenticalBlocksDescriptor.h" 7 : #include "RandomBlocksDescriptor.h" 8 : #include "TypeCasts.hpp" 9 : #include <algorithm> 10 : #include <memory> 11 : 12 : namespace elsa 13 : { 14 : namespace detail 15 : { 16 : /// Helper to create a vector of unique_ptrs from references with a clone method. 17 : template <class data_t, class... Ts> 18 : std::vector<std::unique_ptr<Functional<data_t>>> make_vector(Ts&&... ts) 19 6 : { 20 6 : std::vector<std::unique_ptr<Functional<data_t>>> v; 21 6 : v.reserve(sizeof...(ts)); 22 : 23 6 : (v.emplace_back(std::forward<Ts>(ts).clone()), ...); 24 6 : return v; 25 6 : } 26 : 27 : /// Create a BlockDescriptor given a list of functionals. If all functionals have the same 28 : /// data descriptor, a `IdenticalBlocksDescriptor` returned, else a `RandomBlocksDescriptor` 29 : /// is returned. 30 : template <class data_t> 31 : std::unique_ptr<BlockDescriptor> 32 : determineDescriptor(const std::vector<std::unique_ptr<Functional<data_t>>>& fns); 33 : } // namespace detail 34 : 35 : /** 36 : * @brief Class representing a separable sum of functionals. Given a sequence 37 : * of \f$k\f$ functions \f$ ( f_i )_{i=1}^k \f$, where \f$f_{i}: X_{i} \rightarrow (-\infty, 38 : * \infty]\f$, the separable sum \f$F\f$ is defined as: 39 : * 40 : * \f[ 41 : * F:X_{1}\times X_{2}\cdots\times X_{m} \rightarrow (-\infty, \infty] \\ 42 : * F(x_{1}, x_{2}, \cdots, x_{k}) = \sum_{i=1}^k f_{i}(x_{i}) 43 : * \f] 44 : * 45 : * The great benefit of the separable sum, is that its proximal is easily derived. 46 : * 47 : * @see CombinedProximal 48 : */ 49 : template <class data_t> 50 : class SeparableSum final : public Functional<data_t> 51 : { 52 : public: 53 : /// Create a separable sum from a vector of unique_ptrs to functionals 54 : explicit SeparableSum(std::vector<std::unique_ptr<Functional<data_t>>> fns); 55 : 56 : /// Create a separable sum from a single functional 57 : explicit SeparableSum(const Functional<data_t>& fn); 58 : 59 : /// Create a separable sum from two functionals 60 : SeparableSum(const Functional<data_t>& fn1, const Functional<data_t>& fn2); 61 : 62 : /// Create a separable sum from three functionals 63 : SeparableSum(const Functional<data_t>& fn1, const Functional<data_t>& fn2, 64 : const Functional<data_t>& fn3); 65 : 66 : /// Create a separable sum from variadic number of functionals 67 : template <class... Args> 68 : SeparableSum(const Functional<data_t>& fn1, const Functional<data_t>& fn2, 69 : const Functional<data_t>& fn3, const Functional<data_t>& fn4, Args&&... fns) 70 : : SeparableSum<data_t>( 71 : detail::make_vector<data_t>(fn1, fn2, fn3, fn4, std::forward<Args>(fns)...)) 72 2 : { 73 2 : } 74 : 75 : /// @brief Indicate if the functional has a simple to compute proximal 76 : bool isProxFriendly() const override; 77 : 78 : DataContainer<data_t> proximal(const DataContainer<data_t>& v, 79 : SelfType_t<data_t> t) const override; 80 : 81 : void proximal(const DataContainer<data_t>& v, SelfType_t<data_t> t, 82 : DataContainer<data_t>& out) const override; 83 : 84 : /** 85 : * @brief The convex conjugate of a separable sum is given as: 86 : * @f[ 87 : * f^*(x) = \sum_{i=0}^m f_i^*(x_i) 88 : * @f] 89 : */ 90 : data_t convexConjugate(const DataContainer<data_t>& x) const override; 91 : 92 : private: 93 : /// Evaluate functional. Requires `Rx` to be a blocked `DataContainer` 94 : /// (i.e. its descriptor is of type `BlockDescriptor`), the functions 95 : /// throws if not meet. 96 : data_t evaluateImpl(const DataContainer<data_t>& Rx) const override; 97 : 98 : /// The derivative of the sum of functions, is the sum of the derivatives 99 : void getGradientImpl(const DataContainer<data_t>& Rx, 100 : DataContainer<data_t>& out) const override; 101 : 102 : /// Not yet implemented 103 : LinearOperator<data_t> getHessianImpl(const DataContainer<data_t>& Rx) const override; 104 : 105 : /// Polymorphic clone implementations 106 : SeparableSum<data_t>* cloneImpl() const override; 107 : 108 : /// Polymorphic equalty implementations 109 : bool isEqual(const Functional<data_t>& other) const override; 110 : 111 : std::vector<std::unique_ptr<Functional<data_t>>> fns_{}; 112 : }; 113 : } // namespace elsa