LCOV - code coverage report
Current view: top level - ml/backend/Dnnl - DnnlMerging.cpp (source / functions) Hit Total Coverage
Test: test_coverage.info.cleaned Lines: 58 96 60.4 %
Date: 2022-07-06 02:47:47 Functions: 8 13 61.5 %

          Line data    Source code
       1             : #include "DnnlMerging.h"
       2             : #include "TypeCasts.hpp"
       3             : 
       4             : namespace elsa::ml
       5             : {
       6             :     namespace detail
       7             :     {
       8             :         template <typename data_t>
       9           2 :         DnnlMerging<data_t>::DnnlMerging(const std::vector<VolumeDescriptor>& inputDescriptors,
      10             :                                          const VolumeDescriptor& outputDescriptor)
      11             :             : DnnlLayer<data_t>(inputDescriptors, outputDescriptor, "DnnlMerging",
      12           2 :                                 DnnlLayer<data_t>::anyNumberOfInputs)
      13             :         {
      14           2 :         }
      15             : 
      16             :         template <typename data_t>
      17           3 :         bool DnnlMerging<data_t>::needsForwardSynchronisation() const
      18             :         {
      19           3 :             return true;
      20             :         }
      21             : 
      22             :         template <typename data_t>
      23           2 :         bool DnnlMerging<data_t>::canMerge() const
      24             :         {
      25           2 :             return true;
      26             :         }
      27             : 
      28             :         template <typename data_t>
      29           2 :         DnnlSum<data_t>::DnnlSum(const std::vector<VolumeDescriptor>& inputDescriptors,
      30             :                                  const VolumeDescriptor& outputDescriptor)
      31           2 :             : DnnlMerging<data_t>(inputDescriptors, outputDescriptor)
      32             :         {
      33             :             // Check that all input-descriptors are equal
      34           6 :             assert(std::adjacent_find(inputDescriptors.begin(), inputDescriptors.end(),
      35             :                                       [](const auto& a, const auto& b) { return a != b; })
      36             :                        == inputDescriptors.end()
      37             :                    && "All input-descriptors for DnnlSum must be equal");
      38           2 :         }
      39             : 
      40             :         template <typename data_t>
      41           2 :         void DnnlSum<data_t>::compileForwardStream()
      42             :         {
      43           2 :             BaseType::compileForwardStream();
      44             : 
      45           4 :             std::vector<dnnl::memory> mem;
      46           4 :             std::vector<dnnl::memory::desc> memDesc;
      47           8 :             for (std::size_t i = 0; i < _input.size(); ++i) {
      48           6 :                 memDesc.push_back(_input[i].descriptor);
      49           6 :                 BaseType::validateDnnlMemory(_input[i].effectiveMemory);
      50           6 :                 mem.push_back(*_input[i].effectiveMemory);
      51             :             }
      52             : 
      53             :             // We currently do not support custom scaling since the API does not support it
      54           4 :             std::vector<data_t> scales(_input.size(), data_t(1));
      55             : 
      56             :             // Create primitive-descriptor
      57           2 :             _forwardPrimitiveDescriptor = dnnl::sum::primitive_desc(scales, memDesc, *_engine);
      58             : 
      59           8 :             for (std::size_t i = 0; i < _input.size(); ++i) {
      60             :                 // Reoder input memory if necessary
      61           6 :                 this->reorderMemory(_forwardPrimitiveDescriptor.src_desc(), _input[i],
      62           6 :                                     _forwardStream);
      63             :             }
      64             : 
      65             :             // Add sum primitive to forward-stream
      66           2 :             ELSA_ML_ADD_DNNL_PRIMITIVE(_forwardStream, dnnl::sum(_forwardPrimitiveDescriptor));
      67             : 
      68             :             // Allocate output memory
      69           2 :             _output.effectiveMemory =
      70           2 :                 std::make_shared<dnnl::memory>(_forwardPrimitiveDescriptor.dst_desc(), *_engine);
      71             : 
      72             :             // Validate memory
      73           2 :             BaseType::validateDnnlMemory(_output.effectiveMemory);
      74             : 
      75             :             // Add arguments to forward-stream
      76           4 :             _forwardStream.arguments.push_back({{DNNL_ARG_DST, *_output.effectiveMemory}});
      77           8 :             for (std::size_t i = 0; i < _input.size(); ++i) {
      78           6 :                 _forwardStream.arguments.back().insert({DNNL_ARG_MULTIPLE_SRC + i, mem[i]});
      79             :             }
      80           2 :         }
      81             : 
      82             :         template <typename data_t>
      83           1 :         void DnnlSum<data_t>::compileBackwardStream()
      84             :         {
      85           1 :             BaseType::compileBackwardStream();
      86             : 
      87             :             // Allocate memory for input-gradient if necessary
      88           4 :             for (std::size_t i = 0; i < _inputGradient.size(); ++i) {
      89           3 :                 if (!_inputGradient[i].effectiveMemory) {
      90           3 :                     _inputGradient[i].effectiveMemory = std::make_shared<dnnl::memory>(
      91           3 :                         dnnl::memory::desc({{_inputGradient[i].dimensions},
      92             :                                             this->_typeTag,
      93           3 :                                             _inputGradient[i].formatTag}),
      94           3 :                         *_engine);
      95             :                 }
      96             :             }
      97           1 :             _outputGradient.front().effectiveMemory = _outputGradient.front().describedMemory;
      98           1 :         }
      99             : 
     100             :         template <typename data_t>
     101           1 :         void DnnlSum<data_t>::backwardPropagate([[maybe_unused]] dnnl::stream& executionStream)
     102             :         {
     103             :             // make sure backward stream has been compiled
     104           1 :             assert(_backwardStream.isCompiled
     105             :                    && "Cannot backward propagate because backward-stream has not been compiled");
     106             : 
     107             :             // We derive the gradient for a sum layer as follows:
     108             :             //
     109             :             //
     110             :             //      |   ^                   |   ^
     111             :             //   i0 |   | dE/di0         i1 |   | dE/di1
     112             :             //      v   |                   v   |
     113             :             //   +--------------------------------+
     114             :             //   |             SUM                |
     115             :             //   +--------------------------------+
     116             :             //                |    ^
     117             :             //              o |    | dE/do
     118             :             //                v    |
     119             :             //
     120             :             // The input-gradient along the path of input i0 is given by
     121             :             //    dE/di0 = dE/do * do/di0
     122             :             //             ^^^^^   ^^^^^^
     123             :             //             |       | i0 as the partial derivative of i0+i1
     124             :             //             | output-gradient
     125             : 
     126             :             // Get output-gradient memory
     127           1 :             Eigen::Map<Eigen::ArrayX<data_t>> outputGrad(
     128           1 :                 static_cast<data_t*>(_outputGradient.front().effectiveMemory->get_data_handle()),
     129             :                 _outputDescriptor->getNumberOfCoefficients());
     130             : 
     131           4 :             for (std::size_t i = 0; i < _inputGradient.size(); ++i) {
     132           3 :                 BaseType::validateDnnlMemory(_inputGradient[i].effectiveMemory);
     133           3 :                 BaseType::validateDnnlMemory(_outputGradient.front().effectiveMemory);
     134           3 :                 BaseType::validateDnnlMemory(_input[i].effectiveMemory);
     135             : 
     136             :                 // Get input-gradient memory
     137           6 :                 Eigen::Map<Eigen::ArrayX<data_t>> inputGrad(
     138           3 :                     static_cast<data_t*>(_inputGradient[i].effectiveMemory->get_data_handle()),
     139           3 :                     _inputDescriptor[i]->getNumberOfCoefficients());
     140             : 
     141             :                 // Get input memory
     142           6 :                 Eigen::Map<Eigen::ArrayX<data_t>> input(
     143           3 :                     static_cast<data_t*>(_input[i].effectiveMemory->get_data_handle()),
     144           3 :                     _inputDescriptor[i]->getNumberOfCoefficients());
     145             : 
     146             :                 // Compute input-gradient
     147           3 :                 inputGrad = outputGrad * input;
     148             :             }
     149           1 :         }
     150             : 
     151             :         template <typename data_t>
     152           0 :         DnnlConcatenate<data_t>::DnnlConcatenate(
     153             :             index_t axis, const std::vector<VolumeDescriptor>& inputDescriptors,
     154             :             const VolumeDescriptor& outputDescriptor)
     155           0 :             : DnnlMerging<data_t>(inputDescriptors, outputDescriptor), _axis(axis)
     156             :         {
     157             : 
     158             :             // Check that all input-descriptors are equal
     159           0 :             assert(std::adjacent_find(inputDescriptors.begin(), inputDescriptors.end(),
     160             :                                       [](const auto& a, const auto& b) { return a != b; })
     161             :                        == inputDescriptors.end()
     162             :                    && "All input-descriptors for DnnlSum must be equal");
     163           0 :         }
     164             : 
     165             :         template <typename data_t>
     166           0 :         void DnnlConcatenate<data_t>::compileForwardStream()
     167             :         {
     168           0 :             BaseType::compileForwardStream();
     169             : 
     170           0 :             std::vector<dnnl::memory> mem;
     171           0 :             std::vector<dnnl::memory::desc> memDesc;
     172           0 :             for (std::size_t i = 0; i < _input.size(); ++i) {
     173           0 :                 memDesc.push_back(_input[i].descriptor);
     174           0 :                 BaseType::validateDnnlMemory(_input[i].effectiveMemory);
     175           0 :                 mem.push_back(*_input[i].effectiveMemory);
     176             :             }
     177             : 
     178             :             // Create primitive-descriptor
     179           0 :             _forwardPrimitiveDescriptor =
     180           0 :                 dnnl::concat::primitive_desc(as<int>(_axis), memDesc, *_engine);
     181             : 
     182           0 :             for (std::size_t i = 0; i < _input.size(); ++i) {
     183             :                 // Reoder input memory if necessary
     184           0 :                 this->reorderMemory(_forwardPrimitiveDescriptor.src_desc(), _input[i],
     185           0 :                                     _forwardStream);
     186             :             }
     187             : 
     188             :             // Add sum primitive to forward-stream
     189           0 :             ELSA_ML_ADD_DNNL_PRIMITIVE(_forwardStream, dnnl::concat(_forwardPrimitiveDescriptor));
     190             : 
     191             :             // Allocate output memory
     192           0 :             _output.effectiveMemory =
     193           0 :                 std::make_shared<dnnl::memory>(_forwardPrimitiveDescriptor.dst_desc(), *_engine);
     194             : 
     195             :             // Validate memory
     196           0 :             BaseType::validateDnnlMemory(_output.effectiveMemory);
     197             : 
     198             :             // Add arguments to forward-stream
     199           0 :             _forwardStream.arguments.push_back({{DNNL_ARG_DST, *_output.effectiveMemory}});
     200           0 :             for (std::size_t i = 0; i < _input.size(); ++i) {
     201           0 :                 _forwardStream.arguments.back().insert({DNNL_ARG_MULTIPLE_SRC + i, mem[i]});
     202             :             }
     203           0 :         }
     204             : 
     205             :         template <typename data_t>
     206           0 :         void DnnlConcatenate<data_t>::compileBackwardStream()
     207             :         {
     208           0 :             BaseType::compileBackwardStream();
     209             : 
     210             :             // Allocate memory for input-gradient if necessary
     211           0 :             for (std::size_t i = 0; i < _inputGradient.size(); ++i) {
     212           0 :                 if (!_inputGradient[i].effectiveMemory) {
     213           0 :                     _inputGradient[i].effectiveMemory = std::make_shared<dnnl::memory>(
     214           0 :                         dnnl::memory::desc({{_inputGradient[i].dimensions},
     215             :                                             this->_typeTag,
     216           0 :                                             _inputGradient[i].formatTag}),
     217           0 :                         *_engine);
     218             :                 }
     219             :             }
     220           0 :             _outputGradient.front().effectiveMemory = _outputGradient.front().describedMemory;
     221           0 :         }
     222             : 
     223             :         template <typename data_t>
     224           0 :         void DnnlConcatenate<data_t>::backwardPropagate([
     225             :             [maybe_unused]] dnnl::stream& executionStream)
     226             :         {
     227             :             // make sure backward stream has been compiled
     228           0 :             assert(_backwardStream.isCompiled
     229             :                    && "Cannot backward propagate because backward-stream has not been compiled");
     230             : 
     231             :             // We derive the gradient for a concat layer as follows:
     232             :             //
     233             :             // If the Concatenate layer receives three inputs i0, i1, i2 with
     234             :             // shapes (n, c0, h, w), (n, c1, h, w) and (n, c2, h, w)
     235             :             // respectively and c is the concatenation axis, the output has
     236             :             // shape (n, c0+c1+c2, h, w).
     237             :             //
     238             :             // The incoming gradient for the Concatentation layer has then
     239             :             // also shape (n, c0+c1+c2, h, w).
     240             :             //
     241             :             // The gradient for each of the inputs is then the slice of the
     242             :             // incoming gradient along c that matches the slice of the input,
     243             :             // e.g. i0 gets slice (n, c0, h, w) of the incoming gradient.
     244           0 :         }
     245             : 
     246             :         template class DnnlMerging<float>;
     247             :         template class DnnlSum<float>;
     248             :         template class DnnlConcatenate<float>;
     249             :     } // namespace detail
     250             : } // namespace elsa::ml

Generated by: LCOV version 1.15