Line data Source code
1 : #include "DnnlActivationLayer.h"
2 :
3 : namespace elsa::ml
4 : {
5 : namespace detail
6 : {
7 : template <typename data_t>
8 24 : DnnlActivationLayer<data_t>::DnnlActivationLayer(const VolumeDescriptor& inputDescriptor,
9 : const VolumeDescriptor& outputDescriptor,
10 : dnnl::algorithm algorithm)
11 : : DnnlLayer<data_t>(inputDescriptor, outputDescriptor, "DnnlActivationLayer"),
12 24 : algorithm_(algorithm)
13 : {
14 24 : }
15 :
16 : template <typename data_t>
17 24 : void DnnlActivationLayer<data_t>::setAlpha(data_t alpha)
18 : {
19 24 : _alpha = alpha;
20 24 : }
21 :
22 : template <typename data_t>
23 24 : void DnnlActivationLayer<data_t>::setBeta(data_t beta)
24 : {
25 24 : _beta = beta;
26 24 : }
27 :
28 : template <typename data_t>
29 24 : void DnnlActivationLayer<data_t>::compileForwardStream()
30 : {
31 24 : BaseType::compileForwardStream();
32 :
33 : // Set forward primitive description
34 48 : auto desc = dnnl::eltwise_forward::desc(
35 : /* Inference type */ dnnl::prop_kind::forward_training,
36 : /* Element-wise algorithm */ algorithm_,
37 24 : /* Source memory descriptor */ _input.front().descriptor,
38 : /* Alpha parameter */ _alpha,
39 : /* Beta parameter */ _beta);
40 :
41 24 : _forwardPrimitiveDescriptor = dnnl::eltwise_forward::primitive_desc(desc, *_engine);
42 :
43 : // Set forward primitive
44 24 : ELSA_ML_ADD_DNNL_PRIMITIVE(_forwardStream,
45 : dnnl::eltwise_forward(_forwardPrimitiveDescriptor));
46 :
47 : // Set output memory. Since no activation layer can reorder we set effective memory
48 : // directly
49 24 : _output.effectiveMemory =
50 24 : std::make_shared<dnnl::memory>(_forwardPrimitiveDescriptor.dst_desc(), *_engine);
51 :
52 96 : _forwardStream.arguments.push_back({{DNNL_ARG_SRC, *_input.front().effectiveMemory},
53 24 : {DNNL_ARG_DST, *_output.effectiveMemory}});
54 24 : }
55 :
56 : template <typename data_t>
57 24 : void DnnlActivationLayer<data_t>::compileBackwardStream()
58 : {
59 24 : BaseType::compileBackwardStream();
60 :
61 48 : auto desc = dnnl::eltwise_backward::desc(
62 : /* Element-wise algorithm */ algorithm_,
63 24 : /* Gradient dst memory descriptor */ _outputGradient.front().descriptor,
64 24 : /* Source memory descriptor */ _input.front().descriptor,
65 : /* Alpha parameter */ _alpha,
66 : /* Beta parameter */ _beta);
67 :
68 24 : _backwardPrimitiveDescriptor =
69 24 : dnnl::eltwise_backward::primitive_desc(desc, *_engine, _forwardPrimitiveDescriptor);
70 :
71 : // Reorder if necessary
72 24 : this->reorderMemory(_backwardPrimitiveDescriptor.diff_dst_desc(),
73 24 : _outputGradient.front(), _backwardStream);
74 :
75 24 : _inputGradient.front().effectiveMemory = std::make_shared<dnnl::memory>(
76 24 : _backwardPrimitiveDescriptor.diff_src_desc(), *_engine);
77 :
78 24 : _outputGradient.front().effectiveMemory = _outputGradient.front().describedMemory;
79 24 : BaseType::validateDnnlMemory(_input.front().effectiveMemory);
80 24 : BaseType::validateDnnlMemory(_outputGradient.front().effectiveMemory);
81 24 : BaseType::validateDnnlMemory(_outputGradient.front().describedMemory);
82 24 : BaseType::validateDnnlMemory(_inputGradient.front().effectiveMemory);
83 :
84 24 : ELSA_ML_ADD_DNNL_PRIMITIVE(_backwardStream,
85 : dnnl::eltwise_backward(_backwardPrimitiveDescriptor));
86 144 : _backwardStream.arguments.push_back(
87 : {/* Input */
88 24 : {DNNL_ARG_SRC, *_input.front().effectiveMemory},
89 24 : {DNNL_ARG_DIFF_DST, *_outputGradient.front().effectiveMemory},
90 : /* Output */
91 24 : {DNNL_ARG_DIFF_SRC, *_inputGradient.front().effectiveMemory}});
92 24 : }
93 :
94 : template <typename data_t>
95 3 : DnnlAbs<data_t>::DnnlAbs(const VolumeDescriptor& inputDescriptor)
96 : : DnnlActivationLayer<data_t>(inputDescriptor, inputDescriptor,
97 3 : dnnl::algorithm::eltwise_abs)
98 : {
99 3 : }
100 :
101 : template <typename data_t>
102 0 : DnnlBoundedRelu<data_t>::DnnlBoundedRelu(const VolumeDescriptor& inputDescriptor)
103 : : DnnlActivationLayer<data_t>(inputDescriptor, inputDescriptor,
104 0 : dnnl::algorithm::eltwise_bounded_relu)
105 : {
106 0 : }
107 :
108 : template <typename data_t>
109 3 : DnnlElu<data_t>::DnnlElu(const VolumeDescriptor& inputDescriptor)
110 : : DnnlActivationLayer<data_t>(inputDescriptor, inputDescriptor,
111 3 : dnnl::algorithm::eltwise_elu)
112 : {
113 3 : }
114 :
115 : template <typename data_t>
116 3 : DnnlExp<data_t>::DnnlExp(const VolumeDescriptor& inputDescriptor)
117 : : DnnlActivationLayer<data_t>(inputDescriptor, inputDescriptor,
118 3 : dnnl::algorithm::eltwise_exp)
119 : {
120 3 : }
121 :
122 : template <typename data_t>
123 0 : DnnlGelu<data_t>::DnnlGelu(const VolumeDescriptor& inputDescriptor)
124 : : DnnlActivationLayer<data_t>(inputDescriptor, inputDescriptor,
125 0 : dnnl::algorithm::eltwise_gelu)
126 : {
127 0 : }
128 :
129 : template <typename data_t>
130 3 : DnnlLinear<data_t>::DnnlLinear(const VolumeDescriptor& inputDescriptor)
131 : : DnnlActivationLayer<data_t>(inputDescriptor, inputDescriptor,
132 3 : dnnl::algorithm::eltwise_linear)
133 : {
134 3 : }
135 :
136 : template <typename data_t>
137 3 : DnnlLogistic<data_t>::DnnlLogistic(const VolumeDescriptor& inputDescriptor)
138 : : DnnlActivationLayer<data_t>(inputDescriptor, inputDescriptor,
139 3 : dnnl::algorithm::eltwise_logistic)
140 : {
141 3 : }
142 :
143 : template <typename data_t>
144 3 : DnnlRelu<data_t>::DnnlRelu(const VolumeDescriptor& inputDescriptor)
145 : : DnnlActivationLayer<data_t>(inputDescriptor, inputDescriptor,
146 3 : dnnl::algorithm::eltwise_relu)
147 : {
148 3 : }
149 :
150 : template <typename data_t>
151 3 : DnnlSoftRelu<data_t>::DnnlSoftRelu(const VolumeDescriptor& inputDescriptor)
152 : : DnnlActivationLayer<data_t>(inputDescriptor, inputDescriptor,
153 3 : dnnl::algorithm::eltwise_soft_relu)
154 : {
155 3 : }
156 :
157 : template <typename data_t>
158 0 : DnnlSqrt<data_t>::DnnlSqrt(const VolumeDescriptor& inputDescriptor)
159 : : DnnlActivationLayer<data_t>(inputDescriptor, inputDescriptor,
160 0 : dnnl::algorithm::eltwise_sqrt)
161 : {
162 0 : }
163 :
164 : template <typename data_t>
165 0 : DnnlSquare<data_t>::DnnlSquare(const VolumeDescriptor& inputDescriptor)
166 : : DnnlActivationLayer<data_t>(inputDescriptor, inputDescriptor,
167 0 : dnnl::algorithm::eltwise_square)
168 : {
169 0 : }
170 :
171 : template <typename data_t>
172 0 : DnnlSwish<data_t>::DnnlSwish(const VolumeDescriptor& inputDescriptor)
173 : : DnnlActivationLayer<data_t>(inputDescriptor, inputDescriptor,
174 0 : dnnl::algorithm::eltwise_swish)
175 : {
176 0 : }
177 :
178 : template <typename data_t>
179 3 : DnnlTanh<data_t>::DnnlTanh(const VolumeDescriptor& inputDescriptor)
180 : : DnnlActivationLayer<data_t>(inputDescriptor, inputDescriptor,
181 3 : dnnl::algorithm::eltwise_tanh)
182 : {
183 3 : }
184 :
185 : template class DnnlActivationLayer<float>;
186 :
187 : template struct DnnlAbs<float>;
188 : template struct DnnlBoundedRelu<float>;
189 : template struct DnnlElu<float>;
190 : template struct DnnlExp<float>;
191 : template struct DnnlLinear<float>;
192 : template struct DnnlGelu<float>;
193 : template struct DnnlLogistic<float>;
194 : template struct DnnlRelu<float>;
195 : template struct DnnlSoftRelu<float>;
196 : template struct DnnlSqrt<float>;
197 : template struct DnnlSquare<float>;
198 : template struct DnnlSwish<float>;
199 : template struct DnnlTanh<float>;
200 :
201 : } // namespace detail
202 : } // namespace elsa::ml
|