Line data Source code
1 : #include "L2Squared.h" 2 : 3 : #include "DataContainer.h" 4 : #include "DataDescriptor.h" 5 : #include "Identity.h" 6 : #include "LinearOperator.h" 7 : #include "TypeCasts.hpp" 8 : #include "ProximalL2Squared.h" 9 : 10 : namespace elsa 11 : { 12 : template <typename data_t> 13 : L2Squared<data_t>::L2Squared(const DataDescriptor& domainDescriptor) 14 : : Functional<data_t>(domainDescriptor) 15 96 : { 16 96 : } 17 : 18 : template <typename data_t> 19 : L2Squared<data_t>::L2Squared(const DataContainer<data_t>& b) 20 : : Functional<data_t>(b.getDataDescriptor()), b_(b) 21 16 : { 22 16 : } 23 : 24 : template <typename data_t> 25 : bool L2Squared<data_t>::isDifferentiable() const 26 0 : { 27 0 : return true; 28 0 : } 29 : 30 : template <typename data_t> 31 : bool L2Squared<data_t>::isProxFriendly() const 32 4 : { 33 4 : return true; 34 4 : } 35 : 36 : template <typename data_t> 37 : bool L2Squared<data_t>::hasDataVector() const 38 448 : { 39 448 : return b_.has_value(); 40 448 : } 41 : 42 : template <typename data_t> 43 : const DataContainer<data_t>& L2Squared<data_t>::getDataVector() const 44 0 : { 45 0 : if (!hasDataVector()) { 46 0 : throw Error("L2Squared: No data vector present"); 47 0 : } 48 : 49 0 : return *b_; 50 0 : } 51 : 52 : template <typename data_t> 53 : data_t L2Squared<data_t>::evaluateImpl(const DataContainer<data_t>& x) const 54 124 : { 55 124 : if (!hasDataVector()) { 56 122 : return data_t{0.5} * x.squaredL2Norm(); 57 122 : } 58 : 59 2 : return data_t{0.5} * (x - *b_).squaredL2Norm(); 60 2 : } 61 : 62 : template <typename data_t> 63 : data_t L2Squared<data_t>::convexConjugate(const DataContainer<data_t>& x) const 64 4 : { 65 4 : auto res = 0.25 * x.squaredL2Norm(); 66 : 67 4 : if (hasDataVector()) { 68 2 : res += x.dot(*b_); 69 2 : } 70 : 71 4 : return res; 72 4 : } 73 : 74 : template <typename data_t> 75 : void L2Squared<data_t>::getGradientImpl(const DataContainer<data_t>& x, 76 : DataContainer<data_t>& out) const 77 244 : { 78 244 : if (!hasDataVector()) { 79 242 : out = x; 80 242 : } else { 81 2 : out = x - *b_; 82 2 : } 83 244 : } 84 : 85 : template <typename data_t> 86 : LinearOperator<data_t> L2Squared<data_t>::getHessianImpl(const DataContainer<data_t>& Rx) const 87 4 : { 88 4 : return leaf(Identity<data_t>(Rx.getDataDescriptor())); 89 4 : } 90 : 91 : template <typename data_t> 92 : DataContainer<data_t> L2Squared<data_t>::proximal(const DataContainer<data_t>& v, 93 : SelfType_t<data_t> tau) const 94 8 : { 95 8 : auto out = emptylike(v); 96 8 : proximal(v, tau, out); 97 8 : return out; 98 8 : } 99 : 100 : template <typename data_t> 101 : void L2Squared<data_t>::proximal(const DataContainer<data_t>& v, SelfType_t<data_t> t, 102 : DataContainer<data_t>& out) const 103 8 : { 104 8 : if (!hasDataVector()) { 105 4 : ProximalL2Squared<data_t> prox; 106 4 : prox.apply(v, t, out); 107 4 : } else { 108 4 : ProximalL2Squared<data_t> prox(*b_); 109 4 : prox.apply(v, t, out); 110 4 : } 111 8 : } 112 : 113 : template <typename data_t> 114 : L2Squared<data_t>* L2Squared<data_t>::cloneImpl() const 115 68 : { 116 68 : if (!hasDataVector()) { 117 66 : return new L2Squared(this->getDomainDescriptor()); 118 66 : } 119 2 : return new L2Squared(*b_); 120 2 : } 121 : 122 : template <typename data_t> 123 : bool L2Squared<data_t>::isEqual(const Functional<data_t>& other) const 124 4 : { 125 4 : if (!Functional<data_t>::isEqual(other)) 126 0 : return false; 127 : 128 4 : auto fn = downcast_safe<L2Squared<data_t>>(&other); 129 4 : if (!fn) { 130 0 : return false; 131 0 : } 132 : 133 4 : if (b_ && fn->b_) { 134 2 : return *b_ == *fn->b_; 135 2 : } 136 : 137 2 : return b_ == fn->b_; 138 2 : } 139 : 140 : // ------------------------------------------ 141 : // explicit template instantiation 142 : template class L2Squared<float>; 143 : template class L2Squared<double>; 144 : } // namespace elsa