LCOV - code coverage report
Current view: top level - elsa/problems - Problem.cpp (source / functions) Hit Total Coverage
Test: coverage-all.lcov Lines: 100 107 93.5 %
Date: 2022-08-25 03:05:39 Functions: 64 88 72.7 %

          Line data    Source code
       1             : #include "Problem.h"
       2             : #include "Scaling.h"
       3             : #include "Logger.h"
       4             : #include "Timer.h"
       5             : 
       6             : namespace elsa
       7             : {
       8             :     template <typename data_t>
       9             :     Problem<data_t>::Problem(const Functional<data_t>& dataTerm,
      10             :                              const std::vector<RegularizationTerm<data_t>>& regTerms,
      11             :                              const DataContainer<data_t>& x0,
      12             :                              const std::optional<data_t> lipschitzConstant)
      13             :         : _dataTerm{dataTerm.clone()},
      14             :           _regTerms{regTerms},
      15             :           _currentSolution{x0},
      16             :           _lipschitzConstant{lipschitzConstant}
      17          64 :     {
      18             :         // sanity checks
      19          64 :         if (_dataTerm->getDomainDescriptor().getNumberOfCoefficients()
      20          64 :             != this->_currentSolution.getSize())
      21           0 :             throw InvalidArgumentError("Problem: domain of dataTerm and solution do not match");
      22          64 :     }
      23             : 
      24             :     template <typename data_t>
      25             :     Problem<data_t>::Problem(const Functional<data_t>& dataTerm,
      26             :                              const std::vector<RegularizationTerm<data_t>>& regTerms,
      27             :                              const std::optional<data_t> lipschitzConstant)
      28             :         : _dataTerm{dataTerm.clone()},
      29             :           _regTerms{regTerms},
      30             :           _currentSolution{dataTerm.getDomainDescriptor()},
      31             :           _lipschitzConstant{lipschitzConstant}
      32          19 :     {
      33          19 :         _currentSolution = 0;
      34          19 :     }
      35             : 
      36             :     template <typename data_t>
      37             :     Problem<data_t>::Problem(const Functional<data_t>& dataTerm,
      38             :                              const RegularizationTerm<data_t>& regTerm,
      39             :                              const DataContainer<data_t>& x0,
      40             :                              const std::optional<data_t> lipschitzConstant)
      41             :         : _dataTerm{dataTerm.clone()},
      42             :           _regTerms{regTerm},
      43             :           _currentSolution{x0},
      44             :           _lipschitzConstant{lipschitzConstant}
      45          42 :     {
      46             :         // sanity checks
      47          42 :         if (_dataTerm->getDomainDescriptor().getNumberOfCoefficients()
      48          42 :             != this->_currentSolution.getSize())
      49           0 :             throw InvalidArgumentError("Problem: domain of dataTerm and solution do not match");
      50          42 :     }
      51             : 
      52             :     template <typename data_t>
      53             :     Problem<data_t>::Problem(const Functional<data_t>& dataTerm,
      54             :                              const RegularizationTerm<data_t>& regTerm,
      55             :                              const std::optional<data_t> lipschitzConstant)
      56             :         : _dataTerm{dataTerm.clone()},
      57             :           _regTerms{regTerm},
      58             :           _currentSolution{dataTerm.getDomainDescriptor(), defaultHandlerType},
      59             :           _lipschitzConstant{lipschitzConstant}
      60          39 :     {
      61          39 :         _currentSolution = 0;
      62          39 :     }
      63             : 
      64             :     template <typename data_t>
      65             :     Problem<data_t>::Problem(const Functional<data_t>& dataTerm, const DataContainer<data_t>& x0,
      66             :                              const std::optional<data_t> lipschitzConstant)
      67             :         : _dataTerm{dataTerm.clone()}, _currentSolution{x0}, _lipschitzConstant{lipschitzConstant}
      68         114 :     {
      69             :         // sanity check
      70         114 :         if (_dataTerm->getDomainDescriptor().getNumberOfCoefficients()
      71         114 :             != this->_currentSolution.getSize())
      72           0 :             throw InvalidArgumentError("Problem: domain of dataTerm and solution do not match");
      73         114 :     }
      74             : 
      75             :     template <typename data_t>
      76             :     Problem<data_t>::Problem(const Functional<data_t>& dataTerm,
      77             :                              const std::optional<data_t> lipschitzConstant)
      78             :         : _dataTerm{dataTerm.clone()},
      79             :           _currentSolution{dataTerm.getDomainDescriptor(), defaultHandlerType},
      80             :           _lipschitzConstant{lipschitzConstant}
      81         188 :     {
      82         188 :         _currentSolution = 0;
      83         188 :     }
      84             : 
      85             :     template <typename data_t>
      86             :     Problem<data_t>::Problem(const Problem<data_t>& problem)
      87             :         : Cloneable<Problem<data_t>>(),
      88             :           _dataTerm{problem._dataTerm->clone()},
      89             :           _regTerms{problem._regTerms},
      90             :           _currentSolution{problem._currentSolution},
      91             :           _lipschitzConstant{problem._lipschitzConstant}
      92         213 :     {
      93         213 :     }
      94             : 
      95             :     template <typename data_t>
      96             :     const Functional<data_t>& Problem<data_t>::getDataTerm() const
      97         260 :     {
      98         260 :         return *_dataTerm;
      99         260 :     }
     100             : 
     101             :     template <typename data_t>
     102             :     const std::vector<RegularizationTerm<data_t>>& Problem<data_t>::getRegularizationTerms() const
     103         266 :     {
     104         266 :         return _regTerms;
     105         266 :     }
     106             : 
     107             :     template <typename data_t>
     108             :     const DataContainer<data_t>& Problem<data_t>::getCurrentSolution() const
     109         120 :     {
     110         120 :         return _currentSolution;
     111         120 :     }
     112             : 
     113             :     template <typename data_t>
     114             :     DataContainer<data_t>& Problem<data_t>::getCurrentSolution()
     115       10246 :     {
     116       10246 :         return _currentSolution;
     117       10246 :     }
     118             : 
     119             :     template <typename data_t>
     120             :     data_t Problem<data_t>::evaluateImpl()
     121          58 :     {
     122          58 :         data_t result = _dataTerm->evaluate(_currentSolution);
     123             : 
     124          58 :         for (auto& regTerm : _regTerms)
     125          22 :             result += regTerm.getWeight() * regTerm.getFunctional().evaluate(_currentSolution);
     126             : 
     127          58 :         return result;
     128          58 :     }
     129             : 
     130             :     template <typename data_t>
     131             :     void Problem<data_t>::getGradientImpl(DataContainer<data_t>& result)
     132       10997 :     {
     133       10997 :         _dataTerm->getGradient(_currentSolution, result);
     134             : 
     135       10997 :         for (auto& regTerm : _regTerms)
     136          42 :             result += regTerm.getWeight() * regTerm.getFunctional().getGradient(_currentSolution);
     137       10997 :     }
     138             : 
     139             :     template <typename data_t>
     140             :     LinearOperator<data_t> Problem<data_t>::getHessianImpl() const
     141          87 :     {
     142          87 :         auto hessian = _dataTerm->getHessian(_currentSolution);
     143             : 
     144          87 :         for (auto& regTerm : _regTerms) {
     145          26 :             Scaling weight(_currentSolution.getDataDescriptor(), regTerm.getWeight());
     146          26 :             hessian = hessian + (weight * regTerm.getFunctional().getHessian(_currentSolution));
     147          26 :         }
     148             : 
     149          87 :         return hessian;
     150          87 :     }
     151             : 
     152             :     template <typename data_t>
     153             :     data_t Problem<data_t>::getLipschitzConstantImpl(index_t nIterations) const
     154          34 :     {
     155          34 :         Timer guard("Problem", "Calculating Lipschitz constant");
     156          34 :         Logger::get("Problem")->info("Calculating Lipschitz constant");
     157             : 
     158          34 :         if (_lipschitzConstant.has_value()) {
     159           8 :             return _lipschitzConstant.value();
     160           8 :         }
     161             :         // compute the Lipschitz Constant as the largest eigenvalue of the Hessian
     162          26 :         const auto hessian = getHessian();
     163          26 :         DataContainer<data_t> dcB(hessian.getDomainDescriptor());
     164          26 :         dcB = 1;
     165         741 :         for (index_t i = 0; i < nIterations; i++) {
     166         715 :             dcB = hessian.apply(dcB);
     167         715 :             dcB = dcB / dcB.l2Norm();
     168         715 :         }
     169             : 
     170          26 :         return dcB.dot(hessian.apply(dcB)) / dcB.l2Norm();
     171          26 :     }
     172             : 
     173             :     template <typename data_t>
     174             :     Problem<data_t>* Problem<data_t>::cloneImpl() const
     175          12 :     {
     176          12 :         return new Problem(*this);
     177          12 :     }
     178             : 
     179             :     template <typename data_t>
     180             :     bool Problem<data_t>::isEqual(const Problem<data_t>& other) const
     181          74 :     {
     182          74 :         if (typeid(*this) != typeid(other))
     183          16 :             return false;
     184             : 
     185          58 :         if (_currentSolution != other._currentSolution)
     186           0 :             return false;
     187             : 
     188          58 :         if (*_dataTerm != *other._dataTerm)
     189           0 :             return false;
     190             : 
     191          58 :         if (_regTerms.size() != other._regTerms.size())
     192           0 :             return false;
     193             : 
     194          84 :         for (std::size_t i = 0; i < _regTerms.size(); ++i)
     195          26 :             if (_regTerms.at(i) != other._regTerms.at(i))
     196           0 :                 return false;
     197             : 
     198          58 :         return true;
     199          58 :     }
     200             : 
     201             :     template <typename data_t>
     202             :     data_t Problem<data_t>::evaluate()
     203          62 :     {
     204          62 :         return evaluateImpl();
     205          62 :     }
     206             : 
     207             :     template <typename data_t>
     208             :     DataContainer<data_t> Problem<data_t>::getGradient()
     209       11001 :     {
     210       11001 :         DataContainer<data_t> result(_currentSolution.getDataDescriptor(),
     211       11001 :                                      _currentSolution.getDataHandlerType());
     212       11001 :         getGradient(result);
     213       11001 :         return result;
     214       11001 :     }
     215             : 
     216             :     template <typename data_t>
     217             :     void Problem<data_t>::getGradient(DataContainer<data_t>& result)
     218       11001 :     {
     219       11001 :         getGradientImpl(result);
     220       11001 :     }
     221             : 
     222             :     template <typename data_t>
     223             :     LinearOperator<data_t> Problem<data_t>::getHessian() const
     224          91 :     {
     225          91 :         return getHessianImpl();
     226          91 :     }
     227             : 
     228             :     template <typename data_t>
     229             :     data_t Problem<data_t>::getLipschitzConstant(index_t nIterations) const
     230          46 :     {
     231          46 :         return getLipschitzConstantImpl(nIterations);
     232          46 :     }
     233             : 
     234             :     // ------------------------------------------
     235             :     // explicit template instantiation
     236             :     template class Problem<float>;
     237             : 
     238             :     template class Problem<double>;
     239             : 
     240             :     template class Problem<complex<float>>;
     241             : 
     242             :     template class Problem<complex<double>>;
     243             : 
     244             : } // namespace elsa

Generated by: LCOV version 1.14