Line data Source code
1 : #include "Dictionary.h" 2 : #include "TypeCasts.hpp" 3 : 4 : namespace elsa 5 : { 6 : template <typename data_t> 7 0 : Dictionary<data_t>::Dictionary(const DataDescriptor& signalDescriptor, index_t nAtoms) 8 0 : : LinearOperator<data_t>(VolumeDescriptor({nAtoms}), signalDescriptor), 9 0 : _dictionary{DataContainer<data_t>(generateInitialData(signalDescriptor, nAtoms))}, 10 0 : _nAtoms{nAtoms} 11 : { 12 0 : } 13 : 14 : template <typename data_t> 15 0 : Dictionary<data_t>::Dictionary(const DataContainer<data_t>& dictionary) 16 : : LinearOperator<data_t>( 17 0 : VolumeDescriptor({getIdenticalBlocksDescriptor(dictionary).getNumberOfBlocks()}), 18 0 : getIdenticalBlocksDescriptor(dictionary).getDescriptorOfBlock(0)), 19 0 : _dictionary{dictionary}, 20 0 : _nAtoms{getIdenticalBlocksDescriptor(dictionary).getNumberOfBlocks()} 21 : { 22 0 : for (int i = 0; i < _nAtoms; ++i) { 23 0 : auto block = _dictionary.getBlock(i); 24 0 : data_t l2Norm = block.l2Norm(); 25 0 : if (l2Norm == 0) { 26 0 : throw InvalidArgumentError("Dictionary: initializing with 0-atom not possible"); 27 : } 28 : // don't normalize if the norm is very close to 1 already 29 0 : if (std::abs(l2Norm - 1) 30 0 : > std::numeric_limits<data_t>::epsilon() * std::abs(l2Norm + 1)) { 31 0 : block /= l2Norm; 32 : } 33 : } 34 0 : } 35 : 36 : template <typename data_t> 37 : const IdenticalBlocksDescriptor& 38 0 : Dictionary<data_t>::getIdenticalBlocksDescriptor(const DataContainer<data_t>& data) 39 : { 40 : try { 41 0 : return downcast_safe<IdenticalBlocksDescriptor>(data.getDataDescriptor()); 42 0 : } catch (const BadCastError&) { 43 0 : throw InvalidArgumentError( 44 : "Dictionary: cannot initialize from data without IdenticalBlocksDescriptor"); 45 : } 46 : } 47 : 48 : template <typename data_t> 49 : DataContainer<data_t> 50 0 : Dictionary<data_t>::generateInitialData(const DataDescriptor& signalDescriptor, 51 : index_t nAtoms) 52 : { 53 0 : Vector_t<data_t> randomData(signalDescriptor.getNumberOfCoefficients() * nAtoms); 54 0 : randomData.setRandom(); 55 0 : IdenticalBlocksDescriptor desc(nAtoms, signalDescriptor); 56 0 : DataContainer<data_t> initData{desc, randomData}; 57 : 58 0 : for (int i = 0; i < desc.getNumberOfBlocks(); ++i) { 59 0 : auto block = initData.getBlock(i); 60 0 : block /= block.l2Norm(); 61 : } 62 : 63 0 : return initData; 64 0 : } 65 : 66 : template <typename data_t> 67 0 : void Dictionary<data_t>::updateAtom(index_t j, const DataContainer<data_t>& atom) 68 : { 69 0 : if (j < 0 || j >= _nAtoms) 70 0 : throw InvalidArgumentError("Dictionary::updateAtom: atom index out of bounds"); 71 0 : if (*_rangeDescriptor != atom.getDataDescriptor()) 72 0 : throw InvalidArgumentError("Dictionary::updateAtom: atom has invalid size"); 73 0 : data_t l2Norm = atom.l2Norm(); 74 0 : if (l2Norm == 0) { 75 0 : throw InvalidArgumentError("Dictionary::updateAtom: updating to 0-atom not possible"); 76 : } 77 : // don't normalize if the norm is very close to 1 already 78 0 : _dictionary.getBlock(j) = 79 0 : (std::abs(l2Norm - 1) < std::numeric_limits<data_t>::epsilon() * std::abs(l2Norm + 1)) 80 0 : ? atom 81 : : (atom / l2Norm); 82 0 : } 83 : 84 : template <typename data_t> 85 0 : const DataContainer<data_t> Dictionary<data_t>::getAtom(index_t j) const 86 : { 87 0 : if (j < 0 || j >= _nAtoms) 88 0 : throw InvalidArgumentError("Dictionary: atom index out of bounds"); 89 0 : return _dictionary.getBlock(j); 90 : } 91 : 92 : template <typename data_t> 93 0 : index_t Dictionary<data_t>::getNumberOfAtoms() const 94 : { 95 0 : return _nAtoms; 96 : } 97 : 98 : template <typename data_t> 99 0 : Dictionary<data_t> Dictionary<data_t>::getSupportedDictionary(IndexVector_t support) const 100 : { 101 0 : Dictionary supportDict(*_rangeDescriptor, support.rows()); 102 0 : index_t j = 0; 103 : 104 0 : for (const auto& i : support) { 105 0 : if (i < 0 || i >= _nAtoms) 106 0 : throw InvalidArgumentError( 107 : "Dictionary::getSupportedDictionary: support contains out-of-bounds index"); 108 : 109 0 : supportDict.updateAtom(j, getAtom(i)); 110 0 : ++j; 111 : } 112 : 113 0 : return supportDict; 114 0 : } 115 : 116 : template <typename data_t> 117 0 : void Dictionary<data_t>::applyImpl(const DataContainer<data_t>& x, 118 : DataContainer<data_t>& Ax) const 119 : { 120 0 : Timer timeguard("Dictionary", "apply"); 121 : 122 0 : if (x.getSize() != _nAtoms || Ax.getDataDescriptor() != *_rangeDescriptor) 123 0 : throw InvalidArgumentError("Dictionary::apply: incorrect input/output sizes"); 124 : 125 0 : index_t j = 0; 126 0 : Ax = 0; 127 : 128 0 : for (const auto& x_j : x) { 129 0 : const auto& atom = getAtom(j); 130 0 : Ax += atom * x_j; // vector*scalar 131 : 132 0 : ++j; 133 : } 134 0 : } 135 : 136 : template <typename data_t> 137 0 : void Dictionary<data_t>::applyAdjointImpl(const DataContainer<data_t>& y, 138 : DataContainer<data_t>& Aty) const 139 : { 140 0 : Timer timeguard("Dictionary", "applyAdjoint"); 141 : 142 0 : if (Aty.getSize() != _nAtoms || y.getDataDescriptor() != *_rangeDescriptor) 143 0 : throw InvalidArgumentError("Dictionary::applyAdjoint: incorrect input/output sizes"); 144 : 145 0 : index_t i = 0; 146 0 : Aty = 0; 147 0 : for (auto& Aty_i : Aty) { 148 0 : const auto& atom_i = getAtom(i); 149 0 : for (int j = 0; j < atom_i.getSize(); ++j) { 150 0 : Aty_i += atom_i[j] * y[j]; 151 : } 152 0 : ++i; 153 : } 154 0 : } 155 : 156 : template <typename data_t> 157 0 : Dictionary<data_t>* Dictionary<data_t>::cloneImpl() const 158 : { 159 0 : return new Dictionary(_dictionary); 160 : } 161 : 162 : template <typename data_t> 163 0 : bool Dictionary<data_t>::isEqual(const LinearOperator<data_t>& other) const 164 : { 165 0 : if (!LinearOperator<data_t>::isEqual(other)) 166 0 : return false; 167 : 168 0 : auto otherDictionary = downcast_safe<Dictionary>(&other); 169 0 : if (!otherDictionary) 170 0 : return false; 171 : 172 0 : if (_dictionary != otherDictionary->_dictionary) 173 0 : return false; 174 : 175 0 : return true; 176 : } 177 : 178 : // ------------------------------------------ 179 : // explicit template instantiation 180 : template class Dictionary<float>; 181 : template class Dictionary<double>; 182 : 183 : } // namespace elsa