Line data Source code
1 : #include "ProximalMixedL21Norm.h" 2 : #include "DataContainer.h" 3 : #include "BlockDescriptor.h" 4 : #include "spdlog/fmt/bundled/core.h" 5 : 6 : namespace elsa 7 : { 8 : template <class data_t> 9 : ProximalMixedL21Norm<data_t>::ProximalMixedL21Norm(data_t sigma) : sigma_(sigma) 10 20 : { 11 20 : } 12 : 13 : template <class data_t> 14 : DataContainer<data_t> ProximalMixedL21Norm<data_t>::apply(const DataContainer<data_t>& v, 15 : SelfType_t<data_t> t) const 16 20 : { 17 20 : DataContainer<data_t> out{v.getDataDescriptor()}; 18 20 : apply(v, t, out); 19 20 : return out; 20 20 : } 21 : 22 : template <class data_t> 23 : void ProximalMixedL21Norm<data_t>::apply(const DataContainer<data_t>& v, SelfType_t<data_t> t, 24 : DataContainer<data_t>& prox) const 25 20 : { 26 20 : if (!is<BlockDescriptor>(v.getDataDescriptor())) { 27 0 : throw Error("ProximalMixedL21Norm: Blocked DataContainer expected"); 28 0 : } 29 : 30 20 : auto p21norm = v.pL2Norm(); 31 : 32 : // set each block of prox to be tmp 33 64 : for (int i = 0; i < v.getNumberOfBlocks(); ++i) { 34 44 : prox.getBlock(i) = p21norm; 35 44 : } 36 : 37 20 : auto tau = t * sigma_; 38 20 : prox = (1 - tau / maximum(prox, tau)) * v; 39 20 : } 40 : 41 : // ------------------------------------------ 42 : // explicit template instantiation 43 : template class ProximalMixedL21Norm<float>; 44 : template class ProximalMixedL21Norm<double>; 45 : } // namespace elsa