LCOV - code coverage report
Current view: top level - elsa/operators - Dictionary.cpp (source / functions) Hit Total Coverage
Test: coverage-all.lcov Lines: 98 108 90.7 %
Date: 2024-05-16 04:22:26 Functions: 22 24 91.7 %

          Line data    Source code
       1             : #include "Dictionary.h"
       2             : #include "TypeCasts.hpp"
       3             : 
       4             : namespace elsa
       5             : {
       6             :     template <typename data_t>
       7             :     Dictionary<data_t>::Dictionary(const DataDescriptor& signalDescriptor, index_t nAtoms)
       8             :         : LinearOperator<data_t>(VolumeDescriptor({nAtoms}), signalDescriptor),
       9             :           _dictionary{DataContainer<data_t>(generateInitialData(signalDescriptor, nAtoms))},
      10             :           _nAtoms{nAtoms}
      11          10 :     {
      12          10 :     }
      13             : 
      14             :     template <typename data_t>
      15             :     Dictionary<data_t>::Dictionary(const DataContainer<data_t>& dictionary)
      16             :         : LinearOperator<data_t>(
      17             :             VolumeDescriptor({getIdenticalBlocksDescriptor(dictionary).getNumberOfBlocks()}),
      18             :             getIdenticalBlocksDescriptor(dictionary).getDescriptorOfBlock(0)),
      19             :           _dictionary{dictionary},
      20             :           _nAtoms{getIdenticalBlocksDescriptor(dictionary).getNumberOfBlocks()}
      21          36 :     {
      22         297 :         for (int i = 0; i < _nAtoms; ++i) {
      23         261 :             auto block = _dictionary.getBlock(i);
      24         261 :             data_t l2Norm = block.l2Norm();
      25         261 :             if (l2Norm == 0) {
      26           0 :                 throw InvalidArgumentError("Dictionary: initializing with 0-atom not possible");
      27           0 :             }
      28             :             // don't normalize if the norm is very close to 1 already
      29         261 :             if (std::abs(l2Norm - 1)
      30         261 :                 > std::numeric_limits<data_t>::epsilon() * std::abs(l2Norm + 1)) {
      31          94 :                 block /= l2Norm;
      32          94 :             }
      33         261 :         }
      34          36 :     }
      35             : 
      36             :     template <typename data_t>
      37             :     const IdenticalBlocksDescriptor&
      38             :         Dictionary<data_t>::getIdenticalBlocksDescriptor(const DataContainer<data_t>& data)
      39         104 :     {
      40         104 :         try {
      41         104 :             return downcast_safe<IdenticalBlocksDescriptor>(data.getDataDescriptor());
      42         104 :         } catch (const BadCastError&) {
      43           2 :             throw InvalidArgumentError(
      44           2 :                 "Dictionary: cannot initialize from data without IdenticalBlocksDescriptor");
      45           2 :         }
      46         104 :     }
      47             : 
      48             :     template <typename data_t>
      49             :     DataContainer<data_t>
      50             :         Dictionary<data_t>::generateInitialData(const DataDescriptor& signalDescriptor,
      51             :                                                 index_t nAtoms)
      52          10 :     {
      53          10 :         Vector_t<data_t> randomData(signalDescriptor.getNumberOfCoefficients() * nAtoms);
      54          10 :         randomData.setRandom();
      55          10 :         IdenticalBlocksDescriptor desc(nAtoms, signalDescriptor);
      56          10 :         DataContainer<data_t> initData{desc, randomData};
      57             : 
      58          79 :         for (int i = 0; i < desc.getNumberOfBlocks(); ++i) {
      59          69 :             auto block = initData.getBlock(i);
      60          69 :             block /= block.l2Norm();
      61          69 :         }
      62             : 
      63          10 :         return initData;
      64          10 :     }
      65             : 
      66             :     template <typename data_t>
      67             :     void Dictionary<data_t>::updateAtom(index_t j, const DataContainer<data_t>& atom)
      68          17 :     {
      69          17 :         if (j < 0 || j >= _nAtoms)
      70           2 :             throw InvalidArgumentError("Dictionary::updateAtom: atom index out of bounds");
      71          15 :         if (*_rangeDescriptor != atom.getDataDescriptor())
      72           2 :             throw InvalidArgumentError("Dictionary::updateAtom: atom has invalid size");
      73          13 :         data_t l2Norm = atom.l2Norm();
      74          13 :         if (l2Norm == 0) {
      75           0 :             throw InvalidArgumentError("Dictionary::updateAtom: updating to 0-atom not possible");
      76           0 :         }
      77             :         // don't normalize if the norm is very close to 1 already
      78          13 :         _dictionary.getBlock(j) =
      79          13 :             (std::abs(l2Norm - 1) < std::numeric_limits<data_t>::epsilon() * std::abs(l2Norm + 1))
      80          13 :                 ? atom
      81          13 :                 : (atom / l2Norm);
      82          13 :     }
      83             : 
      84             :     template <typename data_t>
      85             :     const DataContainer<data_t> Dictionary<data_t>::getAtom(index_t j) const
      86         116 :     {
      87         116 :         if (j < 0 || j >= _nAtoms)
      88           2 :             throw InvalidArgumentError("Dictionary: atom index out of bounds");
      89         114 :         return materialize(_dictionary.getBlock(j));
      90         114 :     }
      91             : 
      92             :     template <typename data_t>
      93             :     index_t Dictionary<data_t>::getNumberOfAtoms() const
      94          60 :     {
      95          60 :         return _nAtoms;
      96          60 :     }
      97             : 
      98             :     template <typename data_t>
      99             :     Dictionary<data_t> Dictionary<data_t>::getSupportedDictionary(IndexVector_t support) const
     100           4 :     {
     101           4 :         Dictionary supportDict(*_rangeDescriptor, support.rows());
     102           4 :         index_t j = 0;
     103             : 
     104           4 :         if ((support.array() < 0).any() || (support.array() >= _nAtoms).any()) {
     105           0 :             throw InvalidArgumentError(
     106           0 :                 "Dictionary::getSupportedDictionary: support contains out-of-bounds index");
     107           0 :         }
     108             : 
     109           9 :         for (const auto& i : support) {
     110           9 :             supportDict.updateAtom(j, getAtom(i));
     111           9 :             ++j;
     112           9 :         }
     113             : 
     114           4 :         return supportDict;
     115           4 :     }
     116             : 
     117             :     template <typename data_t>
     118             :     void Dictionary<data_t>::applyImpl(const DataContainer<data_t>& x,
     119             :                                        DataContainer<data_t>& Ax) const
     120          10 :     {
     121          10 :         Timer timeguard("Dictionary", "apply");
     122             : 
     123          10 :         if (x.getSize() != _nAtoms || Ax.getDataDescriptor() != *_rangeDescriptor)
     124           2 :             throw InvalidArgumentError("Dictionary::apply: incorrect input/output sizes");
     125             : 
     126           8 :         index_t j = 0;
     127           8 :         Ax = 0;
     128             : 
     129          23 :         for (const auto& x_j : x) {
     130          23 :             const auto& atom = getAtom(j);
     131          23 :             Ax += atom * x_j; // vector*scalar
     132             : 
     133          23 :             ++j;
     134          23 :         }
     135           8 :     }
     136             : 
     137             :     template <typename data_t>
     138             :     void Dictionary<data_t>::applyAdjointImpl(const DataContainer<data_t>& y,
     139             :                                               DataContainer<data_t>& Aty) const
     140           8 :     {
     141           8 :         Timer timeguard("Dictionary", "applyAdjoint");
     142             : 
     143           8 :         if (Aty.getSize() != _nAtoms || y.getDataDescriptor() != *_rangeDescriptor)
     144           2 :             throw InvalidArgumentError("Dictionary::applyAdjoint: incorrect input/output sizes");
     145             : 
     146           6 :         index_t i = 0;
     147           6 :         Aty = 0;
     148          14 :         for (auto& Aty_i : Aty) {
     149          14 :             const auto& atom_i = getAtom(i);
     150          42 :             for (int j = 0; j < atom_i.getSize(); ++j) {
     151          28 :                 Aty_i += atom_i[j] * y[j];
     152          28 :             }
     153          14 :             ++i;
     154          14 :         }
     155           6 :     }
     156             : 
     157             :     template <typename data_t>
     158             :     Dictionary<data_t>* Dictionary<data_t>::cloneImpl() const
     159           4 :     {
     160           4 :         return new Dictionary(_dictionary);
     161           4 :     }
     162             : 
     163             :     template <typename data_t>
     164             :     bool Dictionary<data_t>::isEqual(const LinearOperator<data_t>& other) const
     165           2 :     {
     166           2 :         if (!LinearOperator<data_t>::isEqual(other))
     167           0 :             return false;
     168             : 
     169           2 :         auto otherDictionary = downcast_safe<Dictionary>(&other);
     170           2 :         if (!otherDictionary)
     171           0 :             return false;
     172             : 
     173           2 :         if (_dictionary != otherDictionary->_dictionary)
     174           0 :             return false;
     175             : 
     176           2 :         return true;
     177           2 :     }
     178             : 
     179             :     // ------------------------------------------
     180             :     // explicit template instantiation
     181             :     template class Dictionary<float>;
     182             :     template class Dictionary<double>;
     183             : 
     184             : } // namespace elsa

Generated by: LCOV version 1.14