Line data Source code
1 : #include "EmissionLogLikelihood.h" 2 : #include "DataContainer.h" 3 : #include "DataDescriptor.h" 4 : #include "LinearOperator.h" 5 : #include "Scaling.h" 6 : #include "TypeCasts.hpp" 7 : 8 : #include <cmath> 9 : #include <stdexcept> 10 : 11 : namespace elsa 12 : { 13 : template <typename data_t> 14 : EmissionLogLikelihood<data_t>::EmissionLogLikelihood(const LinearOperator<data_t>& A, 15 : const DataContainer<data_t>& y, 16 : const DataContainer<data_t>& r) 17 : : Functional<data_t>(A.getDomainDescriptor()), A_(A.clone()), y_{y}, r_{r} 18 6 : { 19 : // sanity check 20 6 : if (A.getRangeDescriptor() != y.getDataDescriptor() 21 6 : || A.getRangeDescriptor() != r.getDataDescriptor()) 22 0 : throw InvalidArgumentError( 23 0 : "EmissionLogLikelihood: residual and y/r not matching in size."); 24 6 : } 25 : 26 : template <typename data_t> 27 : EmissionLogLikelihood<data_t>::EmissionLogLikelihood(const LinearOperator<data_t>& A, 28 : const DataContainer<data_t>& y) 29 : : Functional<data_t>(A.getDomainDescriptor()), A_(A.clone()), y_{y} 30 8 : { 31 : // sanity check 32 8 : if (A.getRangeDescriptor() != y.getDataDescriptor()) 33 0 : throw InvalidArgumentError( 34 0 : "EmissionLogLikelihood: residual and y not matching in size."); 35 8 : } 36 : 37 : template <typename data_t> 38 : bool EmissionLogLikelihood<data_t>::isDifferentiable() const 39 0 : { 40 0 : return true; 41 0 : } 42 : 43 : template <typename data_t> 44 : data_t EmissionLogLikelihood<data_t>::evaluateImpl(const DataContainer<data_t>& x) const 45 4 : { 46 4 : if (x.getDataDescriptor() != A_->getDomainDescriptor()) { 47 0 : throw InvalidArgumentError("EmissionLogLikelihood: given x is not the correct size"); 48 0 : } 49 : 50 4 : auto result = static_cast<data_t>(0.0); 51 : 52 4 : auto Rx = A_->apply(x); 53 10264 : for (index_t i = 0; i < Rx.getSize(); ++i) { 54 10260 : data_t temp = Rx[i]; 55 10260 : if (r_) 56 5130 : temp += (*r_)[i]; 57 : 58 10260 : result += temp - y_[i] * std::log(temp); 59 10260 : } 60 : 61 4 : return result; 62 4 : } 63 : 64 : template <typename data_t> 65 : void EmissionLogLikelihood<data_t>::getGradientImpl(const DataContainer<data_t>& x, 66 : DataContainer<data_t>& out) const 67 4 : { 68 10260 : auto emissionlog = [&](auto in, auto i) { 69 10260 : data_t temp = in; 70 10260 : if (r_) 71 5130 : temp += (*r_)[i]; 72 : 73 10260 : return 1 - y_[i] / temp; 74 10260 : }; 75 : 76 4 : auto Rx = A_->apply(x); 77 10264 : for (index_t i = 0; i < Rx.getSize(); ++i) { 78 10260 : Rx[i] = emissionlog(Rx[i], i); 79 10260 : } 80 4 : A_->applyAdjoint(Rx, out); 81 4 : } 82 : 83 : template <typename data_t> 84 : LinearOperator<data_t> 85 : EmissionLogLikelihood<data_t>::getHessianImpl(const DataContainer<data_t>& x) const 86 4 : { 87 4 : auto scale = [&](auto in) { 88 4 : DataContainer<data_t> s(in.getDataDescriptor()); 89 10264 : for (index_t i = 0; i < in.getSize(); ++i) { 90 10260 : data_t temp = in[i]; 91 10260 : if (r_) 92 5130 : temp += (*r_)[i]; 93 : 94 10260 : s[i] = y_[i] / (temp * temp); 95 10260 : } 96 : 97 4 : return leaf(Scaling<data_t>(s)); 98 4 : }; 99 : 100 4 : auto Rx = A_->apply(x); 101 : 102 : // Jacobian is the operator, plus chain rule 103 4 : return adjoint(*A_) * scale(Rx) * (*A_); 104 4 : } 105 : 106 : template <typename data_t> 107 : EmissionLogLikelihood<data_t>* EmissionLogLikelihood<data_t>::cloneImpl() const 108 4 : { 109 4 : if (r_.has_value()) { 110 2 : return new EmissionLogLikelihood<data_t>(*A_, y_, *r_); 111 2 : } 112 2 : return new EmissionLogLikelihood<data_t>(*A_, y_); 113 2 : } 114 : 115 : template <typename data_t> 116 : bool EmissionLogLikelihood<data_t>::isEqual(const Functional<data_t>& other) const 117 4 : { 118 4 : if (!Functional<data_t>::isEqual(other)) 119 0 : return false; 120 : 121 4 : auto otherELL = downcast_safe<EmissionLogLikelihood>(&other); 122 4 : if (!otherELL) 123 0 : return false; 124 : 125 4 : if (r_ && otherELL->r_ && *r_ != *otherELL->r_) { 126 0 : return false; 127 0 : } 128 : 129 4 : return y_ == otherELL->y_; 130 4 : } 131 : 132 : // ------------------------------------------ 133 : // explicit template instantiation 134 : template class EmissionLogLikelihood<float>; 135 : template class EmissionLogLikelihood<double>; 136 : // no complex instantiations, they make no sense 137 : 138 : } // namespace elsa