LCOV - code coverage report
Current view: top level - elsa/functionals - WeightedL2Squared.cpp (source / functions) Hit Total Coverage
Test: coverage-all.lcov Lines: 24 29 82.8 %
Date: 2024-05-16 04:22:26 Functions: 26 32 81.2 %

          Line data    Source code
       1             : #include "WeightedL2Squared.h"
       2             : #include "DataContainer.h"
       3             : #include "LinearOperator.h"
       4             : #include "Scaling.h"
       5             : #include "TypeCasts.hpp"
       6             : 
       7             : #include <stdexcept>
       8             : 
       9             : namespace elsa
      10             : {
      11             :     template <typename data_t>
      12             :     WeightedL2Squared<data_t>::WeightedL2Squared(const DataContainer<data_t>& weights)
      13             :         : Functional<data_t>(weights.getDataDescriptor()), weights_{weights}
      14          24 :     {
      15          24 :     }
      16             : 
      17             :     template <typename data_t>
      18             :     bool WeightedL2Squared<data_t>::isDifferentiable() const
      19           0 :     {
      20           0 :         return true;
      21           0 :     }
      22             : 
      23             :     template <typename data_t>
      24             :     Scaling<data_t> WeightedL2Squared<data_t>::getWeightingOperator() const
      25           8 :     {
      26           8 :         return Scaling<data_t>(weights_);
      27           8 :     }
      28             : 
      29             :     template <typename data_t>
      30             :     data_t WeightedL2Squared<data_t>::evaluateImpl(const DataContainer<data_t>& Rx) const
      31           2 :     {
      32           2 :         auto temp = weights_ * Rx;
      33           2 :         return static_cast<data_t>(0.5) * Rx.dot(temp);
      34           2 :     }
      35             : 
      36             :     template <typename data_t>
      37             :     void WeightedL2Squared<data_t>::getGradientImpl(const DataContainer<data_t>& Rx,
      38             :                                                     DataContainer<data_t>& out) const
      39           4 :     {
      40           4 :         out = weights_ * Rx;
      41           4 :     }
      42             : 
      43             :     template <typename data_t>
      44             :     LinearOperator<data_t>
      45             :         WeightedL2Squared<data_t>::getHessianImpl(const DataContainer<data_t>&) const
      46           4 :     {
      47           4 :         return leaf(getWeightingOperator());
      48           4 :     }
      49             : 
      50             :     template <typename data_t>
      51             :     WeightedL2Squared<data_t>* WeightedL2Squared<data_t>::cloneImpl() const
      52           4 :     {
      53           4 :         return new WeightedL2Squared(weights_);
      54           4 :     }
      55             : 
      56             :     template <typename data_t>
      57             :     bool WeightedL2Squared<data_t>::isEqual(const Functional<data_t>& other) const
      58           4 :     {
      59           4 :         if (!Functional<data_t>::isEqual(other))
      60           0 :             return false;
      61             : 
      62           4 :         auto otherWL2 = downcast_safe<WeightedL2Squared>(&other);
      63           4 :         if (!otherWL2)
      64           0 :             return false;
      65             : 
      66           4 :         return weights_ == otherWL2->weights_;
      67           4 :     }
      68             : 
      69             :     // ------------------------------------------
      70             :     // explicit template instantiation
      71             :     template class WeightedL2Squared<float>;
      72             :     template class WeightedL2Squared<double>;
      73             :     template class WeightedL2Squared<complex<float>>;
      74             :     template class WeightedL2Squared<complex<double>>;
      75             : 
      76             : } // namespace elsa

Generated by: LCOV version 1.14