Line data Source code
1 : #include "DnnlPoolingLayer.h" 2 : 3 : namespace elsa::ml 4 : { 5 : namespace detail 6 : { 7 : template <typename data_t> 8 1 : DnnlPoolingLayer<data_t>::DnnlPoolingLayer(const VolumeDescriptor& inputDescriptor, 9 : const VolumeDescriptor& outputDescriptor, 10 : const IndexVector_t& poolingWindow, 11 : const IndexVector_t& poolingStride) 12 1 : : DnnlLayer<data_t>(inputDescriptor, outputDescriptor, "DnnlPoolingLayer") 13 : { 14 3 : for (const auto& dim : poolingWindow) { 15 2 : _poolingWindow.push_back(dim); 16 2 : _poolingPadding.push_back(0); 17 : } 18 3 : for (const auto& dim : poolingStride) 19 2 : _poolingStride.push_back(dim); 20 1 : } 21 : 22 : template <typename data_t> 23 1 : void DnnlPoolingLayer<data_t>::compileForwardStream() 24 : { 25 1 : BaseType::compileForwardStream(); 26 1 : auto desc = dnnl::pooling_forward::desc( 27 : /* Propagation kind */ dnnl::prop_kind::forward, 28 : /* Pooling algorithm */ dnnl::algorithm::pooling_max, 29 1 : /* Source memory descriptor */ _input.front().descriptor, 30 1 : /* Destination memory descriptor */ _output.descriptor, 31 1 : /* Pooling strides */ _poolingStride, 32 1 : /* Pooling window */ _poolingWindow, 33 1 : /* Input padding for lower dims */ _poolingPadding, 34 1 : /* Input padding for higher dims */ _poolingPadding); 35 : 36 1 : _forwardPrimitiveDescriptor = dnnl::pooling_forward::primitive_desc(desc, *_engine); 37 : 38 1 : _workspaceMemory.effectiveMemory = std::make_shared<dnnl::memory>( 39 1 : _forwardPrimitiveDescriptor.workspace_desc(), *_engine); 40 : 41 1 : ELSA_ML_ADD_DNNL_PRIMITIVE(_forwardStream, 42 : dnnl::pooling_forward(_forwardPrimitiveDescriptor)); 43 4 : _forwardStream.arguments.push_back( 44 1 : {{DNNL_ARG_SRC, *_input.front().effectiveMemory}, 45 1 : {DNNL_ARG_WORKSPACE, *_workspaceMemory.effectiveMemory}}); 46 : 47 1 : auto outDesc = dnnl::memory::desc({_output.dimensions}, _typeTag, _output.formatTag); 48 : 49 1 : _output.describedMemory = std::make_shared<dnnl::memory>(outDesc, *_engine); 50 : 51 1 : _output.effectiveMemory = _output.describedMemory; 52 1 : if (_forwardPrimitiveDescriptor.dst_desc() != _output.describedMemory->get_desc()) { 53 0 : _output.wasReordered = true; 54 0 : _output.describedMemory = std::make_shared<dnnl::memory>( 55 0 : _forwardPrimitiveDescriptor.dst_desc(), *_engine); 56 0 : _forwardStream.arguments.back().insert({DNNL_ARG_DST, *_output.describedMemory}); 57 0 : ELSA_ML_ADD_DNNL_PRIMITIVE(_forwardStream, dnnl::reorder(*_output.describedMemory, 58 : *_output.effectiveMemory)); 59 0 : _forwardStream.arguments.push_back({{DNNL_ARG_FROM, *_output.describedMemory}, 60 0 : {DNNL_ARG_TO, *_output.effectiveMemory}}); 61 : } else { 62 1 : _forwardStream.arguments.back().insert({DNNL_ARG_DST, *_output.effectiveMemory}); 63 : } 64 1 : } 65 : 66 : template <typename data_t> 67 1 : void DnnlPoolingLayer<data_t>::compileBackwardStream() 68 : { 69 1 : BaseType::compileBackwardStream(); 70 1 : auto desc = dnnl::pooling_backward::desc( 71 : /* Pooling algorithm */ dnnl::algorithm::pooling_max, 72 1 : /* Input gradient descriptor */ _inputGradient.front().descriptor, 73 1 : /* Output gradient descriptor */ _outputGradient.front().descriptor, 74 1 : /* Strides */ _poolingStride, 75 1 : /* Pooling window */ _poolingWindow, 76 1 : /* Padding */ _poolingPadding, _poolingPadding); 77 : 78 1 : _backwardPrimitiveDescriptor = 79 1 : dnnl::pooling_backward::primitive_desc(desc, *_engine, _forwardPrimitiveDescriptor); 80 : 81 1 : this->reorderMemory(_backwardPrimitiveDescriptor.diff_dst_desc(), 82 1 : _outputGradient.front(), _backwardStream); 83 : 84 1 : _inputGradient.front().effectiveMemory = std::make_shared<dnnl::memory>( 85 1 : _backwardPrimitiveDescriptor.diff_src_desc(), *_engine); 86 : 87 1 : BaseType::validateDnnlMemory(_outputGradient.front().effectiveMemory, 88 1 : _inputGradient.front().effectiveMemory, 89 1 : _workspaceMemory.effectiveMemory); 90 : 91 1 : ELSA_ML_ADD_DNNL_PRIMITIVE(_backwardStream, 92 : dnnl::pooling_backward(_backwardPrimitiveDescriptor)); 93 6 : _backwardStream.arguments.push_back( 94 1 : {{DNNL_ARG_DIFF_DST, *_outputGradient.front().effectiveMemory}, 95 1 : {DNNL_ARG_DIFF_SRC, *_inputGradient.front().effectiveMemory}, 96 1 : {DNNL_ARG_WORKSPACE, *_workspaceMemory.effectiveMemory}}); 97 1 : } 98 : 99 : template class DnnlPoolingLayer<float>; 100 : 101 : } // namespace detail 102 : } // namespace elsa::ml