Line data Source code
1 : #include "Dictionary.h" 2 : #include "TypeCasts.hpp" 3 : 4 : namespace elsa 5 : { 6 : template <typename data_t> 7 : Dictionary<data_t>::Dictionary(const DataDescriptor& signalDescriptor, index_t nAtoms) 8 : : LinearOperator<data_t>(VolumeDescriptor({nAtoms}), signalDescriptor), 9 : _dictionary{DataContainer<data_t>(generateInitialData(signalDescriptor, nAtoms))}, 10 : _nAtoms{nAtoms} 11 10 : { 12 10 : } 13 : 14 : template <typename data_t> 15 : Dictionary<data_t>::Dictionary(const DataContainer<data_t>& dictionary) 16 : : LinearOperator<data_t>( 17 : VolumeDescriptor({getIdenticalBlocksDescriptor(dictionary).getNumberOfBlocks()}), 18 : getIdenticalBlocksDescriptor(dictionary).getDescriptorOfBlock(0)), 19 : _dictionary{dictionary}, 20 : _nAtoms{getIdenticalBlocksDescriptor(dictionary).getNumberOfBlocks()} 21 36 : { 22 297 : for (int i = 0; i < _nAtoms; ++i) { 23 261 : auto block = _dictionary.getBlock(i); 24 261 : data_t l2Norm = block.l2Norm(); 25 261 : if (l2Norm == 0) { 26 0 : throw InvalidArgumentError("Dictionary: initializing with 0-atom not possible"); 27 0 : } 28 : // don't normalize if the norm is very close to 1 already 29 261 : if (std::abs(l2Norm - 1) 30 261 : > std::numeric_limits<data_t>::epsilon() * std::abs(l2Norm + 1)) { 31 94 : block /= l2Norm; 32 94 : } 33 261 : } 34 36 : } 35 : 36 : template <typename data_t> 37 : const IdenticalBlocksDescriptor& 38 : Dictionary<data_t>::getIdenticalBlocksDescriptor(const DataContainer<data_t>& data) 39 104 : { 40 104 : try { 41 104 : return downcast_safe<IdenticalBlocksDescriptor>(data.getDataDescriptor()); 42 104 : } catch (const BadCastError&) { 43 2 : throw InvalidArgumentError( 44 2 : "Dictionary: cannot initialize from data without IdenticalBlocksDescriptor"); 45 2 : } 46 104 : } 47 : 48 : template <typename data_t> 49 : DataContainer<data_t> 50 : Dictionary<data_t>::generateInitialData(const DataDescriptor& signalDescriptor, 51 : index_t nAtoms) 52 10 : { 53 10 : Vector_t<data_t> randomData(signalDescriptor.getNumberOfCoefficients() * nAtoms); 54 10 : randomData.setRandom(); 55 10 : IdenticalBlocksDescriptor desc(nAtoms, signalDescriptor); 56 10 : DataContainer<data_t> initData{desc, randomData}; 57 : 58 79 : for (int i = 0; i < desc.getNumberOfBlocks(); ++i) { 59 69 : auto block = initData.getBlock(i); 60 69 : block /= block.l2Norm(); 61 69 : } 62 : 63 10 : return initData; 64 10 : } 65 : 66 : template <typename data_t> 67 : void Dictionary<data_t>::updateAtom(index_t j, const DataContainer<data_t>& atom) 68 17 : { 69 17 : if (j < 0 || j >= _nAtoms) 70 2 : throw InvalidArgumentError("Dictionary::updateAtom: atom index out of bounds"); 71 15 : if (*_rangeDescriptor != atom.getDataDescriptor()) 72 2 : throw InvalidArgumentError("Dictionary::updateAtom: atom has invalid size"); 73 13 : data_t l2Norm = atom.l2Norm(); 74 13 : if (l2Norm == 0) { 75 0 : throw InvalidArgumentError("Dictionary::updateAtom: updating to 0-atom not possible"); 76 0 : } 77 : // don't normalize if the norm is very close to 1 already 78 13 : _dictionary.getBlock(j) = 79 13 : (std::abs(l2Norm - 1) < std::numeric_limits<data_t>::epsilon() * std::abs(l2Norm + 1)) 80 13 : ? atom 81 13 : : (atom / l2Norm); 82 13 : } 83 : 84 : template <typename data_t> 85 : const DataContainer<data_t> Dictionary<data_t>::getAtom(index_t j) const 86 116 : { 87 116 : if (j < 0 || j >= _nAtoms) 88 2 : throw InvalidArgumentError("Dictionary: atom index out of bounds"); 89 114 : return materialize(_dictionary.getBlock(j)); 90 114 : } 91 : 92 : template <typename data_t> 93 : index_t Dictionary<data_t>::getNumberOfAtoms() const 94 60 : { 95 60 : return _nAtoms; 96 60 : } 97 : 98 : template <typename data_t> 99 : Dictionary<data_t> Dictionary<data_t>::getSupportedDictionary(IndexVector_t support) const 100 4 : { 101 4 : Dictionary supportDict(*_rangeDescriptor, support.rows()); 102 4 : index_t j = 0; 103 : 104 4 : if ((support.array() < 0).any() || (support.array() >= _nAtoms).any()) { 105 0 : throw InvalidArgumentError( 106 0 : "Dictionary::getSupportedDictionary: support contains out-of-bounds index"); 107 0 : } 108 : 109 9 : for (const auto& i : support) { 110 9 : supportDict.updateAtom(j, getAtom(i)); 111 9 : ++j; 112 9 : } 113 : 114 4 : return supportDict; 115 4 : } 116 : 117 : template <typename data_t> 118 : void Dictionary<data_t>::applyImpl(const DataContainer<data_t>& x, 119 : DataContainer<data_t>& Ax) const 120 10 : { 121 10 : Timer timeguard("Dictionary", "apply"); 122 : 123 10 : if (x.getSize() != _nAtoms || Ax.getDataDescriptor() != *_rangeDescriptor) 124 2 : throw InvalidArgumentError("Dictionary::apply: incorrect input/output sizes"); 125 : 126 8 : index_t j = 0; 127 8 : Ax = 0; 128 : 129 23 : for (const auto& x_j : x) { 130 23 : const auto& atom = getAtom(j); 131 23 : Ax += atom * x_j; // vector*scalar 132 : 133 23 : ++j; 134 23 : } 135 8 : } 136 : 137 : template <typename data_t> 138 : void Dictionary<data_t>::applyAdjointImpl(const DataContainer<data_t>& y, 139 : DataContainer<data_t>& Aty) const 140 8 : { 141 8 : Timer timeguard("Dictionary", "applyAdjoint"); 142 : 143 8 : if (Aty.getSize() != _nAtoms || y.getDataDescriptor() != *_rangeDescriptor) 144 2 : throw InvalidArgumentError("Dictionary::applyAdjoint: incorrect input/output sizes"); 145 : 146 6 : index_t i = 0; 147 6 : Aty = 0; 148 14 : for (auto& Aty_i : Aty) { 149 14 : const auto& atom_i = getAtom(i); 150 42 : for (int j = 0; j < atom_i.getSize(); ++j) { 151 28 : Aty_i += atom_i[j] * y[j]; 152 28 : } 153 14 : ++i; 154 14 : } 155 6 : } 156 : 157 : template <typename data_t> 158 : Dictionary<data_t>* Dictionary<data_t>::cloneImpl() const 159 4 : { 160 4 : return new Dictionary(_dictionary); 161 4 : } 162 : 163 : template <typename data_t> 164 : bool Dictionary<data_t>::isEqual(const LinearOperator<data_t>& other) const 165 2 : { 166 2 : if (!LinearOperator<data_t>::isEqual(other)) 167 0 : return false; 168 : 169 2 : auto otherDictionary = downcast_safe<Dictionary>(&other); 170 2 : if (!otherDictionary) 171 0 : return false; 172 : 173 2 : if (_dictionary != otherDictionary->_dictionary) 174 0 : return false; 175 : 176 2 : return true; 177 2 : } 178 : 179 : // ------------------------------------------ 180 : // explicit template instantiation 181 : template class Dictionary<float>; 182 : template class Dictionary<double>; 183 : 184 : } // namespace elsa