LCOV - code coverage report
Current view: top level - elsa/functionals - RicianLoss.cpp (source / functions) Hit Total Coverage
Test: coverage-all.lcov Lines: 125 138 90.6 %
Date: 2024-05-14 04:15:33 Functions: 18 22 81.8 %

          Line data    Source code
       1             : #include "RicianLoss.h"
       2             : #include "DataContainer.h"
       3             : #include "TypeCasts.hpp"
       4             : #include "IdenticalBlocksDescriptor.h"
       5             : #include "Identity.h"
       6             : #include "LinearOperator.h"
       7             : #include "Scaling.h"
       8             : #include "BlockLinearOperator.h"
       9             : #include "ZeroOperator.h"
      10             : #include "Timer.h"
      11             : #include "Assertions.h"
      12             : 
      13             : namespace elsa
      14             : {
      15             :     template <typename data_t>
      16             :     std::unique_ptr<DataDescriptor> RicianLoss<data_t>::generate_placeholder_descriptor()
      17           0 :     {
      18           0 :         index_t numOfDims{1};
      19           0 :         IndexVector_t dims(numOfDims);
      20           0 :         dims << 1;
      21           0 :         return std::make_unique<VolumeDescriptor>(dims);
      22           0 :     }
      23             : 
      24             :     template <typename data_t>
      25             :     RandomBlocksDescriptor RicianLoss<data_t>::generate_descriptors(const DataDescriptor& desc1,
      26             :                                                                     const DataDescriptor& desc2)
      27          24 :     {
      28          24 :         std::vector<std::unique_ptr<DataDescriptor>> descs;
      29             : 
      30          24 :         descs.emplace_back(desc1.clone());
      31          24 :         descs.emplace_back(desc2.clone());
      32          24 :         return RandomBlocksDescriptor(descs);
      33          24 :     }
      34             : 
      35             :     template <typename data_t>
      36             :     RicianLoss<data_t>::RicianLoss(const DataContainer<data_t>& ffa,
      37             :                                    const DataContainer<data_t>& ffb, const DataContainer<data_t>& a,
      38             :                                    const DataContainer<data_t>& b,
      39             :                                    const LinearOperator<data_t>& absorp_op,
      40             :                                    const LinearOperator<data_t>& axdt_op, index_t N,
      41             :                                    bool approximate)
      42             :         : Functional<data_t>(
      43             :             generate_descriptors(absorp_op.getDomainDescriptor(), axdt_op.getDomainDescriptor())),
      44             :           ffa_(ffa),
      45             :           ffb_(ffb),
      46             :           a_tilde_(a),
      47             :           b_tilde_(b),
      48             :           absorp_op_(absorp_op.clone()),
      49             :           axdt_op_(axdt_op.clone()),
      50             :           alpha_(ffb / ffa),
      51             :           d_tilde_(b_tilde_ / a_tilde_ / alpha_),
      52             :           N_(static_cast<data_t>(N)),
      53             :           approximate_(approximate)
      54          24 :     {
      55          24 :     }
      56             : 
      57             :     template <typename data_t>
      58             :     bool RicianLoss<data_t>::isDifferentiable() const
      59           0 :     {
      60           0 :         return true;
      61           0 :     }
      62             : 
      63             :     template <typename data_t>
      64             :     data_t RicianLoss<data_t>::evaluateImpl(const DataContainer<data_t>& Rx) const
      65           4 :     {
      66           4 :         Timer timeguard("AXDTStatRecon", "evaluate");
      67             : 
      68           4 :         const auto mu = materialize(Rx.getBlock(0));
      69           4 :         const auto eta = materialize(Rx.getBlock(1));
      70             : 
      71           4 :         auto log_d = -axdt_op_->apply(eta);
      72           4 :         auto d = exp(log_d);
      73             : 
      74           4 :         auto a = exp(-absorp_op_->apply(mu)) * ffa_;
      75           4 :         auto log_a = log(a);
      76             : 
      77           4 :         auto alpha_d = alpha_ * d;
      78             : 
      79           4 :         if (approximate_) {
      80           2 :             auto reorder = sq(a_tilde_ - a);
      81           2 :             reorder *= 2.0f;
      82           2 :             reorder += sq(b_tilde_ - (a * alpha_d));
      83           2 :             reorder /= a;
      84             : 
      85           2 :             return log_a.sum() + (N_ / 4.0f) * reorder.sum();
      86           2 :         } else {
      87             : 
      88           2 :             auto reorder = sq(a_tilde_);
      89           2 :             reorder *= 2.0f;
      90           2 :             reorder += sq(b_tilde_);
      91           2 :             reorder /= a;
      92             : 
      93           2 :             return 1.5f * log_a.sum() + (N_ * 0.5f) * a.sum()
      94           2 :                    + (N_ / 4.0f) * (reorder.sum() + a.dot(sq(alpha_d)))
      95           2 :                    - axdt::log_bessel_0((N_ / 2) * b_tilde_ * alpha_d).sum();
      96           2 :         }
      97           4 :     }
      98             : 
      99             :     template <typename data_t>
     100             :     void RicianLoss<data_t>::getGradientImpl(const DataContainer<data_t>& Rx,
     101             :                                              DataContainer<data_t>& out) const
     102           4 :     {
     103           4 :         Timer timeguard("AXDTStatRecon", "getGradient");
     104             : 
     105           4 :         const auto mu = materialize(Rx.getBlock(0));
     106           4 :         const auto eta = materialize(Rx.getBlock(1));
     107             : 
     108           4 :         auto log_d = -axdt_op_->apply(eta);
     109           4 :         auto d = exp(log_d);
     110             : 
     111           4 :         auto grad_eta = empty<data_t>(out.getBlock(1).getDataDescriptor());
     112             : 
     113           4 :         auto a = exp(-absorp_op_->apply(mu)) * ffa_;
     114           4 :         auto log_a = log(a);
     115             : 
     116             :         // Some temporaries, potential improvement? First measure!
     117           4 :         auto sqa = sq(a);
     118           4 :         auto sqa_tilde = sq(a_tilde_);
     119             : 
     120           4 :         auto alpha_d = alpha_ * d;
     121             : 
     122           4 :         auto tmp_a_alpha_d = a * alpha_d; // Might reuse this later in approximation codepath
     123             : 
     124           4 :         auto tmp_mu = 2.0f * (sqa - sqa_tilde);
     125           4 :         tmp_mu += sq(tmp_a_alpha_d);
     126           4 :         tmp_mu -= sq(b_tilde_);
     127           4 :         tmp_mu /= a;
     128           4 :         tmp_mu *= -N_ / 4.0f;
     129           4 :         tmp_mu -= approximate_ ? 1.0f : 1.5f; // only part that differs
     130           4 :         auto grad_mu = absorp_op_->applyAdjoint(tmp_mu);
     131             : 
     132           4 :         if (approximate_) {
     133           2 :             tmp_a_alpha_d -= b_tilde_;
     134           2 :             tmp_a_alpha_d *= alpha_d;
     135           2 :             tmp_a_alpha_d *= -N_ * 0.5f;
     136             : 
     137           2 :             grad_eta = axdt_op_->applyAdjoint(tmp_a_alpha_d);
     138           2 :         } else {
     139           2 :             auto tmp = a * sq(alpha_d);
     140           2 :             tmp *= -N_ * 0.5f;
     141           2 :             auto grad_eta_tmp_bessel = b_tilde_ * alpha_d;
     142           2 :             grad_eta_tmp_bessel *= N_ * 0.5f;
     143           2 :             tmp += grad_eta_tmp_bessel * axdt::quot_bessel_1_0(grad_eta_tmp_bessel);
     144             : 
     145           2 :             grad_eta = axdt_op_->applyAdjoint(tmp);
     146           2 :         }
     147             : 
     148           4 :         out.getBlock(0) = grad_mu;
     149           4 :         out.getBlock(1) = grad_eta;
     150           4 :     }
     151             : 
     152             :     template <typename data_t>
     153             :     LinearOperator<data_t> RicianLoss<data_t>::getHessianImpl(const DataContainer<data_t>& Rx) const
     154           4 :     {
     155           4 :         Timer timeguard("AXDTStatRecon", "getHessian");
     156             : 
     157           4 :         const auto mu = materialize(Rx.getBlock(0));
     158           4 :         const auto eta = materialize(Rx.getBlock(1));
     159             : 
     160           4 :         using BlockType = typename BlockLinearOperator<data_t>::BlockType;
     161             : 
     162           4 :         auto d = exp(-axdt_op_->apply(eta));
     163           4 :         auto a = exp(-absorp_op_->apply(mu)) * ffa_;
     164             : 
     165           4 :         auto alpha_d = alpha_ * d;
     166             : 
     167           4 :         auto H_1_1 =
     168           4 :             N_ * (2.f * sq(a) + 2.f * sq(a_tilde_) + (sq(a * alpha_d)) + sq(b_tilde_)) / 4.f / a;
     169           4 :         auto H_1_2 = N_ * 0.5f * a * sq(alpha_d);
     170           4 :         auto H_2_2 = emptylike(a);
     171           4 :         if (approximate_) {
     172           2 :             H_2_2 = (N_ * 0.5f) * (alpha_d * (2.f * a * alpha_d - b_tilde_));
     173           2 :         } else {
     174           2 :             H_2_2 = N_ * a * sq(alpha_d);
     175           2 :             auto z = (N_ * 0.5f) * b_tilde_ * alpha_d;
     176           2 :             auto quot_z = axdt::quot_bessel_1_0(z);
     177           2 :             H_2_2 -= sq(z) * (1.f - sq(quot_z));
     178           2 :         }
     179             : 
     180           4 :         auto hessian_absorp = [&]() {
     181           4 :             auto hessian_absorp_0 = adjoint(*absorp_op_) * H_1_1 * *absorp_op_;
     182           4 :             auto hessian_absorp_1 = adjoint(*absorp_op_) * H_1_2 * *axdt_op_;
     183             : 
     184           4 :             typename BlockLinearOperator<data_t>::OperatorList ops;
     185           4 :             ops.emplace_back(hessian_absorp_0.clone());
     186           4 :             ops.emplace_back(hessian_absorp_1.clone());
     187           4 :             return BlockLinearOperator<data_t>(ops, BlockType::COL).clone();
     188           4 :         }();
     189             : 
     190           4 :         auto hessian_axdt = [&]() {
     191           4 :             auto hessian_axdt_0 = adjoint(*axdt_op_) * H_1_2 * *absorp_op_;
     192           4 :             auto hessian_axdt_1 = adjoint(*axdt_op_) * H_2_2 * *axdt_op_;
     193             : 
     194           4 :             typename BlockLinearOperator<data_t>::OperatorList ops;
     195           4 :             ops.emplace_back(hessian_axdt_0.clone());
     196           4 :             ops.emplace_back(hessian_axdt_1.clone());
     197           4 :             return BlockLinearOperator<data_t>(ops, BlockType::COL).clone();
     198           4 :         }();
     199             : 
     200           4 :         typename BlockLinearOperator<data_t>::OperatorList ops;
     201           4 :         ops.emplace_back(std::move(hessian_absorp));
     202           4 :         ops.emplace_back(std::move(hessian_axdt));
     203           4 :         return leaf(BlockLinearOperator<data_t>(ops, BlockType::ROW));
     204           4 :     }
     205             : 
     206             :     template <typename data_t>
     207             :     RicianLoss<data_t>* RicianLoss<data_t>::cloneImpl() const
     208           4 :     {
     209           4 :         return new RicianLoss(ffa_, ffb_, a_tilde_, b_tilde_, *absorp_op_, *axdt_op_,
     210           4 :                               static_cast<index_t>(N_), approximate_);
     211           4 :     }
     212             : 
     213             :     template <typename data_t>
     214             :     bool RicianLoss<data_t>::isEqual(const Functional<data_t>& other) const
     215           4 :     {
     216           4 :         if (!Functional<data_t>::isEqual(other))
     217           0 :             return false;
     218             : 
     219           4 :         auto otherFn = downcast_safe<RicianLoss>(&other);
     220           4 :         if (!otherFn)
     221           0 :             return false;
     222             : 
     223           4 :         if (otherFn->approximate_ != approximate_)
     224           0 :             return false;
     225             : 
     226           4 :         if (otherFn->ffa_ != ffa_ || otherFn->ffb_ != ffb_ || otherFn->a_tilde_ != a_tilde_
     227           4 :             || otherFn->b_tilde_ != b_tilde_ || *(otherFn->absorp_op_) != *(absorp_op_)
     228           4 :             || *(otherFn->axdt_op_) != *(axdt_op_) || otherFn->N_ != N_)
     229           0 :             return false;
     230           4 :         else
     231           4 :             return true;
     232           4 :     }
     233             : 
     234             :     // ------------------------------------------
     235             :     // explicit template instantiation
     236             :     template class RicianLoss<float>;
     237             :     template class RicianLoss<double>;
     238             : } // namespace elsa

Generated by: LCOV version 1.14