Line data Source code
1 : #include "Dictionary.h" 2 : #include "TypeCasts.hpp" 3 : 4 : namespace elsa 5 : { 6 : template <typename data_t> 7 : Dictionary<data_t>::Dictionary(const DataDescriptor& signalDescriptor, index_t nAtoms) 8 : : LinearOperator<data_t>(VolumeDescriptor({nAtoms}), signalDescriptor), 9 : _dictionary{DataContainer<data_t>(generateInitialData(signalDescriptor, nAtoms))}, 10 : _nAtoms{nAtoms} 11 18 : { 12 18 : } 13 : 14 : template <typename data_t> 15 : Dictionary<data_t>::Dictionary(const DataContainer<data_t>& dictionary) 16 : : LinearOperator<data_t>( 17 : VolumeDescriptor({getIdenticalBlocksDescriptor(dictionary).getNumberOfBlocks()}), 18 : getIdenticalBlocksDescriptor(dictionary).getDescriptorOfBlock(0)), 19 : _dictionary{dictionary}, 20 : _nAtoms{getIdenticalBlocksDescriptor(dictionary).getNumberOfBlocks()} 21 97 : { 22 718 : for (int i = 0; i < _nAtoms; ++i) { 23 621 : auto block = _dictionary.getBlock(i); 24 621 : data_t l2Norm = block.l2Norm(); 25 621 : if (l2Norm == 0) { 26 0 : throw InvalidArgumentError("Dictionary: initializing with 0-atom not possible"); 27 0 : } 28 : // don't normalize if the norm is very close to 1 already 29 621 : if (std::abs(l2Norm - 1) 30 621 : > std::numeric_limits<data_t>::epsilon() * std::abs(l2Norm + 1)) { 31 94 : block /= l2Norm; 32 94 : } 33 621 : } 34 97 : } 35 : 36 : template <typename data_t> 37 : const IdenticalBlocksDescriptor& 38 : Dictionary<data_t>::getIdenticalBlocksDescriptor(const DataContainer<data_t>& data) 39 287 : { 40 287 : try { 41 287 : return downcast_safe<IdenticalBlocksDescriptor>(data.getDataDescriptor()); 42 287 : } catch (const BadCastError&) { 43 2 : throw InvalidArgumentError( 44 2 : "Dictionary: cannot initialize from data without IdenticalBlocksDescriptor"); 45 2 : } 46 287 : } 47 : 48 : template <typename data_t> 49 : DataContainer<data_t> 50 : Dictionary<data_t>::generateInitialData(const DataDescriptor& signalDescriptor, 51 : index_t nAtoms) 52 18 : { 53 18 : Vector_t<data_t> randomData(signalDescriptor.getNumberOfCoefficients() * nAtoms); 54 18 : randomData.setRandom(); 55 18 : IdenticalBlocksDescriptor desc(nAtoms, signalDescriptor); 56 18 : DataContainer<data_t> initData{desc, randomData}; 57 : 58 167 : for (int i = 0; i < desc.getNumberOfBlocks(); ++i) { 59 149 : auto block = initData.getBlock(i); 60 149 : block /= block.l2Norm(); 61 149 : } 62 : 63 18 : return initData; 64 18 : } 65 : 66 : template <typename data_t> 67 : void Dictionary<data_t>::updateAtom(index_t j, const DataContainer<data_t>& atom) 68 17 : { 69 17 : if (j < 0 || j >= _nAtoms) 70 2 : throw InvalidArgumentError("Dictionary::updateAtom: atom index out of bounds"); 71 15 : if (*_rangeDescriptor != atom.getDataDescriptor()) 72 2 : throw InvalidArgumentError("Dictionary::updateAtom: atom has invalid size"); 73 13 : data_t l2Norm = atom.l2Norm(); 74 13 : if (l2Norm == 0) { 75 0 : throw InvalidArgumentError("Dictionary::updateAtom: updating to 0-atom not possible"); 76 0 : } 77 : // don't normalize if the norm is very close to 1 already 78 13 : _dictionary.getBlock(j) = 79 13 : (std::abs(l2Norm - 1) < std::numeric_limits<data_t>::epsilon() * std::abs(l2Norm + 1)) 80 13 : ? atom 81 13 : : (atom / l2Norm); 82 13 : } 83 : 84 : template <typename data_t> 85 : const DataContainer<data_t> Dictionary<data_t>::getAtom(index_t j) const 86 168 : { 87 168 : if (j < 0 || j >= _nAtoms) 88 2 : throw InvalidArgumentError("Dictionary: atom index out of bounds"); 89 166 : return _dictionary.getBlock(j); 90 166 : } 91 : 92 : template <typename data_t> 93 : index_t Dictionary<data_t>::getNumberOfAtoms() const 94 60 : { 95 60 : return _nAtoms; 96 60 : } 97 : 98 : template <typename data_t> 99 : Dictionary<data_t> Dictionary<data_t>::getSupportedDictionary(IndexVector_t support) const 100 4 : { 101 4 : Dictionary supportDict(*_rangeDescriptor, support.rows()); 102 4 : index_t j = 0; 103 : 104 9 : for (const auto& i : support) { 105 9 : if (i < 0 || i >= _nAtoms) 106 0 : throw InvalidArgumentError( 107 0 : "Dictionary::getSupportedDictionary: support contains out-of-bounds index"); 108 : 109 9 : supportDict.updateAtom(j, getAtom(i)); 110 9 : ++j; 111 9 : } 112 : 113 4 : return supportDict; 114 4 : } 115 : 116 : template <typename data_t> 117 : void Dictionary<data_t>::applyImpl(const DataContainer<data_t>& x, 118 : DataContainer<data_t>& Ax) const 119 18 : { 120 18 : Timer timeguard("Dictionary", "apply"); 121 : 122 18 : if (x.getSize() != _nAtoms || Ax.getDataDescriptor() != *_rangeDescriptor) 123 2 : throw InvalidArgumentError("Dictionary::apply: incorrect input/output sizes"); 124 : 125 16 : index_t j = 0; 126 16 : Ax = 0; 127 : 128 69 : for (const auto& x_j : x) { 129 69 : const auto& atom = getAtom(j); 130 69 : Ax += atom * x_j; // vector*scalar 131 : 132 69 : ++j; 133 69 : } 134 16 : } 135 : 136 : template <typename data_t> 137 : void Dictionary<data_t>::applyAdjointImpl(const DataContainer<data_t>& y, 138 : DataContainer<data_t>& Aty) const 139 12 : { 140 12 : Timer timeguard("Dictionary", "applyAdjoint"); 141 : 142 12 : if (Aty.getSize() != _nAtoms || y.getDataDescriptor() != *_rangeDescriptor) 143 2 : throw InvalidArgumentError("Dictionary::applyAdjoint: incorrect input/output sizes"); 144 : 145 10 : index_t i = 0; 146 10 : Aty = 0; 147 20 : for (auto& Aty_i : Aty) { 148 20 : const auto& atom_i = getAtom(i); 149 60 : for (int j = 0; j < atom_i.getSize(); ++j) { 150 40 : Aty_i += atom_i[j] * y[j]; 151 40 : } 152 20 : ++i; 153 20 : } 154 10 : } 155 : 156 : template <typename data_t> 157 : Dictionary<data_t>* Dictionary<data_t>::cloneImpl() const 158 65 : { 159 65 : return new Dictionary(_dictionary); 160 65 : } 161 : 162 : template <typename data_t> 163 : bool Dictionary<data_t>::isEqual(const LinearOperator<data_t>& other) const 164 6 : { 165 6 : if (!LinearOperator<data_t>::isEqual(other)) 166 0 : return false; 167 : 168 6 : auto otherDictionary = downcast_safe<Dictionary>(&other); 169 6 : if (!otherDictionary) 170 0 : return false; 171 : 172 6 : if (_dictionary != otherDictionary->_dictionary) 173 0 : return false; 174 : 175 6 : return true; 176 6 : } 177 : 178 : // ------------------------------------------ 179 : // explicit template instantiation 180 : template class Dictionary<float>; 181 : template class Dictionary<double>; 182 : 183 : } // namespace elsa