LCOV - code coverage report
Current view: top level - elsa/functionals - TransmissionLogLikelihood.cpp (source / functions) Hit Total Coverage
Test: coverage-all.lcov Lines: 67 81 82.7 %
Date: 2024-05-16 04:22:26 Functions: 18 20 90.0 %

          Line data    Source code
       1             : #include "TransmissionLogLikelihood.h"
       2             : #include "DataContainer.h"
       3             : #include "LinearOperator.h"
       4             : #include "Scaling.h"
       5             : #include "Error.h"
       6             : #include "TypeCasts.hpp"
       7             : 
       8             : #include <cmath>
       9             : 
      10             : namespace elsa
      11             : {
      12             :     template <typename data_t>
      13             :     TransmissionLogLikelihood<data_t>::TransmissionLogLikelihood(const LinearOperator<data_t>& A,
      14             :                                                                  const DataContainer<data_t>& y,
      15             :                                                                  const DataContainer<data_t>& b,
      16             :                                                                  const DataContainer<data_t>& r)
      17             :         : Functional<data_t>(A.getDomainDescriptor()), A_(A.clone()), y_{y}, b_{b}, r_{r}
      18           6 :     {
      19             :         // sanity check
      20           6 :         if (A.getRangeDescriptor() != y.getDataDescriptor()
      21           6 :             || A.getRangeDescriptor() != b.getDataDescriptor()
      22           6 :             || A.getRangeDescriptor() != r.getDataDescriptor())
      23           0 :             throw InvalidArgumentError(
      24           0 :                 "TransmissionLogLikelihood: operator and y/b/r not matching in size.");
      25           6 :     }
      26             : 
      27             :     template <typename data_t>
      28             :     TransmissionLogLikelihood<data_t>::TransmissionLogLikelihood(const LinearOperator<data_t>& A,
      29             :                                                                  const DataContainer<data_t>& y,
      30             :                                                                  const DataContainer<data_t>& b)
      31             :         : Functional<data_t>(A.getDomainDescriptor()), A_(A.clone()), y_{y}, b_{b}
      32           8 :     {
      33             :         // sanity check
      34           8 :         if (A.getRangeDescriptor() != y.getDataDescriptor()
      35           8 :             || A.getRangeDescriptor() != b.getDataDescriptor())
      36           0 :             throw InvalidArgumentError(
      37           0 :                 "TransmissionLogLikelihood: operator and y/b not matching in size.");
      38           8 :     }
      39             : 
      40             :     template <typename data_t>
      41             :     bool TransmissionLogLikelihood<data_t>::isDifferentiable() const
      42           0 :     {
      43           0 :         return true;
      44           0 :     }
      45             : 
      46             :     template <typename data_t>
      47             :     data_t TransmissionLogLikelihood<data_t>::evaluateImpl(const DataContainer<data_t>& x) const
      48           4 :     {
      49           4 :         if (x.getDataDescriptor() != A_->getDomainDescriptor()) {
      50           0 :             throw InvalidArgumentError(
      51           0 :                 "TransmissionLogLikelihood: given x is not the correct size");
      52           0 :         }
      53             : 
      54           4 :         auto result = static_cast<data_t>(0.0);
      55             : 
      56           4 :         auto Rx = A_->apply(x);
      57        1096 :         for (index_t i = 0; i < Rx.getSize(); ++i) {
      58        1092 :             data_t temp = b_[i] * std::exp(-Rx[i]);
      59        1092 :             if (r_.has_value())
      60         546 :                 temp += (*r_)[i];
      61             : 
      62        1092 :             result += temp - y_[i] * std::log(temp);
      63        1092 :         }
      64             : 
      65           4 :         return result;
      66           4 :     }
      67             : 
      68             :     template <typename data_t>
      69             :     void TransmissionLogLikelihood<data_t>::getGradientImpl(const DataContainer<data_t>& x,
      70             :                                                             DataContainer<data_t>& out) const
      71           4 :     {
      72             :         // Actual computation of the functional
      73        1092 :         auto translog = [&](auto in, auto i) {
      74        1092 :             data_t temp = b_[i] * std::exp(-in);
      75        1092 :             in = -temp;
      76             : 
      77        1092 :             if (r_.has_value())
      78         546 :                 in += y_[i] * temp / (temp + (*r_)[i]);
      79         546 :             else
      80         546 :                 in += y_[i];
      81             : 
      82        1092 :             return in;
      83        1092 :         };
      84             : 
      85           4 :         auto Rx = A_->apply(x);
      86        1096 :         for (index_t i = 0; i < Rx.getSize(); ++i) {
      87        1092 :             Rx[i] = translog(Rx[i], i);
      88        1092 :         }
      89           4 :         A_->applyAdjoint(Rx, out);
      90           4 :     }
      91             : 
      92             :     template <typename data_t>
      93             :     LinearOperator<data_t>
      94             :         TransmissionLogLikelihood<data_t>::getHessianImpl(const DataContainer<data_t>& x) const
      95           4 :     {
      96           4 :         auto scale = [&](auto in) {
      97           4 :             DataContainer<data_t> s(in.getDataDescriptor());
      98        1096 :             for (index_t i = 0; i < in.getSize(); ++i) {
      99        1092 :                 s[i] = b_[i] * std::exp(-in[i]);
     100        1092 :                 if (r_.has_value()) {
     101         546 :                     data_t tempR = s[i] + (*r_)[i];
     102         546 :                     s[i] += (*r_)[i] * y_[i] * s[i] / (tempR * tempR);
     103         546 :                 }
     104        1092 :             }
     105             : 
     106           4 :             return leaf(Scaling<data_t>(s));
     107           4 :         };
     108             : 
     109           4 :         auto Rx = A_->apply(x);
     110           4 :         auto s = scale(Rx);
     111             : 
     112             :         // Jacobian is the operator, plus chain rule
     113           4 :         return adjoint(*A_) * s * (*A_);
     114           4 :     }
     115             : 
     116             :     template <typename data_t>
     117             :     TransmissionLogLikelihood<data_t>* TransmissionLogLikelihood<data_t>::cloneImpl() const
     118           4 :     {
     119           4 :         if (r_.has_value()) {
     120           2 :             return new TransmissionLogLikelihood<data_t>(*A_, y_, b_, *r_);
     121           2 :         }
     122           2 :         return new TransmissionLogLikelihood<data_t>(*A_, y_, b_);
     123           2 :     }
     124             : 
     125             :     template <typename data_t>
     126             :     bool TransmissionLogLikelihood<data_t>::isEqual(const Functional<data_t>& other) const
     127           4 :     {
     128           4 :         if (!Functional<data_t>::isEqual(other))
     129           0 :             return false;
     130             : 
     131           4 :         auto otherTLL = downcast_safe<TransmissionLogLikelihood>(&other);
     132           4 :         if (!otherTLL)
     133           0 :             return false;
     134             : 
     135           4 :         if (y_ != otherTLL->y_ || b_ != otherTLL->b_)
     136           0 :             return false;
     137             : 
     138           4 :         if (r_.has_value() && otherTLL->r_.has_value() && *r_ != *otherTLL->r_)
     139           0 :             return false;
     140             : 
     141           4 :         return true;
     142           4 :     }
     143             : 
     144             :     // ------------------------------------------
     145             :     // explicit template instantiation
     146             :     template class TransmissionLogLikelihood<float>;
     147             :     template class TransmissionLogLikelihood<double>;
     148             :     // no complex instantiations, they make no sense
     149             : 
     150             : } // namespace elsa

Generated by: LCOV version 1.14