LCOV - code coverage report
Current view: top level - elsa/functionals - MaskableSum.cpp (source / functions) Hit Total Coverage
Test: coverage-all.lcov Lines: 54 76 71.1 %
Date: 2024-05-16 04:22:26 Functions: 36 56 64.3 %

          Line data    Source code
       1             : #include "MaskableSum.h"
       2             : #include "DataContainer.h"
       3             : #include "DataDescriptor.h"
       4             : #include "Functional.h"
       5             : #include "TypeCasts.hpp"
       6             : 
       7             : #include <memory>
       8             : #include <numeric>
       9             : #include <vector>
      10             : 
      11             : namespace elsa
      12             : {
      13             :     template <typename data_t>
      14             :     MaskableSum<data_t>::MaskableSum(std::vector<std::unique_ptr<Functional<data_t>>> fs,
      15             :                                      std::optional<std::vector<bool>> mask)
      16             :         : Functional<data_t>(fs.at(0)->getDomainDescriptor()),
      17             :           mask{mask.value_or(std::vector<bool>(fs.size(), true))},
      18             :           functions(std::move(fs))
      19          46 :     {
      20             : 
      21          46 :         if (functions.empty()) {
      22             :             // Unreachable, Functional{} fails first
      23           0 :             throw Error{"MaskableSum: Must contain at least one Functional!"};
      24           0 :         }
      25             : 
      26          46 :         if (mask.has_value() && mask->size() != functions.size()) {
      27             :             // Unreachable, Functional{} fails first
      28           8 :             throw Error{"MaskableSum: Mask size must equal number of Functionals!"};
      29           8 :         }
      30             : 
      31          38 :         auto& firstDesc = functions.front()->getDomainDescriptor();
      32             : 
      33             :         // Determine if all descriptors are equal
      34          44 :         bool allEqual = std::all_of(functions.begin(), functions.end(), [&](const auto& x) {
      35          44 :             return x->getDomainDescriptor() == firstDesc;
      36          44 :         });
      37             : 
      38          38 :         if (!allEqual) {
      39           4 :             throw Error{"MaskableSum: Functionals must all have identical domains!"};
      40           4 :         }
      41          38 :     }
      42             : 
      43             :     template <typename data_t>
      44             :     void MaskableSum<data_t>::setMask(const std::vector<bool>& mask)
      45           6 :     {
      46           6 :         if (mask.size() != functions.size()) {
      47           4 :             throw Error{"MaskableSum: setMask with invalid mask size"};
      48           4 :         }
      49           2 :         this->mask = mask;
      50           2 :     }
      51             : 
      52             :     template <typename data_t>
      53             :     const std::vector<bool>& MaskableSum<data_t>::getMask() const
      54           0 :     {
      55           0 :         return mask;
      56           0 :     }
      57             : 
      58             :     template <typename data_t>
      59             :     const std::vector<std::unique_ptr<Functional<data_t>>>&
      60             :         MaskableSum<data_t>::getFunctions() const
      61           0 :     {
      62           0 :         return functions;
      63           0 :     }
      64             : 
      65             :     template <typename data_t>
      66             :     void MaskableSum<data_t>::setAll()
      67           2 :     {
      68           2 :         std::fill(mask.begin(), mask.end(), true);
      69           2 :     }
      70             : 
      71             :     template <typename data_t>
      72             :     size_t MaskableSum<data_t>::numFunctions() const
      73           0 :     {
      74           0 :         return functions.size();
      75           0 :     }
      76             : 
      77             :     template <typename data_t>
      78             :     data_t MaskableSum<data_t>::evaluateImpl(const DataContainer<data_t>& Rx) const
      79           4 :     {
      80           4 :         if (Rx.getDataDescriptor() != this->getDomainDescriptor()) {
      81           0 :             throw Error{"MaskableSum: Descriptor of argument is unexpected"};
      82           0 :         }
      83             : 
      84           4 :         return std::transform_reduce(functions.begin(), functions.end(), mask.begin(), data_t{0},
      85           4 :                                      std::plus<>{},
      86           8 :                                      [&](auto& fn, auto m) { return m ? fn->evaluate(Rx) : 0; });
      87           4 :     }
      88             : 
      89             :     template <typename data_t>
      90             :     void MaskableSum<data_t>::getGradientImpl(const DataContainer<data_t>& Rx,
      91             :                                               DataContainer<data_t>& out) const
      92           4 :     {
      93           4 :         if (Rx.getDataDescriptor() != this->getDomainDescriptor()) {
      94           0 :             throw Error("MaskableSum: Descriptor of argument is unexpected");
      95           0 :         }
      96             : 
      97           4 :         out.fill(0);
      98             : 
      99           4 :         auto grad = emptylike(Rx);
     100          12 :         for (std::size_t i = 0; i < functions.size(); ++i) {
     101           8 :             if (mask[i]) {
     102           6 :                 functions[i]->getGradient(Rx, grad);
     103           6 :                 out += grad;
     104           6 :             }
     105           8 :         }
     106           4 :     }
     107             : 
     108             :     template <typename data_t>
     109             :     LinearOperator<data_t> MaskableSum<data_t>::getHessianImpl(const DataContainer<data_t>&) const
     110           0 :     {
     111           0 :         throw NotImplementedError("MaskableSum: Hessian not implemented");
     112           0 :     }
     113             : 
     114             :     template <typename data_t>
     115             :     MaskableSum<data_t>* MaskableSum<data_t>::cloneImpl() const
     116           4 :     {
     117           4 :         std::vector<std::unique_ptr<Functional<data_t>>> copies;
     118           4 :         copies.reserve(functions.size());
     119           4 :         for (const auto& ptr : functions) {
     120           4 :             copies.push_back(ptr->clone());
     121           4 :         }
     122             : 
     123           4 :         return new MaskableSum<data_t>(std::move(copies), mask);
     124           4 :     }
     125             : 
     126             :     template <typename data_t>
     127             :     bool MaskableSum<data_t>::isEqual(const Functional<data_t>& other) const
     128           8 :     {
     129           8 :         if (!Functional<data_t>::isEqual(other)) {
     130           0 :             return false;
     131           0 :         }
     132             : 
     133           8 :         const auto& fn = downcast<const MaskableSum<data_t>>(other);
     134             : 
     135           8 :         if (mask != fn.mask) {
     136           0 :             return false;
     137           0 :         }
     138             : 
     139           8 :         return std::equal(functions.begin(), functions.end(), fn.functions.begin(),
     140           8 :                           [](const auto& l, const auto& r) { return (*l) == (*r); });
     141           8 :     }
     142             : 
     143             :     // ------------------------------------------
     144             :     // explicit template instantiation
     145             :     template class MaskableSum<float>;
     146             :     template class MaskableSum<double>;
     147             :     template class MaskableSum<complex<float>>;
     148             :     template class MaskableSum<complex<double>>;
     149             : 
     150             : } // namespace elsa

Generated by: LCOV version 1.14