Line data Source code
1 : #include "MaskableSum.h" 2 : #include "DataContainer.h" 3 : #include "DataDescriptor.h" 4 : #include "Functional.h" 5 : #include "TypeCasts.hpp" 6 : 7 : #include <memory> 8 : #include <numeric> 9 : #include <vector> 10 : 11 : namespace elsa 12 : { 13 : template <typename data_t> 14 : MaskableSum<data_t>::MaskableSum(std::vector<std::unique_ptr<Functional<data_t>>> fs, 15 : std::optional<std::vector<bool>> mask) 16 : : Functional<data_t>(fs.at(0)->getDomainDescriptor()), 17 : mask{mask.value_or(std::vector<bool>(fs.size(), true))}, 18 : functions(std::move(fs)) 19 46 : { 20 : 21 46 : if (functions.empty()) { 22 : // Unreachable, Functional{} fails first 23 0 : throw Error{"MaskableSum: Must contain at least one Functional!"}; 24 0 : } 25 : 26 46 : if (mask.has_value() && mask->size() != functions.size()) { 27 : // Unreachable, Functional{} fails first 28 8 : throw Error{"MaskableSum: Mask size must equal number of Functionals!"}; 29 8 : } 30 : 31 38 : auto& firstDesc = functions.front()->getDomainDescriptor(); 32 : 33 : // Determine if all descriptors are equal 34 44 : bool allEqual = std::all_of(functions.begin(), functions.end(), [&](const auto& x) { 35 44 : return x->getDomainDescriptor() == firstDesc; 36 44 : }); 37 : 38 38 : if (!allEqual) { 39 4 : throw Error{"MaskableSum: Functionals must all have identical domains!"}; 40 4 : } 41 38 : } 42 : 43 : template <typename data_t> 44 : void MaskableSum<data_t>::setMask(const std::vector<bool>& mask) 45 6 : { 46 6 : if (mask.size() != functions.size()) { 47 4 : throw Error{"MaskableSum: setMask with invalid mask size"}; 48 4 : } 49 2 : this->mask = mask; 50 2 : } 51 : 52 : template <typename data_t> 53 : const std::vector<bool>& MaskableSum<data_t>::getMask() const 54 0 : { 55 0 : return mask; 56 0 : } 57 : 58 : template <typename data_t> 59 : const std::vector<std::unique_ptr<Functional<data_t>>>& 60 : MaskableSum<data_t>::getFunctions() const 61 0 : { 62 0 : return functions; 63 0 : } 64 : 65 : template <typename data_t> 66 : void MaskableSum<data_t>::setAll() 67 2 : { 68 2 : std::fill(mask.begin(), mask.end(), true); 69 2 : } 70 : 71 : template <typename data_t> 72 : size_t MaskableSum<data_t>::numFunctions() const 73 0 : { 74 0 : return functions.size(); 75 0 : } 76 : 77 : template <typename data_t> 78 : data_t MaskableSum<data_t>::evaluateImpl(const DataContainer<data_t>& Rx) const 79 4 : { 80 4 : if (Rx.getDataDescriptor() != this->getDomainDescriptor()) { 81 0 : throw Error{"MaskableSum: Descriptor of argument is unexpected"}; 82 0 : } 83 : 84 4 : return std::transform_reduce(functions.begin(), functions.end(), mask.begin(), data_t{0}, 85 4 : std::plus<>{}, 86 8 : [&](auto& fn, auto m) { return m ? fn->evaluate(Rx) : 0; }); 87 4 : } 88 : 89 : template <typename data_t> 90 : void MaskableSum<data_t>::getGradientImpl(const DataContainer<data_t>& Rx, 91 : DataContainer<data_t>& out) const 92 4 : { 93 4 : if (Rx.getDataDescriptor() != this->getDomainDescriptor()) { 94 0 : throw Error("MaskableSum: Descriptor of argument is unexpected"); 95 0 : } 96 : 97 4 : out.fill(0); 98 : 99 4 : auto grad = emptylike(Rx); 100 12 : for (std::size_t i = 0; i < functions.size(); ++i) { 101 8 : if (mask[i]) { 102 6 : functions[i]->getGradient(Rx, grad); 103 6 : out += grad; 104 6 : } 105 8 : } 106 4 : } 107 : 108 : template <typename data_t> 109 : LinearOperator<data_t> 110 : MaskableSum<data_t>::getHessianImpl(const DataContainer<data_t>& Rx) const 111 0 : { 112 0 : if (Rx.getDataDescriptor() != this->getDomainDescriptor()) { 113 0 : throw Error("MaskableSum: Descriptor of argument is unexpected"); 114 0 : } 115 0 : LinearOperator<data_t> H{this->getDomainDescriptor(), this->getDomainDescriptor()}; 116 0 : for (std::size_t i = 0; i < functions.size(); ++i) { 117 0 : if (mask[i]) { 118 0 : H = H + functions[i]->getHessian(Rx); 119 0 : } 120 0 : } 121 0 : return H; 122 0 : } 123 : 124 : template <typename data_t> 125 : MaskableSum<data_t>* MaskableSum<data_t>::cloneImpl() const 126 4 : { 127 4 : std::vector<std::unique_ptr<Functional<data_t>>> copies; 128 4 : copies.reserve(functions.size()); 129 4 : for (const auto& ptr : functions) { 130 4 : copies.push_back(ptr->clone()); 131 4 : } 132 : 133 4 : return new MaskableSum<data_t>(std::move(copies), mask); 134 4 : } 135 : 136 : template <typename data_t> 137 : bool MaskableSum<data_t>::isEqual(const Functional<data_t>& other) const 138 8 : { 139 8 : if (!Functional<data_t>::isEqual(other)) { 140 0 : return false; 141 0 : } 142 : 143 8 : const auto& fn = downcast<const MaskableSum<data_t>>(other); 144 : 145 8 : if (mask != fn.mask) { 146 0 : return false; 147 0 : } 148 : 149 8 : return std::equal(functions.begin(), functions.end(), fn.functions.begin(), 150 8 : [](const auto& l, const auto& r) { return (*l) == (*r); }); 151 8 : } 152 : 153 : // ------------------------------------------ 154 : // explicit template instantiation 155 : template class MaskableSum<float>; 156 : template class MaskableSum<double>; 157 : template class MaskableSum<complex<float>>; 158 : template class MaskableSum<complex<double>>; 159 : 160 : } // namespace elsa