Line data Source code
1 : #include "SeparableSum.h" 2 : #include "DataContainer.h" 3 : #include "Functional.h" 4 : #include "IdenticalBlocksDescriptor.h" 5 : #include "TypeCasts.hpp" 6 : #include "CombinedProximal.h" 7 : 8 : #include <memory> 9 : 10 : namespace elsa 11 : { 12 : template <class data_t> 13 : SeparableSum<data_t>::SeparableSum(std::vector<std::unique_ptr<Functional<data_t>>> fns) 14 : : Functional<data_t>(*detail::determineDescriptor(fns)), fns_(std::move(fns)) 15 14 : { 16 14 : } 17 : 18 : template <class data_t> 19 : SeparableSum<data_t>::SeparableSum(const Functional<data_t>& fn) 20 : : Functional<data_t>(IdenticalBlocksDescriptor(1, fn.getDomainDescriptor())) 21 2 : { 22 2 : fns_.push_back(fn.clone()); 23 2 : } 24 : 25 : template <class data_t> 26 : SeparableSum<data_t>::SeparableSum(const Functional<data_t>& fn1, const Functional<data_t>& fn2) 27 : : SeparableSum<data_t>(detail::make_vector<data_t>(fn1, fn2)) 28 2 : { 29 2 : } 30 : 31 : template <class data_t> 32 : SeparableSum<data_t>::SeparableSum(const Functional<data_t>& fn1, const Functional<data_t>& fn2, 33 : const Functional<data_t>& fn3) 34 : : SeparableSum<data_t>(detail::make_vector<data_t>(fn1, fn2, fn3)) 35 2 : { 36 2 : } 37 : 38 : template <typename data_t> 39 : bool SeparableSum<data_t>::isProxFriendly() const 40 0 : { 41 0 : return true; 42 0 : } 43 : 44 : template <class data_t> 45 : data_t SeparableSum<data_t>::evaluateImpl(const DataContainer<data_t>& Rx) const 46 8 : { 47 8 : if (!is<BlockDescriptor>(Rx.getDataDescriptor())) { 48 0 : throw Error("SeparableSum: Blocked DataContainer expected"); 49 0 : } 50 : 51 8 : if (Rx.getDataDescriptor() != this->getDomainDescriptor()) { 52 0 : throw Error("SeparableSum: Descriptor of argument is unexpected"); 53 0 : } 54 : 55 8 : auto& blockdesc = downcast_safe<BlockDescriptor>(Rx.getDataDescriptor()); 56 : 57 8 : data_t sum{0}; 58 28 : for (int i = 0; i < blockdesc.getNumberOfBlocks(); ++i) { 59 20 : sum += fns_[asUnsigned(i)]->evaluate(Rx.getBlock(i)); 60 20 : } 61 8 : return sum; 62 8 : } 63 : 64 : template <typename data_t> 65 : data_t SeparableSum<data_t>::convexConjugate(const DataContainer<data_t>& x) const 66 0 : { 67 0 : if (!is<BlockDescriptor>(x.getDataDescriptor())) { 68 0 : throw Error("SeparableSum: Input to convex conjugate needs to be blocked"); 69 0 : } 70 : 71 0 : auto& blockedDesc = downcast_safe<BlockDescriptor>(x.getDataDescriptor()); 72 : 73 0 : if (blockedDesc.getNumberOfBlocks() != asSigned(fns_.size())) { 74 0 : throw Error("SeparableSum: unequal number of blocks ({}) and number of functions ({}) ", 75 0 : blockedDesc.getNumberOfBlocks(), fns_.size()); 76 0 : } 77 : 78 0 : data_t sum = 0; 79 : 80 0 : for (int i = 0; i < blockedDesc.getNumberOfBlocks(); ++i) { 81 0 : sum += fns_[asUnsigned(i)]->convexConjugate(x.getBlock(i)); 82 0 : } 83 0 : return sum; 84 0 : } 85 : 86 : template <class data_t> 87 : void SeparableSum<data_t>::getGradientImpl(const DataContainer<data_t>& Rx, 88 : DataContainer<data_t>& out) const 89 0 : { 90 0 : if (!is<BlockDescriptor>(Rx.getDataDescriptor())) { 91 0 : throw Error("SeparableSum: Blocked DataContainer expected for gradient"); 92 0 : } 93 : 94 0 : if (Rx.getDataDescriptor() != this->getDomainDescriptor()) { 95 0 : throw Error("SeparableSum: Descriptor of argument is unexpected"); 96 0 : } 97 : 98 0 : auto& blockdesc = downcast_safe<BlockDescriptor>(Rx.getDataDescriptor()); 99 : 100 0 : for (int i = 0; i < blockdesc.getNumberOfBlocks(); ++i) { 101 0 : auto outview = out.getBlock(i); 102 0 : fns_[asUnsigned(i)]->getGradient(Rx.getBlock(i), outview); 103 0 : } 104 0 : } 105 : 106 : template <class data_t> 107 : LinearOperator<data_t> SeparableSum<data_t>::getHessianImpl(const DataContainer<data_t>&) const 108 0 : { 109 0 : throw NotImplementedError("SeparableSum: Hessian not implemented"); 110 0 : } 111 : 112 : template <typename data_t> 113 : DataContainer<data_t> SeparableSum<data_t>::proximal(const DataContainer<data_t>& v, 114 : SelfType_t<data_t> tau) const 115 0 : { 116 0 : auto out = emptylike(v); 117 0 : proximal(v, tau, out); 118 0 : return out; 119 0 : } 120 : 121 : template <typename data_t> 122 : void SeparableSum<data_t>::proximal(const DataContainer<data_t>& v, SelfType_t<data_t> t, 123 : DataContainer<data_t>& out) const 124 0 : { 125 0 : if (!is<BlockDescriptor>(v.getDataDescriptor())) { 126 0 : throw Error("SeparableSum: Input to proximal needs to be blocked"); 127 0 : } 128 : 129 0 : auto& blockedDesc = downcast_safe<BlockDescriptor>(v.getDataDescriptor()); 130 : 131 0 : if (blockedDesc.getNumberOfBlocks() != asSigned(fns_.size())) { 132 0 : throw Error("SeparableSum: unequal number of blocks ({}) and number of functions ({}) ", 133 0 : blockedDesc.getNumberOfBlocks(), fns_.size()); 134 0 : } 135 : 136 0 : for (int i = 0; i < blockedDesc.getNumberOfBlocks(); ++i) { 137 0 : auto outview = out.getBlock(i); 138 0 : auto inview = v.getBlock(i); 139 : 140 0 : fns_[i]->proximal(inview, t, outview); 141 0 : } 142 0 : } 143 : 144 : template <class data_t> 145 : SeparableSum<data_t>* SeparableSum<data_t>::cloneImpl() const 146 8 : { 147 8 : std::vector<std::unique_ptr<Functional<data_t>>> copyfns; 148 28 : for (std::size_t i = 0; i < fns_.size(); ++i) { 149 20 : copyfns.push_back(fns_[i]->clone()); 150 20 : } 151 : 152 8 : return new SeparableSum<data_t>(std::move(copyfns)); 153 8 : } 154 : 155 : template <class data_t> 156 : bool SeparableSum<data_t>::isEqual(const Functional<data_t>& other) const 157 8 : { 158 8 : if (!Functional<data_t>::isEqual(other)) { 159 0 : return false; 160 0 : } 161 : 162 8 : const auto& fn = downcast<const SeparableSum<data_t>>(other); 163 8 : return std::equal(fns_.begin(), fns_.end(), fn.fns_.begin(), 164 20 : [](const auto& l, const auto& r) { return (*l) == (*r); }); 165 8 : } 166 : 167 : namespace detail 168 : { 169 : template <class data_t> 170 : std::unique_ptr<BlockDescriptor> 171 : determineDescriptor(const std::vector<std::unique_ptr<Functional<data_t>>>& fns) 172 14 : { 173 : // For now assume non empty 174 14 : auto& firstDesc = fns.front()->getDomainDescriptor(); 175 : 176 : // Determine if all descriptors are equal 177 38 : bool allEqual = std::all_of(fns.begin(), fns.end(), [&](const auto& x) { 178 38 : return x->getDomainDescriptor() == firstDesc; 179 38 : }); 180 : 181 : // Then we can return an identical block descriptor 182 14 : if (allEqual) { 183 14 : return std::make_unique<IdenticalBlocksDescriptor>(fns.size(), firstDesc); 184 14 : } 185 : 186 : // There are different descriptors, so extract them from the vector of functionals 187 0 : std::vector<std::unique_ptr<DataDescriptor>> descriptors; 188 0 : descriptors.reserve(fns.size()); 189 0 : for (const auto& f : fns) { 190 0 : descriptors.push_back(f->getDomainDescriptor().clone()); 191 0 : } 192 : 193 0 : return std::make_unique<RandomBlocksDescriptor>(std::move(descriptors)); 194 0 : } 195 : } // namespace detail 196 : 197 : // ------------------------------------------ 198 : // explicit template instantiation 199 : template class SeparableSum<float>; 200 : template class SeparableSum<double>; 201 : 202 : template std::unique_ptr<BlockDescriptor> detail::determineDescriptor<float>( 203 : const std::vector<std::unique_ptr<Functional<float>>>& fns); 204 : template std::unique_ptr<BlockDescriptor> detail::determineDescriptor<double>( 205 : const std::vector<std::unique_ptr<Functional<double>>>& fns); 206 : } // namespace elsa