LCOV - code coverage report
Current view: top level - elsa/functionals - TransmissionLogLikelihood.cpp (source / functions) Hit Total Coverage
Test: coverage-all.lcov Lines: 62 74 83.8 %
Date: 2022-08-25 03:05:39 Functions: 18 18 100.0 %

          Line data    Source code
       1             : #include "TransmissionLogLikelihood.h"
       2             : #include "Scaling.h"
       3             : #include "Error.h"
       4             : #include "TypeCasts.hpp"
       5             : 
       6             : #include <cmath>
       7             : 
       8             : namespace elsa
       9             : {
      10             : 
      11             :     template <typename data_t>
      12             :     TransmissionLogLikelihood<data_t>::TransmissionLogLikelihood(
      13             :         const DataDescriptor& domainDescriptor, const DataContainer<data_t>& y,
      14             :         const DataContainer<data_t>& b, const DataContainer<data_t>& r)
      15             :         : Functional<data_t>(domainDescriptor),
      16             :           _y{std::make_unique<DataContainer<data_t>>(y)},
      17             :           _b{std::make_unique<DataContainer<data_t>>(b)},
      18             :           _r{std::make_unique<DataContainer<data_t>>(r)}
      19           4 :     {
      20             :         // sanity check
      21           4 :         if (domainDescriptor != y.getDataDescriptor() || domainDescriptor != b.getDataDescriptor()
      22           4 :             || domainDescriptor != r.getDataDescriptor())
      23           0 :             throw InvalidArgumentError(
      24           0 :                 "TransmissionLogLikelihood: descriptor and y/b/r not matching in size.");
      25           4 :     }
      26             : 
      27             :     template <typename data_t>
      28             :     TransmissionLogLikelihood<data_t>::TransmissionLogLikelihood(
      29             :         const DataDescriptor& domainDescriptor, const DataContainer<data_t>& y,
      30             :         const DataContainer<data_t>& b)
      31             :         : Functional<data_t>(domainDescriptor),
      32             :           _y{std::make_unique<DataContainer<data_t>>(y)},
      33             :           _b{std::make_unique<DataContainer<data_t>>(b)}
      34           6 :     {
      35             :         // sanity check
      36           6 :         if (domainDescriptor != y.getDataDescriptor() || domainDescriptor != b.getDataDescriptor())
      37           0 :             throw InvalidArgumentError(
      38           0 :                 "TransmissionLogLikelihood: descriptor and y/b not matching in size.");
      39           6 :     }
      40             : 
      41             :     template <typename data_t>
      42             :     TransmissionLogLikelihood<data_t>::TransmissionLogLikelihood(const Residual<data_t>& residual,
      43             :                                                                  const DataContainer<data_t>& y,
      44             :                                                                  const DataContainer<data_t>& b,
      45             :                                                                  const DataContainer<data_t>& r)
      46             :         : Functional<data_t>(residual),
      47             :           _y{std::make_unique<DataContainer<data_t>>(y)},
      48             :           _b{std::make_unique<DataContainer<data_t>>(b)},
      49             :           _r{std::make_unique<DataContainer<data_t>>(r)}
      50           8 :     {
      51             :         // sanity check
      52           8 :         if (residual.getRangeDescriptor() != y.getDataDescriptor()
      53           8 :             || residual.getRangeDescriptor() != b.getDataDescriptor()
      54           8 :             || residual.getRangeDescriptor() != r.getDataDescriptor())
      55           0 :             throw InvalidArgumentError(
      56           0 :                 "TransmissionLogLikelihood: residual and y/b/r not matching in size.");
      57           8 :     }
      58             : 
      59             :     template <typename data_t>
      60             :     TransmissionLogLikelihood<data_t>::TransmissionLogLikelihood(const Residual<data_t>& residual,
      61             :                                                                  const DataContainer<data_t>& y,
      62             :                                                                  const DataContainer<data_t>& b)
      63             :         : Functional<data_t>(residual),
      64             :           _y{std::make_unique<DataContainer<data_t>>(y)},
      65             :           _b{std::make_unique<DataContainer<data_t>>(b)}
      66          10 :     {
      67             :         // sanity check
      68          10 :         if (residual.getRangeDescriptor() != y.getDataDescriptor()
      69          10 :             || residual.getRangeDescriptor() != b.getDataDescriptor())
      70           0 :             throw InvalidArgumentError(
      71           0 :                 "TransmissionLogLikelihood: residual and y/b not matching in size.");
      72          10 :     }
      73             : 
      74             :     template <typename data_t>
      75             :     data_t TransmissionLogLikelihood<data_t>::evaluateImpl(const DataContainer<data_t>& Rx)
      76           8 :     {
      77           8 :         auto result = static_cast<data_t>(0.0);
      78             : 
      79        1984 :         for (index_t i = 0; i < Rx.getSize(); ++i) {
      80        1976 :             data_t temp = (*_b)[i] * std::exp(-Rx[i]);
      81        1976 :             if (_r)
      82         988 :                 temp += (*_r)[i];
      83             : 
      84        1976 :             result += temp - (*_y)[i] * std::log(temp);
      85        1976 :         }
      86             : 
      87           8 :         return result;
      88           8 :     }
      89             : 
      90             :     template <typename data_t>
      91             :     void TransmissionLogLikelihood<data_t>::getGradientInPlaceImpl(DataContainer<data_t>& Rx)
      92           8 :     {
      93        1984 :         for (index_t i = 0; i < Rx.getSize(); ++i) {
      94        1976 :             data_t temp = (*_b)[i] * std::exp(-Rx[i]);
      95        1976 :             Rx[i] = -temp;
      96             : 
      97        1976 :             if (_r)
      98         988 :                 Rx[i] += (*_y)[i] * temp / (temp + (*_r)[i]);
      99         988 :             else
     100         988 :                 Rx[i] += (*_y)[i];
     101        1976 :         }
     102           8 :     }
     103             : 
     104             :     template <typename data_t>
     105             :     LinearOperator<data_t>
     106             :         TransmissionLogLikelihood<data_t>::getHessianImpl(const DataContainer<data_t>& Rx)
     107           8 :     {
     108           8 :         DataContainer<data_t> scaleFactors(Rx.getDataDescriptor());
     109        1984 :         for (index_t i = 0; i < Rx.getSize(); ++i) {
     110        1976 :             data_t temp = (*_b)[i] * std::exp(-Rx[i]);
     111        1976 :             scaleFactors[i] = temp;
     112        1976 :             if (_r) {
     113         988 :                 data_t tempR = temp + (*_r)[i];
     114         988 :                 scaleFactors[i] += (*_r)[i] * (*_y)[i] * temp / (tempR * tempR);
     115         988 :             }
     116        1976 :         }
     117             : 
     118           8 :         return leaf(Scaling<data_t>(Rx.getDataDescriptor(), scaleFactors));
     119           8 :     }
     120             : 
     121             :     template <typename data_t>
     122             :     TransmissionLogLikelihood<data_t>* TransmissionLogLikelihood<data_t>::cloneImpl() const
     123           8 :     {
     124           8 :         if (_r)
     125           4 :             return new TransmissionLogLikelihood<data_t>(this->getResidual(), *_y, *_b, *_r);
     126           4 :         else
     127           4 :             return new TransmissionLogLikelihood<data_t>(this->getResidual(), *_y, *_b);
     128           8 :     }
     129             : 
     130             :     template <typename data_t>
     131             :     bool TransmissionLogLikelihood<data_t>::isEqual(const Functional<data_t>& other) const
     132           8 :     {
     133           8 :         if (!Functional<data_t>::isEqual(other))
     134           0 :             return false;
     135             : 
     136           8 :         auto otherTLL = downcast_safe<TransmissionLogLikelihood>(&other);
     137           8 :         if (!otherTLL)
     138           0 :             return false;
     139             : 
     140           8 :         if (*_y != *otherTLL->_y || *_b != *otherTLL->_b)
     141           0 :             return false;
     142             : 
     143           8 :         if (_r && *_r != *otherTLL->_r)
     144           0 :             return false;
     145             : 
     146           8 :         return true;
     147           8 :     }
     148             : 
     149             :     // ------------------------------------------
     150             :     // explicit template instantiation
     151             :     template class TransmissionLogLikelihood<float>;
     152             :     template class TransmissionLogLikelihood<double>;
     153             :     // no complex instantiations, they make no sense
     154             : 
     155             : } // namespace elsa

Generated by: LCOV version 1.14