LCOV - code coverage report
Current view: top level - elsa/functionals - EmissionLogLikelihood.cpp (source / functions) Hit Total Coverage
Test: coverage-all.lcov Lines: 59 72 81.9 %
Date: 2024-05-16 04:22:26 Functions: 18 20 90.0 %

          Line data    Source code
       1             : #include "EmissionLogLikelihood.h"
       2             : #include "DataContainer.h"
       3             : #include "DataDescriptor.h"
       4             : #include "LinearOperator.h"
       5             : #include "Scaling.h"
       6             : #include "TypeCasts.hpp"
       7             : 
       8             : #include <cmath>
       9             : #include <stdexcept>
      10             : 
      11             : namespace elsa
      12             : {
      13             :     template <typename data_t>
      14             :     EmissionLogLikelihood<data_t>::EmissionLogLikelihood(const LinearOperator<data_t>& A,
      15             :                                                          const DataContainer<data_t>& y,
      16             :                                                          const DataContainer<data_t>& r)
      17             :         : Functional<data_t>(A.getDomainDescriptor()), A_(A.clone()), y_{y}, r_{r}
      18           6 :     {
      19             :         // sanity check
      20           6 :         if (A.getRangeDescriptor() != y.getDataDescriptor()
      21           6 :             || A.getRangeDescriptor() != r.getDataDescriptor())
      22           0 :             throw InvalidArgumentError(
      23           0 :                 "EmissionLogLikelihood: residual and y/r not matching in size.");
      24           6 :     }
      25             : 
      26             :     template <typename data_t>
      27             :     EmissionLogLikelihood<data_t>::EmissionLogLikelihood(const LinearOperator<data_t>& A,
      28             :                                                          const DataContainer<data_t>& y)
      29             :         : Functional<data_t>(A.getDomainDescriptor()), A_(A.clone()), y_{y}
      30           8 :     {
      31             :         // sanity check
      32           8 :         if (A.getRangeDescriptor() != y.getDataDescriptor())
      33           0 :             throw InvalidArgumentError(
      34           0 :                 "EmissionLogLikelihood: residual and y not matching in size.");
      35           8 :     }
      36             : 
      37             :     template <typename data_t>
      38             :     bool EmissionLogLikelihood<data_t>::isDifferentiable() const
      39           0 :     {
      40           0 :         return true;
      41           0 :     }
      42             : 
      43             :     template <typename data_t>
      44             :     data_t EmissionLogLikelihood<data_t>::evaluateImpl(const DataContainer<data_t>& x) const
      45           4 :     {
      46           4 :         if (x.getDataDescriptor() != A_->getDomainDescriptor()) {
      47           0 :             throw InvalidArgumentError("EmissionLogLikelihood: given x is not the correct size");
      48           0 :         }
      49             : 
      50           4 :         auto result = static_cast<data_t>(0.0);
      51             : 
      52           4 :         auto Rx = A_->apply(x);
      53       10264 :         for (index_t i = 0; i < Rx.getSize(); ++i) {
      54       10260 :             data_t temp = Rx[i];
      55       10260 :             if (r_)
      56        5130 :                 temp += (*r_)[i];
      57             : 
      58       10260 :             result += temp - y_[i] * std::log(temp);
      59       10260 :         }
      60             : 
      61           4 :         return result;
      62           4 :     }
      63             : 
      64             :     template <typename data_t>
      65             :     void EmissionLogLikelihood<data_t>::getGradientImpl(const DataContainer<data_t>& x,
      66             :                                                         DataContainer<data_t>& out) const
      67           4 :     {
      68       10260 :         auto emissionlog = [&](auto in, auto i) {
      69       10260 :             data_t temp = in;
      70       10260 :             if (r_)
      71        5130 :                 temp += (*r_)[i];
      72             : 
      73       10260 :             return 1 - y_[i] / temp;
      74       10260 :         };
      75             : 
      76           4 :         auto Rx = A_->apply(x);
      77       10264 :         for (index_t i = 0; i < Rx.getSize(); ++i) {
      78       10260 :             Rx[i] = emissionlog(Rx[i], i);
      79       10260 :         }
      80           4 :         A_->applyAdjoint(Rx, out);
      81           4 :     }
      82             : 
      83             :     template <typename data_t>
      84             :     LinearOperator<data_t>
      85             :         EmissionLogLikelihood<data_t>::getHessianImpl(const DataContainer<data_t>& x) const
      86           4 :     {
      87           4 :         auto scale = [&](auto in) {
      88           4 :             DataContainer<data_t> s(in.getDataDescriptor());
      89       10264 :             for (index_t i = 0; i < in.getSize(); ++i) {
      90       10260 :                 data_t temp = in[i];
      91       10260 :                 if (r_)
      92        5130 :                     temp += (*r_)[i];
      93             : 
      94       10260 :                 s[i] = y_[i] / (temp * temp);
      95       10260 :             }
      96             : 
      97           4 :             return leaf(Scaling<data_t>(s));
      98           4 :         };
      99             : 
     100           4 :         auto Rx = A_->apply(x);
     101             : 
     102             :         // Jacobian is the operator, plus chain rule
     103           4 :         return adjoint(*A_) * scale(Rx) * (*A_);
     104           4 :     }
     105             : 
     106             :     template <typename data_t>
     107             :     EmissionLogLikelihood<data_t>* EmissionLogLikelihood<data_t>::cloneImpl() const
     108           4 :     {
     109           4 :         if (r_.has_value()) {
     110           2 :             return new EmissionLogLikelihood<data_t>(*A_, y_, *r_);
     111           2 :         }
     112           2 :         return new EmissionLogLikelihood<data_t>(*A_, y_);
     113           2 :     }
     114             : 
     115             :     template <typename data_t>
     116             :     bool EmissionLogLikelihood<data_t>::isEqual(const Functional<data_t>& other) const
     117           4 :     {
     118           4 :         if (!Functional<data_t>::isEqual(other))
     119           0 :             return false;
     120             : 
     121           4 :         auto otherELL = downcast_safe<EmissionLogLikelihood>(&other);
     122           4 :         if (!otherELL)
     123           0 :             return false;
     124             : 
     125           4 :         if (r_ && otherELL->r_ && *r_ != *otherELL->r_) {
     126           0 :             return false;
     127           0 :         }
     128             : 
     129           4 :         return y_ == otherELL->y_;
     130           4 :     }
     131             : 
     132             :     // ------------------------------------------
     133             :     // explicit template instantiation
     134             :     template class EmissionLogLikelihood<float>;
     135             :     template class EmissionLogLikelihood<double>;
     136             :     // no complex instantiations, they make no sense
     137             : 
     138             : } // namespace elsa

Generated by: LCOV version 1.14