LCOV - code coverage report
Current view: top level - elsa/operators - Dictionary.cpp (source / functions) Hit Total Coverage
Test: coverage-all.lcov Lines: 98 107 91.6 %
Date: 2022-08-25 03:05:39 Functions: 24 24 100.0 %

          Line data    Source code
       1             : #include "Dictionary.h"
       2             : #include "TypeCasts.hpp"
       3             : 
       4             : namespace elsa
       5             : {
       6             :     template <typename data_t>
       7             :     Dictionary<data_t>::Dictionary(const DataDescriptor& signalDescriptor, index_t nAtoms)
       8             :         : LinearOperator<data_t>(VolumeDescriptor({nAtoms}), signalDescriptor),
       9             :           _dictionary{DataContainer<data_t>(generateInitialData(signalDescriptor, nAtoms))},
      10             :           _nAtoms{nAtoms}
      11          18 :     {
      12          18 :     }
      13             : 
      14             :     template <typename data_t>
      15             :     Dictionary<data_t>::Dictionary(const DataContainer<data_t>& dictionary)
      16             :         : LinearOperator<data_t>(
      17             :             VolumeDescriptor({getIdenticalBlocksDescriptor(dictionary).getNumberOfBlocks()}),
      18             :             getIdenticalBlocksDescriptor(dictionary).getDescriptorOfBlock(0)),
      19             :           _dictionary{dictionary},
      20             :           _nAtoms{getIdenticalBlocksDescriptor(dictionary).getNumberOfBlocks()}
      21          97 :     {
      22         718 :         for (int i = 0; i < _nAtoms; ++i) {
      23         621 :             auto block = _dictionary.getBlock(i);
      24         621 :             data_t l2Norm = block.l2Norm();
      25         621 :             if (l2Norm == 0) {
      26           0 :                 throw InvalidArgumentError("Dictionary: initializing with 0-atom not possible");
      27           0 :             }
      28             :             // don't normalize if the norm is very close to 1 already
      29         621 :             if (std::abs(l2Norm - 1)
      30         621 :                 > std::numeric_limits<data_t>::epsilon() * std::abs(l2Norm + 1)) {
      31          94 :                 block /= l2Norm;
      32          94 :             }
      33         621 :         }
      34          97 :     }
      35             : 
      36             :     template <typename data_t>
      37             :     const IdenticalBlocksDescriptor&
      38             :         Dictionary<data_t>::getIdenticalBlocksDescriptor(const DataContainer<data_t>& data)
      39         287 :     {
      40         287 :         try {
      41         287 :             return downcast_safe<IdenticalBlocksDescriptor>(data.getDataDescriptor());
      42         287 :         } catch (const BadCastError&) {
      43           2 :             throw InvalidArgumentError(
      44           2 :                 "Dictionary: cannot initialize from data without IdenticalBlocksDescriptor");
      45           2 :         }
      46         287 :     }
      47             : 
      48             :     template <typename data_t>
      49             :     DataContainer<data_t>
      50             :         Dictionary<data_t>::generateInitialData(const DataDescriptor& signalDescriptor,
      51             :                                                 index_t nAtoms)
      52          18 :     {
      53          18 :         Vector_t<data_t> randomData(signalDescriptor.getNumberOfCoefficients() * nAtoms);
      54          18 :         randomData.setRandom();
      55          18 :         IdenticalBlocksDescriptor desc(nAtoms, signalDescriptor);
      56          18 :         DataContainer<data_t> initData{desc, randomData};
      57             : 
      58         167 :         for (int i = 0; i < desc.getNumberOfBlocks(); ++i) {
      59         149 :             auto block = initData.getBlock(i);
      60         149 :             block /= block.l2Norm();
      61         149 :         }
      62             : 
      63          18 :         return initData;
      64          18 :     }
      65             : 
      66             :     template <typename data_t>
      67             :     void Dictionary<data_t>::updateAtom(index_t j, const DataContainer<data_t>& atom)
      68          17 :     {
      69          17 :         if (j < 0 || j >= _nAtoms)
      70           2 :             throw InvalidArgumentError("Dictionary::updateAtom: atom index out of bounds");
      71          15 :         if (*_rangeDescriptor != atom.getDataDescriptor())
      72           2 :             throw InvalidArgumentError("Dictionary::updateAtom: atom has invalid size");
      73          13 :         data_t l2Norm = atom.l2Norm();
      74          13 :         if (l2Norm == 0) {
      75           0 :             throw InvalidArgumentError("Dictionary::updateAtom: updating to 0-atom not possible");
      76           0 :         }
      77             :         // don't normalize if the norm is very close to 1 already
      78          13 :         _dictionary.getBlock(j) =
      79          13 :             (std::abs(l2Norm - 1) < std::numeric_limits<data_t>::epsilon() * std::abs(l2Norm + 1))
      80          13 :                 ? atom
      81          13 :                 : (atom / l2Norm);
      82          13 :     }
      83             : 
      84             :     template <typename data_t>
      85             :     const DataContainer<data_t> Dictionary<data_t>::getAtom(index_t j) const
      86         168 :     {
      87         168 :         if (j < 0 || j >= _nAtoms)
      88           2 :             throw InvalidArgumentError("Dictionary: atom index out of bounds");
      89         166 :         return _dictionary.getBlock(j);
      90         166 :     }
      91             : 
      92             :     template <typename data_t>
      93             :     index_t Dictionary<data_t>::getNumberOfAtoms() const
      94          60 :     {
      95          60 :         return _nAtoms;
      96          60 :     }
      97             : 
      98             :     template <typename data_t>
      99             :     Dictionary<data_t> Dictionary<data_t>::getSupportedDictionary(IndexVector_t support) const
     100           4 :     {
     101           4 :         Dictionary supportDict(*_rangeDescriptor, support.rows());
     102           4 :         index_t j = 0;
     103             : 
     104           9 :         for (const auto& i : support) {
     105           9 :             if (i < 0 || i >= _nAtoms)
     106           0 :                 throw InvalidArgumentError(
     107           0 :                     "Dictionary::getSupportedDictionary: support contains out-of-bounds index");
     108             : 
     109           9 :             supportDict.updateAtom(j, getAtom(i));
     110           9 :             ++j;
     111           9 :         }
     112             : 
     113           4 :         return supportDict;
     114           4 :     }
     115             : 
     116             :     template <typename data_t>
     117             :     void Dictionary<data_t>::applyImpl(const DataContainer<data_t>& x,
     118             :                                        DataContainer<data_t>& Ax) const
     119          18 :     {
     120          18 :         Timer timeguard("Dictionary", "apply");
     121             : 
     122          18 :         if (x.getSize() != _nAtoms || Ax.getDataDescriptor() != *_rangeDescriptor)
     123           2 :             throw InvalidArgumentError("Dictionary::apply: incorrect input/output sizes");
     124             : 
     125          16 :         index_t j = 0;
     126          16 :         Ax = 0;
     127             : 
     128          69 :         for (const auto& x_j : x) {
     129          69 :             const auto& atom = getAtom(j);
     130          69 :             Ax += atom * x_j; // vector*scalar
     131             : 
     132          69 :             ++j;
     133          69 :         }
     134          16 :     }
     135             : 
     136             :     template <typename data_t>
     137             :     void Dictionary<data_t>::applyAdjointImpl(const DataContainer<data_t>& y,
     138             :                                               DataContainer<data_t>& Aty) const
     139          12 :     {
     140          12 :         Timer timeguard("Dictionary", "applyAdjoint");
     141             : 
     142          12 :         if (Aty.getSize() != _nAtoms || y.getDataDescriptor() != *_rangeDescriptor)
     143           2 :             throw InvalidArgumentError("Dictionary::applyAdjoint: incorrect input/output sizes");
     144             : 
     145          10 :         index_t i = 0;
     146          10 :         Aty = 0;
     147          20 :         for (auto& Aty_i : Aty) {
     148          20 :             const auto& atom_i = getAtom(i);
     149          60 :             for (int j = 0; j < atom_i.getSize(); ++j) {
     150          40 :                 Aty_i += atom_i[j] * y[j];
     151          40 :             }
     152          20 :             ++i;
     153          20 :         }
     154          10 :     }
     155             : 
     156             :     template <typename data_t>
     157             :     Dictionary<data_t>* Dictionary<data_t>::cloneImpl() const
     158          65 :     {
     159          65 :         return new Dictionary(_dictionary);
     160          65 :     }
     161             : 
     162             :     template <typename data_t>
     163             :     bool Dictionary<data_t>::isEqual(const LinearOperator<data_t>& other) const
     164           6 :     {
     165           6 :         if (!LinearOperator<data_t>::isEqual(other))
     166           0 :             return false;
     167             : 
     168           6 :         auto otherDictionary = downcast_safe<Dictionary>(&other);
     169           6 :         if (!otherDictionary)
     170           0 :             return false;
     171             : 
     172           6 :         if (_dictionary != otherDictionary->_dictionary)
     173           0 :             return false;
     174             : 
     175           6 :         return true;
     176           6 :     }
     177             : 
     178             :     // ------------------------------------------
     179             :     // explicit template instantiation
     180             :     template class Dictionary<float>;
     181             :     template class Dictionary<double>;
     182             : 
     183             : } // namespace elsa

Generated by: LCOV version 1.14