Line data Source code
1 : #include "DnnlMerging.h"
2 : #include "TypeCasts.hpp"
3 :
4 : namespace elsa::ml
5 : {
6 : namespace detail
7 : {
8 : template <typename data_t>
9 2 : DnnlMerging<data_t>::DnnlMerging(const std::vector<VolumeDescriptor>& inputDescriptors,
10 : const VolumeDescriptor& outputDescriptor)
11 : : DnnlLayer<data_t>(inputDescriptors, outputDescriptor, "DnnlMerging",
12 2 : DnnlLayer<data_t>::anyNumberOfInputs)
13 : {
14 2 : }
15 :
16 : template <typename data_t>
17 3 : bool DnnlMerging<data_t>::needsForwardSynchronisation() const
18 : {
19 3 : return true;
20 : }
21 :
22 : template <typename data_t>
23 2 : bool DnnlMerging<data_t>::canMerge() const
24 : {
25 2 : return true;
26 : }
27 :
28 : template <typename data_t>
29 2 : DnnlSum<data_t>::DnnlSum(const std::vector<VolumeDescriptor>& inputDescriptors,
30 : const VolumeDescriptor& outputDescriptor)
31 2 : : DnnlMerging<data_t>(inputDescriptors, outputDescriptor)
32 : {
33 : // Check that all input-descriptors are equal
34 6 : assert(std::adjacent_find(inputDescriptors.begin(), inputDescriptors.end(),
35 : [](const auto& a, const auto& b) { return a != b; })
36 : == inputDescriptors.end()
37 : && "All input-descriptors for DnnlSum must be equal");
38 2 : }
39 :
40 : template <typename data_t>
41 2 : void DnnlSum<data_t>::compileForwardStream()
42 : {
43 2 : BaseType::compileForwardStream();
44 :
45 4 : std::vector<dnnl::memory> mem;
46 4 : std::vector<dnnl::memory::desc> memDesc;
47 8 : for (std::size_t i = 0; i < _input.size(); ++i) {
48 6 : memDesc.push_back(_input[i].descriptor);
49 6 : BaseType::validateDnnlMemory(_input[i].effectiveMemory);
50 6 : mem.push_back(*_input[i].effectiveMemory);
51 : }
52 :
53 : // We currently do not support custom scaling since the API does not support it
54 4 : std::vector<data_t> scales(_input.size(), data_t(1));
55 :
56 : // Create primitive-descriptor
57 2 : _forwardPrimitiveDescriptor = dnnl::sum::primitive_desc(scales, memDesc, *_engine);
58 :
59 8 : for (std::size_t i = 0; i < _input.size(); ++i) {
60 : // Reoder input memory if necessary
61 6 : this->reorderMemory(_forwardPrimitiveDescriptor.src_desc(), _input[i],
62 6 : _forwardStream);
63 : }
64 :
65 : // Add sum primitive to forward-stream
66 2 : ELSA_ML_ADD_DNNL_PRIMITIVE(_forwardStream, dnnl::sum(_forwardPrimitiveDescriptor));
67 :
68 : // Allocate output memory
69 2 : _output.effectiveMemory =
70 2 : std::make_shared<dnnl::memory>(_forwardPrimitiveDescriptor.dst_desc(), *_engine);
71 :
72 : // Validate memory
73 2 : BaseType::validateDnnlMemory(_output.effectiveMemory);
74 :
75 : // Add arguments to forward-stream
76 4 : _forwardStream.arguments.push_back({{DNNL_ARG_DST, *_output.effectiveMemory}});
77 8 : for (std::size_t i = 0; i < _input.size(); ++i) {
78 6 : _forwardStream.arguments.back().insert({DNNL_ARG_MULTIPLE_SRC + i, mem[i]});
79 : }
80 2 : }
81 :
82 : template <typename data_t>
83 1 : void DnnlSum<data_t>::compileBackwardStream()
84 : {
85 1 : BaseType::compileBackwardStream();
86 :
87 : // Allocate memory for input-gradient if necessary
88 4 : for (std::size_t i = 0; i < _inputGradient.size(); ++i) {
89 3 : if (!_inputGradient[i].effectiveMemory) {
90 3 : _inputGradient[i].effectiveMemory = std::make_shared<dnnl::memory>(
91 3 : dnnl::memory::desc({{_inputGradient[i].dimensions},
92 : this->_typeTag,
93 3 : _inputGradient[i].formatTag}),
94 3 : *_engine);
95 : }
96 : }
97 1 : _outputGradient.front().effectiveMemory = _outputGradient.front().describedMemory;
98 1 : }
99 :
100 : template <typename data_t>
101 1 : void DnnlSum<data_t>::backwardPropagate([[maybe_unused]] dnnl::stream& executionStream)
102 : {
103 : // make sure backward stream has been compiled
104 1 : assert(_backwardStream.isCompiled
105 : && "Cannot backward propagate because backward-stream has not been compiled");
106 :
107 : // We derive the gradient for a sum layer as follows:
108 : //
109 : //
110 : // | ^ | ^
111 : // i0 | | dE/di0 i1 | | dE/di1
112 : // v | v |
113 : // +--------------------------------+
114 : // | SUM |
115 : // +--------------------------------+
116 : // | ^
117 : // o | | dE/do
118 : // v |
119 : //
120 : // The input-gradient along the path of input i0 is given by
121 : // dE/di0 = dE/do * do/di0
122 : // ^^^^^ ^^^^^^
123 : // | | i0 as the partial derivative of i0+i1
124 : // | output-gradient
125 :
126 : // Get output-gradient memory
127 1 : Eigen::Map<Eigen::ArrayX<data_t>> outputGrad(
128 1 : static_cast<data_t*>(_outputGradient.front().effectiveMemory->get_data_handle()),
129 : _outputDescriptor->getNumberOfCoefficients());
130 :
131 4 : for (std::size_t i = 0; i < _inputGradient.size(); ++i) {
132 3 : BaseType::validateDnnlMemory(_inputGradient[i].effectiveMemory);
133 3 : BaseType::validateDnnlMemory(_outputGradient.front().effectiveMemory);
134 3 : BaseType::validateDnnlMemory(_input[i].effectiveMemory);
135 :
136 : // Get input-gradient memory
137 6 : Eigen::Map<Eigen::ArrayX<data_t>> inputGrad(
138 3 : static_cast<data_t*>(_inputGradient[i].effectiveMemory->get_data_handle()),
139 3 : _inputDescriptor[i]->getNumberOfCoefficients());
140 :
141 : // Get input memory
142 6 : Eigen::Map<Eigen::ArrayX<data_t>> input(
143 3 : static_cast<data_t*>(_input[i].effectiveMemory->get_data_handle()),
144 3 : _inputDescriptor[i]->getNumberOfCoefficients());
145 :
146 : // Compute input-gradient
147 3 : inputGrad = outputGrad * input;
148 : }
149 1 : }
150 :
151 : template <typename data_t>
152 0 : DnnlConcatenate<data_t>::DnnlConcatenate(
153 : index_t axis, const std::vector<VolumeDescriptor>& inputDescriptors,
154 : const VolumeDescriptor& outputDescriptor)
155 0 : : DnnlMerging<data_t>(inputDescriptors, outputDescriptor), _axis(axis)
156 : {
157 :
158 : // Check that all input-descriptors are equal
159 0 : assert(std::adjacent_find(inputDescriptors.begin(), inputDescriptors.end(),
160 : [](const auto& a, const auto& b) { return a != b; })
161 : == inputDescriptors.end()
162 : && "All input-descriptors for DnnlSum must be equal");
163 0 : }
164 :
165 : template <typename data_t>
166 0 : void DnnlConcatenate<data_t>::compileForwardStream()
167 : {
168 0 : BaseType::compileForwardStream();
169 :
170 0 : std::vector<dnnl::memory> mem;
171 0 : std::vector<dnnl::memory::desc> memDesc;
172 0 : for (std::size_t i = 0; i < _input.size(); ++i) {
173 0 : memDesc.push_back(_input[i].descriptor);
174 0 : BaseType::validateDnnlMemory(_input[i].effectiveMemory);
175 0 : mem.push_back(*_input[i].effectiveMemory);
176 : }
177 :
178 : // Create primitive-descriptor
179 0 : _forwardPrimitiveDescriptor =
180 0 : dnnl::concat::primitive_desc(as<int>(_axis), memDesc, *_engine);
181 :
182 0 : for (std::size_t i = 0; i < _input.size(); ++i) {
183 : // Reoder input memory if necessary
184 0 : this->reorderMemory(_forwardPrimitiveDescriptor.src_desc(), _input[i],
185 0 : _forwardStream);
186 : }
187 :
188 : // Add sum primitive to forward-stream
189 0 : ELSA_ML_ADD_DNNL_PRIMITIVE(_forwardStream, dnnl::concat(_forwardPrimitiveDescriptor));
190 :
191 : // Allocate output memory
192 0 : _output.effectiveMemory =
193 0 : std::make_shared<dnnl::memory>(_forwardPrimitiveDescriptor.dst_desc(), *_engine);
194 :
195 : // Validate memory
196 0 : BaseType::validateDnnlMemory(_output.effectiveMemory);
197 :
198 : // Add arguments to forward-stream
199 0 : _forwardStream.arguments.push_back({{DNNL_ARG_DST, *_output.effectiveMemory}});
200 0 : for (std::size_t i = 0; i < _input.size(); ++i) {
201 0 : _forwardStream.arguments.back().insert({DNNL_ARG_MULTIPLE_SRC + i, mem[i]});
202 : }
203 0 : }
204 :
205 : template <typename data_t>
206 0 : void DnnlConcatenate<data_t>::compileBackwardStream()
207 : {
208 0 : BaseType::compileBackwardStream();
209 :
210 : // Allocate memory for input-gradient if necessary
211 0 : for (std::size_t i = 0; i < _inputGradient.size(); ++i) {
212 0 : if (!_inputGradient[i].effectiveMemory) {
213 0 : _inputGradient[i].effectiveMemory = std::make_shared<dnnl::memory>(
214 0 : dnnl::memory::desc({{_inputGradient[i].dimensions},
215 : this->_typeTag,
216 0 : _inputGradient[i].formatTag}),
217 0 : *_engine);
218 : }
219 : }
220 0 : _outputGradient.front().effectiveMemory = _outputGradient.front().describedMemory;
221 0 : }
222 :
223 : template <typename data_t>
224 0 : void DnnlConcatenate<data_t>::backwardPropagate([
225 : [maybe_unused]] dnnl::stream& executionStream)
226 : {
227 : // make sure backward stream has been compiled
228 0 : assert(_backwardStream.isCompiled
229 : && "Cannot backward propagate because backward-stream has not been compiled");
230 :
231 : // We derive the gradient for a concat layer as follows:
232 : //
233 : // If the Concatenate layer receives three inputs i0, i1, i2 with
234 : // shapes (n, c0, h, w), (n, c1, h, w) and (n, c2, h, w)
235 : // respectively and c is the concatenation axis, the output has
236 : // shape (n, c0+c1+c2, h, w).
237 : //
238 : // The incoming gradient for the Concatentation layer has then
239 : // also shape (n, c0+c1+c2, h, w).
240 : //
241 : // The gradient for each of the inputs is then the slice of the
242 : // incoming gradient along c that matches the slice of the input,
243 : // e.g. i0 gets slice (n, c0, h, w) of the incoming gradient.
244 0 : }
245 :
246 : template class DnnlMerging<float>;
247 : template class DnnlSum<float>;
248 : template class DnnlConcatenate<float>;
249 : } // namespace detail
250 : } // namespace elsa::ml
|