Line data Source code
1 : #include "TransmissionLogLikelihood.h" 2 : #include "Scaling.h" 3 : #include "Error.h" 4 : #include "TypeCasts.hpp" 5 : 6 : #include <cmath> 7 : 8 : namespace elsa 9 : { 10 : 11 : template <typename data_t> 12 : TransmissionLogLikelihood<data_t>::TransmissionLogLikelihood( 13 : const DataDescriptor& domainDescriptor, const DataContainer<data_t>& y, 14 : const DataContainer<data_t>& b, const DataContainer<data_t>& r) 15 : : Functional<data_t>(domainDescriptor), 16 : _y{std::make_unique<DataContainer<data_t>>(y)}, 17 : _b{std::make_unique<DataContainer<data_t>>(b)}, 18 : _r{std::make_unique<DataContainer<data_t>>(r)} 19 4 : { 20 : // sanity check 21 4 : if (domainDescriptor != y.getDataDescriptor() || domainDescriptor != b.getDataDescriptor() 22 4 : || domainDescriptor != r.getDataDescriptor()) 23 0 : throw InvalidArgumentError( 24 0 : "TransmissionLogLikelihood: descriptor and y/b/r not matching in size."); 25 4 : } 26 : 27 : template <typename data_t> 28 : TransmissionLogLikelihood<data_t>::TransmissionLogLikelihood( 29 : const DataDescriptor& domainDescriptor, const DataContainer<data_t>& y, 30 : const DataContainer<data_t>& b) 31 : : Functional<data_t>(domainDescriptor), 32 : _y{std::make_unique<DataContainer<data_t>>(y)}, 33 : _b{std::make_unique<DataContainer<data_t>>(b)} 34 6 : { 35 : // sanity check 36 6 : if (domainDescriptor != y.getDataDescriptor() || domainDescriptor != b.getDataDescriptor()) 37 0 : throw InvalidArgumentError( 38 0 : "TransmissionLogLikelihood: descriptor and y/b not matching in size."); 39 6 : } 40 : 41 : template <typename data_t> 42 : TransmissionLogLikelihood<data_t>::TransmissionLogLikelihood(const Residual<data_t>& residual, 43 : const DataContainer<data_t>& y, 44 : const DataContainer<data_t>& b, 45 : const DataContainer<data_t>& r) 46 : : Functional<data_t>(residual), 47 : _y{std::make_unique<DataContainer<data_t>>(y)}, 48 : _b{std::make_unique<DataContainer<data_t>>(b)}, 49 : _r{std::make_unique<DataContainer<data_t>>(r)} 50 8 : { 51 : // sanity check 52 8 : if (residual.getRangeDescriptor() != y.getDataDescriptor() 53 8 : || residual.getRangeDescriptor() != b.getDataDescriptor() 54 8 : || residual.getRangeDescriptor() != r.getDataDescriptor()) 55 0 : throw InvalidArgumentError( 56 0 : "TransmissionLogLikelihood: residual and y/b/r not matching in size."); 57 8 : } 58 : 59 : template <typename data_t> 60 : TransmissionLogLikelihood<data_t>::TransmissionLogLikelihood(const Residual<data_t>& residual, 61 : const DataContainer<data_t>& y, 62 : const DataContainer<data_t>& b) 63 : : Functional<data_t>(residual), 64 : _y{std::make_unique<DataContainer<data_t>>(y)}, 65 : _b{std::make_unique<DataContainer<data_t>>(b)} 66 10 : { 67 : // sanity check 68 10 : if (residual.getRangeDescriptor() != y.getDataDescriptor() 69 10 : || residual.getRangeDescriptor() != b.getDataDescriptor()) 70 0 : throw InvalidArgumentError( 71 0 : "TransmissionLogLikelihood: residual and y/b not matching in size."); 72 10 : } 73 : 74 : template <typename data_t> 75 : data_t TransmissionLogLikelihood<data_t>::evaluateImpl(const DataContainer<data_t>& Rx) 76 8 : { 77 8 : auto result = static_cast<data_t>(0.0); 78 : 79 1984 : for (index_t i = 0; i < Rx.getSize(); ++i) { 80 1976 : data_t temp = (*_b)[i] * std::exp(-Rx[i]); 81 1976 : if (_r) 82 988 : temp += (*_r)[i]; 83 : 84 1976 : result += temp - (*_y)[i] * std::log(temp); 85 1976 : } 86 : 87 8 : return result; 88 8 : } 89 : 90 : template <typename data_t> 91 : void TransmissionLogLikelihood<data_t>::getGradientInPlaceImpl(DataContainer<data_t>& Rx) 92 8 : { 93 1984 : for (index_t i = 0; i < Rx.getSize(); ++i) { 94 1976 : data_t temp = (*_b)[i] * std::exp(-Rx[i]); 95 1976 : Rx[i] = -temp; 96 : 97 1976 : if (_r) 98 988 : Rx[i] += (*_y)[i] * temp / (temp + (*_r)[i]); 99 988 : else 100 988 : Rx[i] += (*_y)[i]; 101 1976 : } 102 8 : } 103 : 104 : template <typename data_t> 105 : LinearOperator<data_t> 106 : TransmissionLogLikelihood<data_t>::getHessianImpl(const DataContainer<data_t>& Rx) 107 8 : { 108 8 : DataContainer<data_t> scaleFactors(Rx.getDataDescriptor()); 109 1984 : for (index_t i = 0; i < Rx.getSize(); ++i) { 110 1976 : data_t temp = (*_b)[i] * std::exp(-Rx[i]); 111 1976 : scaleFactors[i] = temp; 112 1976 : if (_r) { 113 988 : data_t tempR = temp + (*_r)[i]; 114 988 : scaleFactors[i] += (*_r)[i] * (*_y)[i] * temp / (tempR * tempR); 115 988 : } 116 1976 : } 117 : 118 8 : return leaf(Scaling<data_t>(Rx.getDataDescriptor(), scaleFactors)); 119 8 : } 120 : 121 : template <typename data_t> 122 : TransmissionLogLikelihood<data_t>* TransmissionLogLikelihood<data_t>::cloneImpl() const 123 8 : { 124 8 : if (_r) 125 4 : return new TransmissionLogLikelihood<data_t>(this->getResidual(), *_y, *_b, *_r); 126 4 : else 127 4 : return new TransmissionLogLikelihood<data_t>(this->getResidual(), *_y, *_b); 128 8 : } 129 : 130 : template <typename data_t> 131 : bool TransmissionLogLikelihood<data_t>::isEqual(const Functional<data_t>& other) const 132 8 : { 133 8 : if (!Functional<data_t>::isEqual(other)) 134 0 : return false; 135 : 136 8 : auto otherTLL = downcast_safe<TransmissionLogLikelihood>(&other); 137 8 : if (!otherTLL) 138 0 : return false; 139 : 140 8 : if (*_y != *otherTLL->_y || *_b != *otherTLL->_b) 141 0 : return false; 142 : 143 8 : if (_r && *_r != *otherTLL->_r) 144 0 : return false; 145 : 146 8 : return true; 147 8 : } 148 : 149 : // ------------------------------------------ 150 : // explicit template instantiation 151 : template class TransmissionLogLikelihood<float>; 152 : template class TransmissionLogLikelihood<double>; 153 : // no complex instantiations, they make no sense 154 : 155 : } // namespace elsa