Line data Source code
1 : #include "DnnlFlattenLayer.h" 2 : 3 : namespace elsa::ml 4 : { 5 : namespace detail 6 : { 7 : template <typename data_t> 8 0 : DnnlFlattenLayer<data_t>::DnnlFlattenLayer(const VolumeDescriptor& inputDescriptor, 9 : const VolumeDescriptor& outputDescriptor) 10 0 : : DnnlLayer<data_t>(inputDescriptor, outputDescriptor, "DnnlFlattenLayer") 11 : { 12 0 : assert(inputDescriptor.getNumberOfCoefficients() 13 : == outputDescriptor.getNumberOfCoefficients() 14 : && "Cannot flatten if number of coefficients of input- and output-descriptor do " 15 : "not match"); 16 0 : } 17 : 18 : template <typename data_t> 19 0 : void DnnlFlattenLayer<data_t>::compileForwardStream() 20 : { 21 0 : BaseType::compileForwardStream(); 22 : 23 : // Set output-descriptor. This is the flattened input-descriptor 24 0 : _output.effectiveMemory = std::make_shared<dnnl::memory>( 25 0 : dnnl::memory::desc({{_output.dimensions}, _typeTag, _output.formatTag}), *_engine); 26 : 27 0 : BaseType::validateDnnlMemory(_input.front().effectiveMemory, _output.effectiveMemory); 28 : 29 0 : _output.effectiveMemory->set_data_handle( 30 0 : _input.front().effectiveMemory->get_data_handle()); 31 0 : } 32 : 33 : template <typename data_t> 34 0 : void DnnlFlattenLayer<data_t>::compileBackwardStream() 35 : { 36 0 : BaseType::compileBackwardStream(); 37 : // Set output-memory 38 0 : _inputGradient.front().effectiveMemory = std::make_shared<dnnl::memory>( 39 0 : dnnl::memory::desc({{_inputGradient.front().dimensions}, 40 : _typeTag, 41 0 : _inputGradient.front().formatTag}), 42 0 : *_engine); 43 : 44 0 : BaseType::validateDnnlMemory(_inputGradient.front().effectiveMemory, 45 0 : _outputGradient.front().effectiveMemory); 46 0 : _inputGradient.front().effectiveMemory->set_data_handle( 47 0 : _outputGradient.front().effectiveMemory->get_data_handle()); 48 0 : } 49 : 50 : template class DnnlFlattenLayer<float>; 51 : } // namespace detail 52 : } // namespace elsa::ml