Line data Source code
1 : #pragma once 2 : 3 : #include <unordered_map> 4 : #include <memory> 5 : #include <utility> 6 : #include <string> 7 : #include <vector> 8 : 9 : #include "elsaDefines.h" 10 : #include "Common.h" 11 : #include "DataContainer.h" 12 : #include "DataDescriptor.h" 13 : #include "VolumeDescriptor.h" 14 : #include "Logger.h" 15 : #include "TypeCasts.hpp" 16 : 17 : #include "dnnl.hpp" 18 : 19 : namespace elsa::ml 20 : { 21 : namespace detail 22 : { 23 : template <typename data_t> 24 : struct TypeToDnnlTypeTag { 25 : static constexpr dnnl::memory::data_type tag = dnnl::memory::data_type::undef; 26 : }; 27 : 28 : template <> 29 : struct TypeToDnnlTypeTag<float> { 30 : static constexpr dnnl::memory::data_type tag = dnnl::memory::data_type::f32; 31 : }; 32 : 33 : /// A Dnnl layer 34 : /// 35 : /// This class is the base for all Dnnl backend layer's in elsa. 36 : template <typename data_t> 37 : class DnnlLayer 38 : { 39 : public: 40 : /// Virtual destructor 41 27 : virtual ~DnnlLayer() = default; 42 : 43 : /// Execute this layer's forward primitives on executionStream 44 : virtual void forwardPropagate(dnnl::stream& executionStream); 45 : 46 : /// Execute this layer's backward primitives on executionStream 47 : virtual void backwardPropagate(dnnl::stream& executionStream); 48 : 49 : /// Get this layer's input-descriptor at a given index. 50 : /// 51 : /// @param index the index of the input-descriptor in this layer's 52 : /// list of input-descriptors 53 : /// @return this layer's output-descriptor at a given index 54 8 : VolumeDescriptor getInputDescriptor(index_t index = 0) const 55 : { 56 8 : validateVectorIndex(_inputDescriptor, index); 57 8 : return downcast_safe<VolumeDescriptor>(*_inputDescriptor[asUnsigned(index)]); 58 : } 59 : 60 : /// Get this layer's output-descriptor 61 9 : VolumeDescriptor getOutputDescriptor() const 62 : { 63 9 : assert(_outputDescriptor != nullptr 64 : && "Cannot get output-descriptor since it is null"); 65 9 : return downcast_safe<VolumeDescriptor>(*_outputDescriptor); 66 : } 67 : 68 : /// Set this layer's input at a given index. 69 : /// 70 : /// @param input DataContainer containing the input data 71 : /// @param index Index of the input to set in the list of layer 72 : /// inputs. 73 : /// @warning This performs a copy from the DataContainer to Dnnl memory 74 : /// and is therefore potentially expensive. 75 : void setInput(const DataContainer<data_t>& input, index_t index = 0); 76 : 77 : /// Set this layer's input memory by passing a pointer to another Dnnl memory 78 : void setInputMemory(std::shared_ptr<dnnl::memory> input, index_t index = 0); 79 : 80 : /// Set next layer's input memory by passing a pointer to another Dnnl memory 81 : void setNextInputMemory(std::shared_ptr<dnnl::memory> input); 82 : 83 : /// Set this layer's output-gradient at a given index. 84 : /// 85 : /// @param gradient DataContainer containing the gradient data. 86 : /// @param index Index of the gradient to set in the list of layer 87 : /// gradients. 88 : void setOutputGradient(const DataContainer<data_t>& gradient, index_t index = 0); 89 : 90 : /// Set this layer's raw memory storing the gradient of its output 91 : void setOutputGradientMemory(std::shared_ptr<dnnl::memory> outputGradient, 92 : index_t index = 0); 93 : 94 : /// Set this layer's raw memory storing the gradient of its output 95 : void setNextOutputGradientMemory(std::shared_ptr<dnnl::memory> outputGradient); 96 : 97 : /// Get the layer's output by copying it into a DataContainer. 98 : /// If the layer reorders, it gets reordered again. 99 : /// 100 : /// \note This method is meant to expose a layer's output. Since 101 : /// elsa uses whcn as its memory-format, the output gets reshaped 102 : /// to match this memory-format, regardless of the memory-format 103 : /// that is used internally. 104 : /// 105 : /// @warning This function performs a copy and is therefore potentially 106 : /// expensive. It should not be used internally to connect network 107 : /// layers. 108 : DataContainer<data_t> getOutput() const; 109 : 110 : /// Get a pointer to this layer's dnnl output memory. 111 : /// 112 : /// \note In case of reordering primitives the memory returned by 113 : /// this function can differ from what is expected. In other word, 114 : /// this function doesn't revert possible memory reordering and 115 : /// should therefore be used for internal purposes only but not 116 : /// for final reporting of layer outputs. 117 : std::shared_ptr<dnnl::memory> getOutputMemory() const; 118 : 119 : /// Get this layer's input gradient 120 : DataContainer<data_t> getInputGradient(index_t index = 0) const; 121 : 122 : /// Get this layer's input gradient memory 123 : std::shared_ptr<dnnl::memory> getInputGradientMemory(index_t index = 0); 124 : 125 : /// @returns the number of inputs of this layer 126 2 : index_t getNumberOfInputs() const 127 : { 128 2 : return static_cast<index_t>(_inputDescriptor.size()); 129 : } 130 : 131 : /// Set the number of output-gradients of this layer 132 0 : void setNumberOfOutputGradients(index_t num) 133 : { 134 0 : _outputGradient = 135 0 : std::vector<DnnlMemory>(num == 0 ? 1 : asUnsigned(num), _outputGradient[0]); 136 0 : } 137 : 138 : /// @returns the number of output-gradients of this layer 139 0 : index_t getNumberOfOutputGradients() const { return asSigned(_outputGradient.size()); } 140 : 141 : /// Compile this layer, i.e., construct all necessary layer logic based on arguments 142 : /// defined beforehand. 143 : /// 144 : /// @param propagation The kind of propagation this layer should be compiled for 145 : void compile(PropagationKind propagation = PropagationKind::Forward); 146 : 147 : /// Return a pointer to this layer's execution engine. 148 : std::shared_ptr<dnnl::engine> getEngine() const; 149 : 150 : /// Set this layer's Dnnl execution engine 151 : void setEngine(std::shared_ptr<dnnl::engine> engine); 152 : 153 : /// Initialize all parameters of this layer 154 0 : virtual void initialize() {} 155 : 156 : /// @returns true if this layer is trainable, false otherwise 157 : virtual bool isTrainable() const; 158 : 159 : /// @returns true if this layer can merge multiple inputs together, 160 : /// false otherwise 161 : virtual bool canMerge() const; 162 : 163 : protected: 164 : struct DnnlMemory { 165 : /// Memory dimensions 166 : dnnl::memory::dims dimensions; 167 : 168 : /// Memory descriptor 169 : dnnl::memory::desc descriptor; 170 : 171 : /// Pointer to memory that was described during layer construction 172 : std::shared_ptr<dnnl::memory> describedMemory = nullptr; 173 : 174 : /// Pointer to memory that was possibly reordered during execution of 175 : /// a primitve 176 : std::shared_ptr<dnnl::memory> effectiveMemory = nullptr; 177 : 178 : /// Dnnl format that for memoryDescriptor 179 : dnnl::memory::format_tag formatTag; 180 : 181 : /// Flag to indicate whether this memory has been reordered by a 182 : /// primitive 183 : bool wasReordered = false; 184 : 185 : /// Flag to indicate whether this memory could be reordered by a 186 : /// primitive 187 : bool canBeReordered = false; 188 : }; 189 : 190 : /// A propagation-stream, i.e., a collection of Dnnl primitives 191 : /// and arguments that can be executed using a Dnnl engine. 192 : struct PropagationStream { 193 : /// Vector of primitives this propagation stream consists of 194 : std::vector<dnnl::primitive> primitives; 195 : 196 : /// Vector of arguments this propagation stream consists of 197 : std::vector<std::unordered_map<int, dnnl::memory>> arguments; 198 : 199 : /// Flag to indicate whether this propagation stream has been compiled 200 : bool isCompiled = false; 201 : 202 : // In the case of a debug build we keep a list of primitive names 203 : std::vector<std::string> names; 204 : }; 205 : 206 : #define ELSA_ML_ADD_DNNL_PRIMITIVE(propagationStream, primitive) \ 207 : propagationStream.primitives.push_back(primitive); \ 208 : propagationStream.names.push_back(#primitive); \ 209 : Logger::get(this->_name)->trace("Adding Dnnl primitive {}", #primitive) 210 : 211 0 : std::vector<std::string> getDnnlPrimitiveNames(const PropagationStream& stream) 212 : { 213 0 : return stream.names; 214 : } 215 : 216 : /// Validate a parameter pack of DnnlMemory 217 : template <typename... T> 218 167 : inline static void validateDnnlMemory([[maybe_unused]] T&&... mem) 219 : { 220 : #if !defined(NDEBUG) 221 167 : (assert(mem != nullptr && "Pointer to Dnnl memory cannot be null"), ...); 222 167 : (assert(mem->get_desc().get_size() != 0 223 : && "Dnnl memory descriptor cannot be of size 0"), 224 : ...); 225 : #endif 226 167 : } 227 : 228 : template <typename T> 229 61 : inline static void validateVectorIndex([[maybe_unused]] const std::vector<T>& vec, 230 : [[maybe_unused]] index_t index) 231 : { 232 61 : assert(asUnsigned(index) >= 0 && asUnsigned(index) < vec.size() 233 : && "Vector index is out of bounds"); 234 61 : } 235 : 236 : /// Construct a DnnlLayer by providing a volume-descriptor for its input and output 237 : DnnlLayer(const VolumeDescriptor& inputDescriptor, 238 : const VolumeDescriptor& outputDescriptor, const std::string& name, 239 : int allowedNumberOfInputs = 1); 240 : 241 : /// Cosntruct a DnnlLayer by providing a list of volume-descriptors 242 : /// for its input and a single volume-descriptor for its output 243 : DnnlLayer(const std::vector<VolumeDescriptor>& inputDescriptor, 244 : const VolumeDescriptor& outputDescriptor, const std::string& name, 245 : int allowedNumberOfInputs = 1); 246 : 247 : /// Explicitly deleted copy constructor 248 : DnnlLayer(const DnnlLayer&) = delete; 249 : 250 : /// Type of all coefficients in this layer, expressed as a Dnnl data-type tag 251 : static constexpr dnnl::memory::data_type _typeTag = TypeToDnnlTypeTag<data_t>::tag; 252 : 253 : /// Reorders memory from described to effective if memory descriptor differs from 254 : /// primitive description 255 : void reorderMemory(const dnnl::memory::desc& memoryDesc, DnnlMemory& memory, 256 : PropagationStream& stream); 257 : 258 : /// Compile this layer's backward stream 259 : virtual void compileBackwardStream(); 260 : 261 : /// Compile this layer's forward stream 262 : virtual void compileForwardStream(); 263 : 264 : /// If a layer has multiple outputs, it will receive multiple 265 : /// output-gradients. This functions adds a primitive to the 266 : /// Dnnl backward-stream that sums up all of these output-gradients. 267 : void handleMultipleOutputGradients(); 268 : 269 : /// Write the content of a DataContainer to Dnnl memory 270 : /// 271 : /// \note This performs a copy and is therefore potentially expensive. 272 : static void writeToDnnlMemory(const DataContainer<data_t>& data, dnnl::memory& memory); 273 : 274 : /// Read the content from Dnnl memory into a DataContainer 275 : /// 276 : /// \note This performs a copy and is therefore potentially expensive. 277 : static void readFromDnnlMemory(DataContainer<data_t>& data, const dnnl::memory& memory); 278 : 279 : /// @returns true if this layer needs to synchronize its Dnnl 280 : /// execution-stream during a forward-pass, false otherwise. 281 : /// 282 : /// This is particularly true for any merging layer. 283 9 : virtual bool needsForwardSynchronisation() const { return false; } 284 : 285 : /// @returns true if this layer needs to synchronize its Dnnl 286 : /// execution-stream during a backward-pass, false otherwise. 287 : /// 288 : /// This is particularly true for all layers with multiple outputs. 289 9 : virtual bool needsBackwardSynchronisation() const { return _outputGradient.size() > 1; } 290 : 291 0 : std::string getName() const { return _name; } 292 : 293 : /// Choose a Dnnl memory format tag, given a VolumeDescriptor. 294 : /// 295 : /// The following format tags are chosen: 296 : /// 297 : /// +-----------+-------+---------+ 298 : /// | Dimension | Input | Weights | 299 : /// +-----------+-------+---------+ 300 : /// | 2D | nc | oi | 301 : /// | 3D | ncw | oiw | 302 : /// | 4D | nchw | oihw | 303 : /// | 5D | ncdhw | oidhw | 304 : /// +-----------+-------+---------+ 305 : /// 306 : /// where each letter has the following meaning: 307 : /// 308 : /// Input case: 309 : /// n: Number of batches 310 : /// c: Number of input channels 311 : /// d: Depth (spatial dimension) 312 : /// h: Height (spatial dimension) 313 : /// w: Width (spatial dimension) 314 : /// 315 : /// Weights case: 316 : /// o: Number if output channels, i.e., number of weights 317 : /// i: Number of input channels 318 : /// d: Depth (spatial dimension) 319 : /// h: Height (spatial dimension) 320 : /// w: Width (spatial dimension) 321 : /// 322 : /// @param desc DataDescriptor to choose a format type tag. 323 : /// @param isInput True if the DataDescriptor descripes an input, 324 : /// false if it describes weights. 325 : /// @return Dnnl memory format tag corresponding to the above table. 326 : static dnnl::memory::format_tag 327 : dataDescriptorToDnnlMemoryFormatTag(const VolumeDescriptor& desc, bool isInput); 328 : 329 : /// @returns a string representation of a Dnnl memory format-tag 330 : static std::string dnnlMemoryFormatTagToString(dnnl::memory::format_tag tag); 331 : 332 : /// This layer's forward propagation stream 333 : PropagationStream _forwardStream; 334 : 335 : /// This layer's backward propagation stream 336 : PropagationStream _backwardStream; 337 : 338 : /// This layer's input memory 339 : std::vector<DnnlMemory> _input; 340 : 341 : /// This layer's input gradient memory 342 : std::vector<DnnlMemory> _inputGradient; 343 : 344 : /// This layer's output memory 345 : DnnlMemory _output; 346 : 347 : /// Vector with this layer's output-gradients 348 : std::vector<DnnlMemory> _outputGradient; 349 : 350 : /// This layer's output DataDescriptor 351 : std::unique_ptr<DataDescriptor> _outputDescriptor; 352 : 353 : /// This layer's input DataDescriptor 354 : std::vector<std::unique_ptr<DataDescriptor>> _inputDescriptor; 355 : 356 : /// This layer's Dnnl execution engine 357 : std::shared_ptr<dnnl::engine> _engine = nullptr; 358 : 359 : constexpr static int anyNumberOfInputs = -1; 360 : 361 : int _allowedNumberOfInputs; 362 : 363 : std::string _name; 364 : 365 : private: 366 : index_t _currentInputMemoryIndex = 0; 367 : index_t _currentOutputGradientMemoryIndex = 0; 368 : }; 369 : } // namespace detail 370 : 371 : } // namespace elsa::ml