LCOV - code coverage report
Current view: top level - elsa/functionals - SeparableSum.cpp (source / functions) Hit Total Coverage
Test: coverage-all.lcov Lines: 40 107 37.4 %
Date: 2025-01-14 06:38:49 Functions: 20 32 62.5 %

          Line data    Source code
       1             : #include "SeparableSum.h"
       2             : #include "DataContainer.h"
       3             : #include "Functional.h"
       4             : #include "IdenticalBlocksDescriptor.h"
       5             : #include "TypeCasts.hpp"
       6             : #include "CombinedProximal.h"
       7             : 
       8             : #include <memory>
       9             : 
      10             : namespace elsa
      11             : {
      12             :     template <class data_t>
      13             :     SeparableSum<data_t>::SeparableSum(std::vector<std::unique_ptr<Functional<data_t>>> fns)
      14             :         : Functional<data_t>(*detail::determineDescriptor(fns)), fns_(std::move(fns))
      15          14 :     {
      16          14 :     }
      17             : 
      18             :     template <class data_t>
      19             :     SeparableSum<data_t>::SeparableSum(const Functional<data_t>& fn)
      20             :         : Functional<data_t>(IdenticalBlocksDescriptor(1, fn.getDomainDescriptor()))
      21           2 :     {
      22           2 :         fns_.push_back(fn.clone());
      23           2 :     }
      24             : 
      25             :     template <class data_t>
      26             :     SeparableSum<data_t>::SeparableSum(const Functional<data_t>& fn1, const Functional<data_t>& fn2)
      27             :         : SeparableSum<data_t>(detail::make_vector<data_t>(fn1, fn2))
      28           2 :     {
      29           2 :     }
      30             : 
      31             :     template <class data_t>
      32             :     SeparableSum<data_t>::SeparableSum(const Functional<data_t>& fn1, const Functional<data_t>& fn2,
      33             :                                        const Functional<data_t>& fn3)
      34             :         : SeparableSum<data_t>(detail::make_vector<data_t>(fn1, fn2, fn3))
      35           2 :     {
      36           2 :     }
      37             : 
      38             :     template <typename data_t>
      39             :     bool SeparableSum<data_t>::isProxFriendly() const
      40           0 :     {
      41           0 :         return true;
      42           0 :     }
      43             : 
      44             :     template <class data_t>
      45             :     data_t SeparableSum<data_t>::evaluateImpl(const DataContainer<data_t>& Rx) const
      46           8 :     {
      47           8 :         if (!is<BlockDescriptor>(Rx.getDataDescriptor())) {
      48           0 :             throw Error("SeparableSum: Blocked DataContainer expected");
      49           0 :         }
      50             : 
      51           8 :         if (Rx.getDataDescriptor() != this->getDomainDescriptor()) {
      52           0 :             throw Error("SeparableSum: Descriptor of argument is unexpected");
      53           0 :         }
      54             : 
      55           8 :         auto& blockdesc = downcast_safe<BlockDescriptor>(Rx.getDataDescriptor());
      56             : 
      57           8 :         data_t sum{0};
      58          28 :         for (int i = 0; i < blockdesc.getNumberOfBlocks(); ++i) {
      59          20 :             sum += fns_[asUnsigned(i)]->evaluate(Rx.getBlock(i));
      60          20 :         }
      61           8 :         return sum;
      62           8 :     }
      63             : 
      64             :     template <typename data_t>
      65             :     data_t SeparableSum<data_t>::convexConjugate(const DataContainer<data_t>& x) const
      66           0 :     {
      67           0 :         if (!is<BlockDescriptor>(x.getDataDescriptor())) {
      68           0 :             throw Error("SeparableSum: Input to convex conjugate needs to be blocked");
      69           0 :         }
      70             : 
      71           0 :         auto& blockedDesc = downcast_safe<BlockDescriptor>(x.getDataDescriptor());
      72             : 
      73           0 :         if (blockedDesc.getNumberOfBlocks() != asSigned(fns_.size())) {
      74           0 :             throw Error("SeparableSum: unequal number of blocks ({}) and number of functions ({}) ",
      75           0 :                         blockedDesc.getNumberOfBlocks(), fns_.size());
      76           0 :         }
      77             : 
      78           0 :         data_t sum = 0;
      79             : 
      80           0 :         for (int i = 0; i < blockedDesc.getNumberOfBlocks(); ++i) {
      81           0 :             sum += fns_[asUnsigned(i)]->convexConjugate(x.getBlock(i));
      82           0 :         }
      83           0 :         return sum;
      84           0 :     }
      85             : 
      86             :     template <class data_t>
      87             :     void SeparableSum<data_t>::getGradientImpl(const DataContainer<data_t>& Rx,
      88             :                                                DataContainer<data_t>& out) const
      89           0 :     {
      90           0 :         if (!is<BlockDescriptor>(Rx.getDataDescriptor())) {
      91           0 :             throw Error("SeparableSum: Blocked DataContainer expected for gradient");
      92           0 :         }
      93             : 
      94           0 :         if (Rx.getDataDescriptor() != this->getDomainDescriptor()) {
      95           0 :             throw Error("SeparableSum: Descriptor of argument is unexpected");
      96           0 :         }
      97             : 
      98           0 :         auto& blockdesc = downcast_safe<BlockDescriptor>(Rx.getDataDescriptor());
      99             : 
     100           0 :         for (int i = 0; i < blockdesc.getNumberOfBlocks(); ++i) {
     101           0 :             auto outview = out.getBlock(i);
     102           0 :             fns_[asUnsigned(i)]->getGradient(Rx.getBlock(i), outview);
     103           0 :         }
     104           0 :     }
     105             : 
     106             :     template <class data_t>
     107             :     LinearOperator<data_t> SeparableSum<data_t>::getHessianImpl(const DataContainer<data_t>&) const
     108           0 :     {
     109           0 :         throw NotImplementedError("SeparableSum: Hessian not implemented");
     110           0 :     }
     111             : 
     112             :     template <typename data_t>
     113             :     DataContainer<data_t> SeparableSum<data_t>::proximal(const DataContainer<data_t>& v,
     114             :                                                          SelfType_t<data_t> tau) const
     115           0 :     {
     116           0 :         auto out = emptylike(v);
     117           0 :         proximal(v, tau, out);
     118           0 :         return out;
     119           0 :     }
     120             : 
     121             :     template <typename data_t>
     122             :     void SeparableSum<data_t>::proximal(const DataContainer<data_t>& v, SelfType_t<data_t> t,
     123             :                                         DataContainer<data_t>& out) const
     124           0 :     {
     125           0 :         if (!is<BlockDescriptor>(v.getDataDescriptor())) {
     126           0 :             throw Error("SeparableSum: Input to proximal needs to be blocked");
     127           0 :         }
     128             : 
     129           0 :         auto& blockedDesc = downcast_safe<BlockDescriptor>(v.getDataDescriptor());
     130             : 
     131           0 :         if (blockedDesc.getNumberOfBlocks() != asSigned(fns_.size())) {
     132           0 :             throw Error("SeparableSum: unequal number of blocks ({}) and number of functions ({}) ",
     133           0 :                         blockedDesc.getNumberOfBlocks(), fns_.size());
     134           0 :         }
     135             : 
     136           0 :         for (int i = 0; i < blockedDesc.getNumberOfBlocks(); ++i) {
     137           0 :             auto outview = out.getBlock(i);
     138           0 :             auto inview = v.getBlock(i);
     139             : 
     140           0 :             fns_[i]->proximal(inview, t, outview);
     141           0 :         }
     142           0 :     }
     143             : 
     144             :     template <class data_t>
     145             :     SeparableSum<data_t>* SeparableSum<data_t>::cloneImpl() const
     146           8 :     {
     147           8 :         std::vector<std::unique_ptr<Functional<data_t>>> copyfns;
     148          28 :         for (std::size_t i = 0; i < fns_.size(); ++i) {
     149          20 :             copyfns.push_back(fns_[i]->clone());
     150          20 :         }
     151             : 
     152           8 :         return new SeparableSum<data_t>(std::move(copyfns));
     153           8 :     }
     154             : 
     155             :     template <class data_t>
     156             :     bool SeparableSum<data_t>::isEqual(const Functional<data_t>& other) const
     157           8 :     {
     158           8 :         if (!Functional<data_t>::isEqual(other)) {
     159           0 :             return false;
     160           0 :         }
     161             : 
     162           8 :         const auto& fn = downcast<const SeparableSum<data_t>>(other);
     163           8 :         return std::equal(fns_.begin(), fns_.end(), fn.fns_.begin(),
     164          20 :                           [](const auto& l, const auto& r) { return (*l) == (*r); });
     165           8 :     }
     166             : 
     167             :     namespace detail
     168             :     {
     169             :         template <class data_t>
     170             :         std::unique_ptr<BlockDescriptor>
     171             :             determineDescriptor(const std::vector<std::unique_ptr<Functional<data_t>>>& fns)
     172          14 :         {
     173             :             // For now assume non empty
     174          14 :             auto& firstDesc = fns.front()->getDomainDescriptor();
     175             : 
     176             :             // Determine if all descriptors are equal
     177          38 :             bool allEqual = std::all_of(fns.begin(), fns.end(), [&](const auto& x) {
     178          38 :                 return x->getDomainDescriptor() == firstDesc;
     179          38 :             });
     180             : 
     181             :             // Then we can return an identical block descriptor
     182          14 :             if (allEqual) {
     183          14 :                 return std::make_unique<IdenticalBlocksDescriptor>(fns.size(), firstDesc);
     184          14 :             }
     185             : 
     186             :             // There are different descriptors, so extract them from the vector of functionals
     187           0 :             std::vector<std::unique_ptr<DataDescriptor>> descriptors;
     188           0 :             descriptors.reserve(fns.size());
     189           0 :             for (const auto& f : fns) {
     190           0 :                 descriptors.push_back(f->getDomainDescriptor().clone());
     191           0 :             }
     192             : 
     193           0 :             return std::make_unique<RandomBlocksDescriptor>(std::move(descriptors));
     194           0 :         }
     195             :     } // namespace detail
     196             : 
     197             :     // ------------------------------------------
     198             :     // explicit template instantiation
     199             :     template class SeparableSum<float>;
     200             :     template class SeparableSum<double>;
     201             : 
     202             :     template std::unique_ptr<BlockDescriptor> detail::determineDescriptor<float>(
     203             :         const std::vector<std::unique_ptr<Functional<float>>>& fns);
     204             :     template std::unique_ptr<BlockDescriptor> detail::determineDescriptor<double>(
     205             :         const std::vector<std::unique_ptr<Functional<double>>>& fns);
     206             : } // namespace elsa

Generated by: LCOV version 1.14