Line data Source code
1 : #include "TransmissionLogLikelihood.h" 2 : #include "Scaling.h" 3 : #include "Error.h" 4 : #include "TypeCasts.hpp" 5 : 6 : #include <cmath> 7 : 8 : namespace elsa 9 : { 10 : 11 : template <typename data_t> 12 0 : TransmissionLogLikelihood<data_t>::TransmissionLogLikelihood( 13 : const DataDescriptor& domainDescriptor, const DataContainer<data_t>& y, 14 : const DataContainer<data_t>& b, const DataContainer<data_t>& r) 15 : : Functional<data_t>(domainDescriptor), 16 0 : _y{std::make_unique<DataContainer<data_t>>(y)}, 17 0 : _b{std::make_unique<DataContainer<data_t>>(b)}, 18 0 : _r{std::make_unique<DataContainer<data_t>>(r)} 19 : { 20 : // sanity check 21 0 : if (domainDescriptor != y.getDataDescriptor() || domainDescriptor != b.getDataDescriptor() 22 0 : || domainDescriptor != r.getDataDescriptor()) 23 0 : throw InvalidArgumentError( 24 : "TransmissionLogLikelihood: descriptor and y/b/r not matching in size."); 25 0 : } 26 : 27 : template <typename data_t> 28 0 : TransmissionLogLikelihood<data_t>::TransmissionLogLikelihood( 29 : const DataDescriptor& domainDescriptor, const DataContainer<data_t>& y, 30 : const DataContainer<data_t>& b) 31 : : Functional<data_t>(domainDescriptor), 32 0 : _y{std::make_unique<DataContainer<data_t>>(y)}, 33 0 : _b{std::make_unique<DataContainer<data_t>>(b)} 34 : { 35 : // sanity check 36 0 : if (domainDescriptor != y.getDataDescriptor() || domainDescriptor != b.getDataDescriptor()) 37 0 : throw InvalidArgumentError( 38 : "TransmissionLogLikelihood: descriptor and y/b not matching in size."); 39 0 : } 40 : 41 : template <typename data_t> 42 0 : TransmissionLogLikelihood<data_t>::TransmissionLogLikelihood(const Residual<data_t>& residual, 43 : const DataContainer<data_t>& y, 44 : const DataContainer<data_t>& b, 45 : const DataContainer<data_t>& r) 46 : : Functional<data_t>(residual), 47 0 : _y{std::make_unique<DataContainer<data_t>>(y)}, 48 0 : _b{std::make_unique<DataContainer<data_t>>(b)}, 49 0 : _r{std::make_unique<DataContainer<data_t>>(r)} 50 : { 51 : // sanity check 52 0 : if (residual.getRangeDescriptor() != y.getDataDescriptor() 53 0 : || residual.getRangeDescriptor() != b.getDataDescriptor() 54 0 : || residual.getRangeDescriptor() != r.getDataDescriptor()) 55 0 : throw InvalidArgumentError( 56 : "TransmissionLogLikelihood: residual and y/b/r not matching in size."); 57 0 : } 58 : 59 : template <typename data_t> 60 0 : TransmissionLogLikelihood<data_t>::TransmissionLogLikelihood(const Residual<data_t>& residual, 61 : const DataContainer<data_t>& y, 62 : const DataContainer<data_t>& b) 63 : : Functional<data_t>(residual), 64 0 : _y{std::make_unique<DataContainer<data_t>>(y)}, 65 0 : _b{std::make_unique<DataContainer<data_t>>(b)} 66 : { 67 : // sanity check 68 0 : if (residual.getRangeDescriptor() != y.getDataDescriptor() 69 0 : || residual.getRangeDescriptor() != b.getDataDescriptor()) 70 0 : throw InvalidArgumentError( 71 : "TransmissionLogLikelihood: residual and y/b not matching in size."); 72 0 : } 73 : 74 : template <typename data_t> 75 0 : data_t TransmissionLogLikelihood<data_t>::evaluateImpl(const DataContainer<data_t>& Rx) 76 : { 77 0 : auto result = static_cast<data_t>(0.0); 78 : 79 0 : for (index_t i = 0; i < Rx.getSize(); ++i) { 80 0 : data_t temp = (*_b)[i] * std::exp(-Rx[i]); 81 0 : if (_r) 82 0 : temp += (*_r)[i]; 83 : 84 0 : result += temp - (*_y)[i] * std::log(temp); 85 : } 86 : 87 0 : return result; 88 : } 89 : 90 : template <typename data_t> 91 0 : void TransmissionLogLikelihood<data_t>::getGradientInPlaceImpl(DataContainer<data_t>& Rx) 92 : { 93 0 : for (index_t i = 0; i < Rx.getSize(); ++i) { 94 0 : data_t temp = (*_b)[i] * std::exp(-Rx[i]); 95 0 : Rx[i] = -temp; 96 : 97 0 : if (_r) 98 0 : Rx[i] += (*_y)[i] * temp / (temp + (*_r)[i]); 99 : else 100 0 : Rx[i] += (*_y)[i]; 101 : } 102 0 : } 103 : 104 : template <typename data_t> 105 : LinearOperator<data_t> 106 0 : TransmissionLogLikelihood<data_t>::getHessianImpl(const DataContainer<data_t>& Rx) 107 : { 108 0 : DataContainer<data_t> scaleFactors(Rx.getDataDescriptor()); 109 0 : for (index_t i = 0; i < Rx.getSize(); ++i) { 110 0 : data_t temp = (*_b)[i] * std::exp(-Rx[i]); 111 0 : scaleFactors[i] = temp; 112 0 : if (_r) { 113 0 : data_t tempR = temp + (*_r)[i]; 114 0 : scaleFactors[i] += (*_r)[i] * (*_y)[i] * temp / (tempR * tempR); 115 : } 116 : } 117 : 118 0 : return leaf(Scaling<data_t>(Rx.getDataDescriptor(), scaleFactors)); 119 0 : } 120 : 121 : template <typename data_t> 122 0 : TransmissionLogLikelihood<data_t>* TransmissionLogLikelihood<data_t>::cloneImpl() const 123 : { 124 0 : if (_r) 125 0 : return new TransmissionLogLikelihood<data_t>(this->getResidual(), *_y, *_b, *_r); 126 : else 127 0 : return new TransmissionLogLikelihood<data_t>(this->getResidual(), *_y, *_b); 128 : } 129 : 130 : template <typename data_t> 131 0 : bool TransmissionLogLikelihood<data_t>::isEqual(const Functional<data_t>& other) const 132 : { 133 0 : if (!Functional<data_t>::isEqual(other)) 134 0 : return false; 135 : 136 0 : auto otherTLL = downcast_safe<TransmissionLogLikelihood>(&other); 137 0 : if (!otherTLL) 138 0 : return false; 139 : 140 0 : if (*_y != *otherTLL->_y || *_b != *otherTLL->_b) 141 0 : return false; 142 : 143 0 : if (_r && *_r != *otherTLL->_r) 144 0 : return false; 145 : 146 0 : return true; 147 : } 148 : 149 : // ------------------------------------------ 150 : // explicit template instantiation 151 : template class TransmissionLogLikelihood<float>; 152 : template class TransmissionLogLikelihood<double>; 153 : // no complex instantiations, they make no sense 154 : 155 : } // namespace elsa