Line data Source code
1 : #include "TransmissionLogLikelihood.h" 2 : #include "DataContainer.h" 3 : #include "LinearOperator.h" 4 : #include "Scaling.h" 5 : #include "Error.h" 6 : #include "TypeCasts.hpp" 7 : 8 : #include <cmath> 9 : 10 : namespace elsa 11 : { 12 : template <typename data_t> 13 : TransmissionLogLikelihood<data_t>::TransmissionLogLikelihood(const LinearOperator<data_t>& A, 14 : const DataContainer<data_t>& y, 15 : const DataContainer<data_t>& b, 16 : const DataContainer<data_t>& r) 17 : : Functional<data_t>(A.getDomainDescriptor()), A_(A.clone()), y_{y}, b_{b}, r_{r} 18 6 : { 19 : // sanity check 20 6 : if (A.getRangeDescriptor() != y.getDataDescriptor() 21 6 : || A.getRangeDescriptor() != b.getDataDescriptor() 22 6 : || A.getRangeDescriptor() != r.getDataDescriptor()) 23 0 : throw InvalidArgumentError( 24 0 : "TransmissionLogLikelihood: operator and y/b/r not matching in size."); 25 6 : } 26 : 27 : template <typename data_t> 28 : TransmissionLogLikelihood<data_t>::TransmissionLogLikelihood(const LinearOperator<data_t>& A, 29 : const DataContainer<data_t>& y, 30 : const DataContainer<data_t>& b) 31 : : Functional<data_t>(A.getDomainDescriptor()), A_(A.clone()), y_{y}, b_{b} 32 8 : { 33 : // sanity check 34 8 : if (A.getRangeDescriptor() != y.getDataDescriptor() 35 8 : || A.getRangeDescriptor() != b.getDataDescriptor()) 36 0 : throw InvalidArgumentError( 37 0 : "TransmissionLogLikelihood: operator and y/b not matching in size."); 38 8 : } 39 : 40 : template <typename data_t> 41 : bool TransmissionLogLikelihood<data_t>::isDifferentiable() const 42 0 : { 43 0 : return true; 44 0 : } 45 : 46 : template <typename data_t> 47 : data_t TransmissionLogLikelihood<data_t>::evaluateImpl(const DataContainer<data_t>& x) const 48 4 : { 49 4 : if (x.getDataDescriptor() != A_->getDomainDescriptor()) { 50 0 : throw InvalidArgumentError( 51 0 : "TransmissionLogLikelihood: given x is not the correct size"); 52 0 : } 53 : 54 4 : auto result = static_cast<data_t>(0.0); 55 : 56 4 : auto Rx = A_->apply(x); 57 1096 : for (index_t i = 0; i < Rx.getSize(); ++i) { 58 1092 : data_t temp = b_[i] * std::exp(-Rx[i]); 59 1092 : if (r_.has_value()) 60 546 : temp += (*r_)[i]; 61 : 62 1092 : result += temp - y_[i] * std::log(temp); 63 1092 : } 64 : 65 4 : return result; 66 4 : } 67 : 68 : template <typename data_t> 69 : void TransmissionLogLikelihood<data_t>::getGradientImpl(const DataContainer<data_t>& x, 70 : DataContainer<data_t>& out) const 71 4 : { 72 : // Actual computation of the functional 73 1092 : auto translog = [&](auto in, auto i) { 74 1092 : data_t temp = b_[i] * std::exp(-in); 75 1092 : in = -temp; 76 : 77 1092 : if (r_.has_value()) 78 546 : in += y_[i] * temp / (temp + (*r_)[i]); 79 546 : else 80 546 : in += y_[i]; 81 : 82 1092 : return in; 83 1092 : }; 84 : 85 4 : auto Rx = A_->apply(x); 86 1096 : for (index_t i = 0; i < Rx.getSize(); ++i) { 87 1092 : Rx[i] = translog(Rx[i], i); 88 1092 : } 89 4 : A_->applyAdjoint(Rx, out); 90 4 : } 91 : 92 : template <typename data_t> 93 : LinearOperator<data_t> 94 : TransmissionLogLikelihood<data_t>::getHessianImpl(const DataContainer<data_t>& x) const 95 4 : { 96 4 : auto scale = [&](auto in) { 97 4 : DataContainer<data_t> s(in.getDataDescriptor()); 98 1096 : for (index_t i = 0; i < in.getSize(); ++i) { 99 1092 : s[i] = b_[i] * std::exp(-in[i]); 100 1092 : if (r_.has_value()) { 101 546 : data_t tempR = s[i] + (*r_)[i]; 102 546 : s[i] += (*r_)[i] * y_[i] * s[i] / (tempR * tempR); 103 546 : } 104 1092 : } 105 : 106 4 : return leaf(Scaling<data_t>(s)); 107 4 : }; 108 : 109 4 : auto Rx = A_->apply(x); 110 4 : auto s = scale(Rx); 111 : 112 : // Jacobian is the operator, plus chain rule 113 4 : return adjoint(*A_) * s * (*A_); 114 4 : } 115 : 116 : template <typename data_t> 117 : TransmissionLogLikelihood<data_t>* TransmissionLogLikelihood<data_t>::cloneImpl() const 118 4 : { 119 4 : if (r_.has_value()) { 120 2 : return new TransmissionLogLikelihood<data_t>(*A_, y_, b_, *r_); 121 2 : } 122 2 : return new TransmissionLogLikelihood<data_t>(*A_, y_, b_); 123 2 : } 124 : 125 : template <typename data_t> 126 : bool TransmissionLogLikelihood<data_t>::isEqual(const Functional<data_t>& other) const 127 4 : { 128 4 : if (!Functional<data_t>::isEqual(other)) 129 0 : return false; 130 : 131 4 : auto otherTLL = downcast_safe<TransmissionLogLikelihood>(&other); 132 4 : if (!otherTLL) 133 0 : return false; 134 : 135 4 : if (y_ != otherTLL->y_ || b_ != otherTLL->b_) 136 0 : return false; 137 : 138 4 : if (r_.has_value() && otherTLL->r_.has_value() && *r_ != *otherTLL->r_) 139 0 : return false; 140 : 141 4 : return true; 142 4 : } 143 : 144 : // ------------------------------------------ 145 : // explicit template instantiation 146 : template class TransmissionLogLikelihood<float>; 147 : template class TransmissionLogLikelihood<double>; 148 : // no complex instantiations, they make no sense 149 : 150 : } // namespace elsa