LCOV - code coverage report
Current view: top level - ml/backend/Dnnl - DnnlActivationLayer.cpp (source / functions) Hit Total Coverage
Test: test_coverage.info.cleaned Lines: 66 81 81.5 %
Date: 2022-07-06 02:47:47 Functions: 13 18 72.2 %

          Line data    Source code
       1             : #include "DnnlActivationLayer.h"
       2             : 
       3             : namespace elsa::ml
       4             : {
       5             :     namespace detail
       6             :     {
       7             :         template <typename data_t>
       8          24 :         DnnlActivationLayer<data_t>::DnnlActivationLayer(const VolumeDescriptor& inputDescriptor,
       9             :                                                          const VolumeDescriptor& outputDescriptor,
      10             :                                                          dnnl::algorithm algorithm)
      11             :             : DnnlLayer<data_t>(inputDescriptor, outputDescriptor, "DnnlActivationLayer"),
      12          24 :               algorithm_(algorithm)
      13             :         {
      14          24 :         }
      15             : 
      16             :         template <typename data_t>
      17          24 :         void DnnlActivationLayer<data_t>::setAlpha(data_t alpha)
      18             :         {
      19          24 :             _alpha = alpha;
      20          24 :         }
      21             : 
      22             :         template <typename data_t>
      23          24 :         void DnnlActivationLayer<data_t>::setBeta(data_t beta)
      24             :         {
      25          24 :             _beta = beta;
      26          24 :         }
      27             : 
      28             :         template <typename data_t>
      29          24 :         void DnnlActivationLayer<data_t>::compileForwardStream()
      30             :         {
      31          24 :             BaseType::compileForwardStream();
      32             : 
      33             :             // Set forward primitive description
      34          48 :             auto desc = dnnl::eltwise_forward::desc(
      35             :                 /* Inference type */ dnnl::prop_kind::forward_training,
      36             :                 /* Element-wise algorithm */ algorithm_,
      37          24 :                 /* Source memory descriptor */ _input.front().descriptor,
      38             :                 /* Alpha parameter */ _alpha,
      39             :                 /* Beta parameter */ _beta);
      40             : 
      41          24 :             _forwardPrimitiveDescriptor = dnnl::eltwise_forward::primitive_desc(desc, *_engine);
      42             : 
      43             :             // Set forward primitive
      44          24 :             ELSA_ML_ADD_DNNL_PRIMITIVE(_forwardStream,
      45             :                                        dnnl::eltwise_forward(_forwardPrimitiveDescriptor));
      46             : 
      47             :             // Set output memory. Since no activation layer can reorder we set effective memory
      48             :             // directly
      49          24 :             _output.effectiveMemory =
      50          24 :                 std::make_shared<dnnl::memory>(_forwardPrimitiveDescriptor.dst_desc(), *_engine);
      51             : 
      52          96 :             _forwardStream.arguments.push_back({{DNNL_ARG_SRC, *_input.front().effectiveMemory},
      53          24 :                                                 {DNNL_ARG_DST, *_output.effectiveMemory}});
      54          24 :         }
      55             : 
      56             :         template <typename data_t>
      57          24 :         void DnnlActivationLayer<data_t>::compileBackwardStream()
      58             :         {
      59          24 :             BaseType::compileBackwardStream();
      60             : 
      61          48 :             auto desc = dnnl::eltwise_backward::desc(
      62             :                 /* Element-wise algorithm */ algorithm_,
      63          24 :                 /* Gradient dst memory descriptor */ _outputGradient.front().descriptor,
      64          24 :                 /* Source memory descriptor */ _input.front().descriptor,
      65             :                 /* Alpha parameter */ _alpha,
      66             :                 /* Beta parameter */ _beta);
      67             : 
      68          24 :             _backwardPrimitiveDescriptor =
      69          24 :                 dnnl::eltwise_backward::primitive_desc(desc, *_engine, _forwardPrimitiveDescriptor);
      70             : 
      71             :             // Reorder if necessary
      72          24 :             this->reorderMemory(_backwardPrimitiveDescriptor.diff_dst_desc(),
      73          24 :                                 _outputGradient.front(), _backwardStream);
      74             : 
      75          24 :             _inputGradient.front().effectiveMemory = std::make_shared<dnnl::memory>(
      76          24 :                 _backwardPrimitiveDescriptor.diff_src_desc(), *_engine);
      77             : 
      78          24 :             _outputGradient.front().effectiveMemory = _outputGradient.front().describedMemory;
      79          24 :             BaseType::validateDnnlMemory(_input.front().effectiveMemory);
      80          24 :             BaseType::validateDnnlMemory(_outputGradient.front().effectiveMemory);
      81          24 :             BaseType::validateDnnlMemory(_outputGradient.front().describedMemory);
      82          24 :             BaseType::validateDnnlMemory(_inputGradient.front().effectiveMemory);
      83             : 
      84          24 :             ELSA_ML_ADD_DNNL_PRIMITIVE(_backwardStream,
      85             :                                        dnnl::eltwise_backward(_backwardPrimitiveDescriptor));
      86         144 :             _backwardStream.arguments.push_back(
      87             :                 {/* Input */
      88          24 :                  {DNNL_ARG_SRC, *_input.front().effectiveMemory},
      89          24 :                  {DNNL_ARG_DIFF_DST, *_outputGradient.front().effectiveMemory},
      90             :                  /* Output */
      91          24 :                  {DNNL_ARG_DIFF_SRC, *_inputGradient.front().effectiveMemory}});
      92          24 :         }
      93             : 
      94             :         template <typename data_t>
      95           3 :         DnnlAbs<data_t>::DnnlAbs(const VolumeDescriptor& inputDescriptor)
      96             :             : DnnlActivationLayer<data_t>(inputDescriptor, inputDescriptor,
      97           3 :                                           dnnl::algorithm::eltwise_abs)
      98             :         {
      99           3 :         }
     100             : 
     101             :         template <typename data_t>
     102           0 :         DnnlBoundedRelu<data_t>::DnnlBoundedRelu(const VolumeDescriptor& inputDescriptor)
     103             :             : DnnlActivationLayer<data_t>(inputDescriptor, inputDescriptor,
     104           0 :                                           dnnl::algorithm::eltwise_bounded_relu)
     105             :         {
     106           0 :         }
     107             : 
     108             :         template <typename data_t>
     109           3 :         DnnlElu<data_t>::DnnlElu(const VolumeDescriptor& inputDescriptor)
     110             :             : DnnlActivationLayer<data_t>(inputDescriptor, inputDescriptor,
     111           3 :                                           dnnl::algorithm::eltwise_elu)
     112             :         {
     113           3 :         }
     114             : 
     115             :         template <typename data_t>
     116           3 :         DnnlExp<data_t>::DnnlExp(const VolumeDescriptor& inputDescriptor)
     117             :             : DnnlActivationLayer<data_t>(inputDescriptor, inputDescriptor,
     118           3 :                                           dnnl::algorithm::eltwise_exp)
     119             :         {
     120           3 :         }
     121             : 
     122             :         template <typename data_t>
     123           0 :         DnnlGelu<data_t>::DnnlGelu(const VolumeDescriptor& inputDescriptor)
     124             :             : DnnlActivationLayer<data_t>(inputDescriptor, inputDescriptor,
     125           0 :                                           dnnl::algorithm::eltwise_gelu)
     126             :         {
     127           0 :         }
     128             : 
     129             :         template <typename data_t>
     130           3 :         DnnlLinear<data_t>::DnnlLinear(const VolumeDescriptor& inputDescriptor)
     131             :             : DnnlActivationLayer<data_t>(inputDescriptor, inputDescriptor,
     132           3 :                                           dnnl::algorithm::eltwise_linear)
     133             :         {
     134           3 :         }
     135             : 
     136             :         template <typename data_t>
     137           3 :         DnnlLogistic<data_t>::DnnlLogistic(const VolumeDescriptor& inputDescriptor)
     138             :             : DnnlActivationLayer<data_t>(inputDescriptor, inputDescriptor,
     139           3 :                                           dnnl::algorithm::eltwise_logistic)
     140             :         {
     141           3 :         }
     142             : 
     143             :         template <typename data_t>
     144           3 :         DnnlRelu<data_t>::DnnlRelu(const VolumeDescriptor& inputDescriptor)
     145             :             : DnnlActivationLayer<data_t>(inputDescriptor, inputDescriptor,
     146           3 :                                           dnnl::algorithm::eltwise_relu)
     147             :         {
     148           3 :         }
     149             : 
     150             :         template <typename data_t>
     151           3 :         DnnlSoftRelu<data_t>::DnnlSoftRelu(const VolumeDescriptor& inputDescriptor)
     152             :             : DnnlActivationLayer<data_t>(inputDescriptor, inputDescriptor,
     153           3 :                                           dnnl::algorithm::eltwise_soft_relu)
     154             :         {
     155           3 :         }
     156             : 
     157             :         template <typename data_t>
     158           0 :         DnnlSqrt<data_t>::DnnlSqrt(const VolumeDescriptor& inputDescriptor)
     159             :             : DnnlActivationLayer<data_t>(inputDescriptor, inputDescriptor,
     160           0 :                                           dnnl::algorithm::eltwise_sqrt)
     161             :         {
     162           0 :         }
     163             : 
     164             :         template <typename data_t>
     165           0 :         DnnlSquare<data_t>::DnnlSquare(const VolumeDescriptor& inputDescriptor)
     166             :             : DnnlActivationLayer<data_t>(inputDescriptor, inputDescriptor,
     167           0 :                                           dnnl::algorithm::eltwise_square)
     168             :         {
     169           0 :         }
     170             : 
     171             :         template <typename data_t>
     172           0 :         DnnlSwish<data_t>::DnnlSwish(const VolumeDescriptor& inputDescriptor)
     173             :             : DnnlActivationLayer<data_t>(inputDescriptor, inputDescriptor,
     174           0 :                                           dnnl::algorithm::eltwise_swish)
     175             :         {
     176           0 :         }
     177             : 
     178             :         template <typename data_t>
     179           3 :         DnnlTanh<data_t>::DnnlTanh(const VolumeDescriptor& inputDescriptor)
     180             :             : DnnlActivationLayer<data_t>(inputDescriptor, inputDescriptor,
     181           3 :                                           dnnl::algorithm::eltwise_tanh)
     182             :         {
     183           3 :         }
     184             : 
     185             :         template class DnnlActivationLayer<float>;
     186             : 
     187             :         template struct DnnlAbs<float>;
     188             :         template struct DnnlBoundedRelu<float>;
     189             :         template struct DnnlElu<float>;
     190             :         template struct DnnlExp<float>;
     191             :         template struct DnnlLinear<float>;
     192             :         template struct DnnlGelu<float>;
     193             :         template struct DnnlLogistic<float>;
     194             :         template struct DnnlRelu<float>;
     195             :         template struct DnnlSoftRelu<float>;
     196             :         template struct DnnlSqrt<float>;
     197             :         template struct DnnlSquare<float>;
     198             :         template struct DnnlSwish<float>;
     199             :         template struct DnnlTanh<float>;
     200             : 
     201             :     } // namespace detail
     202             : } // namespace elsa::ml

Generated by: LCOV version 1.15