Line data Source code
1 : #include "EmissionLogLikelihood.h" 2 : #include "Scaling.h" 3 : #include "TypeCasts.hpp" 4 : 5 : #include <cmath> 6 : #include <stdexcept> 7 : 8 : namespace elsa 9 : { 10 : template <typename data_t> 11 : EmissionLogLikelihood<data_t>::EmissionLogLikelihood(const DataDescriptor& domainDescriptor, 12 : const DataContainer<data_t>& y, 13 : const DataContainer<data_t>& r) 14 : : Functional<data_t>(domainDescriptor), 15 : _y{std::make_unique<DataContainer<data_t>>(y)}, 16 : _r{std::make_unique<DataContainer<data_t>>(r)} 17 4 : { 18 : // sanity check 19 4 : if (domainDescriptor != y.getDataDescriptor() || domainDescriptor != r.getDataDescriptor()) 20 0 : throw InvalidArgumentError( 21 0 : "EmissionLogLikelihood: descriptor and y/r not matching in size."); 22 4 : } 23 : 24 : template <typename data_t> 25 : EmissionLogLikelihood<data_t>::EmissionLogLikelihood(const DataDescriptor& domainDescriptor, 26 : const DataContainer<data_t>& y) 27 : : Functional<data_t>(domainDescriptor), _y{std::make_unique<DataContainer<data_t>>(y)} 28 6 : { 29 : // sanity check 30 6 : if (domainDescriptor != y.getDataDescriptor()) 31 0 : throw InvalidArgumentError( 32 0 : "EmissionLogLikelihood: descriptor and y not matching in size."); 33 6 : } 34 : 35 : template <typename data_t> 36 : EmissionLogLikelihood<data_t>::EmissionLogLikelihood(const Residual<data_t>& residual, 37 : const DataContainer<data_t>& y, 38 : const DataContainer<data_t>& r) 39 : : Functional<data_t>(residual), 40 : _y{std::make_unique<DataContainer<data_t>>(y)}, 41 : _r{std::make_unique<DataContainer<data_t>>(r)} 42 8 : { 43 : // sanity check 44 8 : if (residual.getRangeDescriptor() != y.getDataDescriptor() 45 8 : || residual.getRangeDescriptor() != r.getDataDescriptor()) 46 0 : throw InvalidArgumentError( 47 0 : "EmissionLogLikelihood: residual and y/r not matching in size."); 48 8 : } 49 : 50 : template <typename data_t> 51 : EmissionLogLikelihood<data_t>::EmissionLogLikelihood(const Residual<data_t>& residual, 52 : const DataContainer<data_t>& y) 53 : : Functional<data_t>(residual), _y{std::make_unique<DataContainer<data_t>>(y)} 54 10 : { 55 : // sanity check 56 10 : if (residual.getRangeDescriptor() != y.getDataDescriptor()) 57 0 : throw InvalidArgumentError( 58 0 : "EmissionLogLikelihood: residual and y not matching in size."); 59 10 : } 60 : 61 : template <typename data_t> 62 : data_t EmissionLogLikelihood<data_t>::evaluateImpl(const DataContainer<data_t>& Rx) 63 8 : { 64 8 : auto result = static_cast<data_t>(0.0); 65 : 66 14048 : for (index_t i = 0; i < Rx.getSize(); ++i) { 67 14040 : data_t temp = Rx[i]; 68 14040 : if (_r) 69 7020 : temp += (*_r)[i]; 70 : 71 14040 : result += temp - (*_y)[i] * std::log(temp); 72 14040 : } 73 : 74 8 : return result; 75 8 : } 76 : 77 : template <typename data_t> 78 : void EmissionLogLikelihood<data_t>::getGradientInPlaceImpl(DataContainer<data_t>& Rx) 79 8 : { 80 14048 : for (index_t i = 0; i < Rx.getSize(); ++i) { 81 14040 : data_t temp = Rx[i]; 82 14040 : if (_r) 83 7020 : temp += (*_r)[i]; 84 : 85 14040 : Rx[i] = 1 - (*_y)[i] / temp; 86 14040 : } 87 8 : } 88 : 89 : template <typename data_t> 90 : LinearOperator<data_t> 91 : EmissionLogLikelihood<data_t>::getHessianImpl(const DataContainer<data_t>& Rx) 92 8 : { 93 8 : DataContainer<data_t> scaleFactors(Rx.getDataDescriptor()); 94 14048 : for (index_t i = 0; i < Rx.getSize(); ++i) { 95 14040 : data_t temp = Rx[i]; 96 14040 : if (_r) 97 7020 : temp += (*_r)[i]; 98 : 99 14040 : scaleFactors[i] = (*_y)[i] / (temp * temp); 100 14040 : } 101 : 102 8 : return leaf(Scaling<data_t>(Rx.getDataDescriptor(), scaleFactors)); 103 8 : } 104 : 105 : template <typename data_t> 106 : EmissionLogLikelihood<data_t>* EmissionLogLikelihood<data_t>::cloneImpl() const 107 8 : { 108 8 : if (_r) 109 4 : return new EmissionLogLikelihood<data_t>(this->getResidual(), *_y, *_r); 110 4 : else 111 4 : return new EmissionLogLikelihood<data_t>(this->getResidual(), *_y); 112 8 : } 113 : 114 : template <typename data_t> 115 : bool EmissionLogLikelihood<data_t>::isEqual(const Functional<data_t>& other) const 116 8 : { 117 8 : if (!Functional<data_t>::isEqual(other)) 118 0 : return false; 119 : 120 8 : auto otherELL = downcast_safe<EmissionLogLikelihood>(&other); 121 8 : if (!otherELL) 122 0 : return false; 123 : 124 8 : if (*_y != *otherELL->_y) 125 0 : return false; 126 : 127 8 : if (_r && *_r != *otherELL->_r) 128 0 : return false; 129 : 130 8 : return true; 131 8 : } 132 : 133 : // ------------------------------------------ 134 : // explicit template instantiation 135 : template class EmissionLogLikelihood<float>; 136 : template class EmissionLogLikelihood<double>; 137 : // no complex instantiations, they make no sense 138 : 139 : } // namespace elsa