Line data Source code
1 : #pragma once
2 :
3 : #include <vector>
4 : #include <list>
5 : #include <memory>
6 : #include <set>
7 : #include <string>
8 :
9 : #include "elsaDefines.h"
10 : #include "TypeCasts.hpp"
11 : #include "Common.h"
12 : #include "State.h"
13 : #include "Layer.h"
14 : #include "Dense.h"
15 : #include "Conv.h"
16 : #include "Pooling.h"
17 : #include "ProgressBar.h"
18 : #include "Utils.h"
19 : #include "IdenticalBlocksDescriptor.h"
20 :
21 : #include "DnnlNoopLayer.h"
22 : #include "DnnlDenseLayer.h"
23 : #include "DnnlConvolution.h"
24 : #include "DnnlPoolingLayer.h"
25 : #include "DnnlFlattenLayer.h"
26 : #include "DnnlSoftmaxLayer.h"
27 : #include "DnnlMerging.h"
28 : #include "DnnlActivationLayer.h"
29 :
30 : namespace elsa::ml
31 : {
32 : namespace detail
33 : {
34 : ELSA_ML_MAKE_BACKEND_LAYER_SELECTOR(Dnnl, Undefined, DnnlLayer);
35 : ELSA_ML_MAKE_BACKEND_LAYER_SELECTOR(Dnnl, Dense, DnnlDenseLayer);
36 : ELSA_ML_MAKE_BACKEND_LAYER_SELECTOR(Dnnl, Softmax, DnnlSoftmaxLayer);
37 : ELSA_ML_MAKE_BACKEND_LAYER_SELECTOR(Dnnl, Conv1D, DnnlConvolution);
38 : ELSA_ML_MAKE_BACKEND_LAYER_SELECTOR(Dnnl, Conv2D, DnnlConvolution);
39 : ELSA_ML_MAKE_BACKEND_LAYER_SELECTOR(Dnnl, Conv3D, DnnlConvolution);
40 :
41 : template <typename data_t, typename GraphType>
42 0 : static void addActivationNode(index_t index, Activation activation,
43 : const VolumeDescriptor& inputDescriptor, GraphType* graph)
44 : {
45 0 : switch (activation) {
46 0 : case Activation::Identity:
47 0 : graph->setData(index, std::make_shared<DnnlNoopLayer<data_t>>(inputDescriptor));
48 0 : break;
49 0 : case Activation::Sigmoid:
50 0 : graph->setData(index, std::make_shared<DnnlLogistic<data_t>>(inputDescriptor));
51 0 : break;
52 0 : case Activation::Relu:
53 0 : graph->setData(index, std::make_shared<DnnlRelu<data_t>>(inputDescriptor));
54 0 : break;
55 0 : case Activation::Tanh:
56 0 : graph->setData(index, std::make_shared<DnnlTanh<data_t>>(inputDescriptor));
57 0 : break;
58 0 : case Activation::Elu:
59 0 : graph->setData(index, std::make_shared<DnnlElu<data_t>>(inputDescriptor));
60 0 : break;
61 0 : default:
62 0 : assert(false && "This execution path of the code should never be reached");
63 : }
64 0 : }
65 :
66 : template <typename GraphType, typename data_t>
67 0 : static void addNodeToBackendGraph(index_t batchSize, GraphType* graph,
68 : const Layer<data_t>* node)
69 : {
70 0 : const index_t nodeIdx = node->getGlobalIndex();
71 :
72 0 : VolumeDescriptor inputDescriptorWithBatchSize = reverseVolumeDescriptor(
73 : attachBatchSizeToVolumeDescriptor(batchSize, node->getInputDescriptor()));
74 :
75 0 : VolumeDescriptor outputDescriptorWithBatchSize = reverseVolumeDescriptor(
76 : attachBatchSizeToVolumeDescriptor(batchSize, node->getOutputDescriptor()));
77 :
78 0 : switch (node->getLayerType()) {
79 0 : case LayerType::Input: {
80 0 : graph->setData(nodeIdx, std::make_shared<DnnlNoopLayer<data_t>>(
81 : /* input-descriptor */
82 : inputDescriptorWithBatchSize));
83 0 : break;
84 : }
85 0 : case LayerType::Flatten: {
86 0 : graph->setData(nodeIdx, std::make_shared<DnnlFlattenLayer<data_t>>(
87 : /* input-descriptor */
88 : inputDescriptorWithBatchSize,
89 : /* output-descriptor */
90 : outputDescriptorWithBatchSize));
91 0 : break;
92 : }
93 0 : case LayerType::Relu:
94 : case LayerType::ClippedRelu:
95 : case LayerType::Sigmoid:
96 : case LayerType::Tanh:
97 : case LayerType::Elu: {
98 0 : auto downcastedLayer = downcast_safe<ActivationBase<data_t>>(node);
99 0 : addActivationNode<data_t>(nodeIdx, downcastedLayer->getActivation(),
100 : inputDescriptorWithBatchSize, graph);
101 0 : break;
102 : }
103 0 : case LayerType::Dense: {
104 : // Downcast to the polymorphic layer type
105 0 : auto downcastedLayer = downcast_safe<Dense<data_t>>(node);
106 :
107 0 : assert(node->getInputDescriptor().getNumberOfDimensions() == 1
108 : && "Dense layer requires 1D input");
109 :
110 : // Build weights descriptor
111 0 : IndexVector_t weightsDims{
112 : {downcastedLayer->getNumberOfUnits(),
113 : node->getInputDescriptor().getNumberOfCoefficients()}};
114 0 : VolumeDescriptor weightsDescriptor(weightsDims);
115 :
116 : // Add to backend-graph
117 0 : graph->setData(nodeIdx, std::make_shared<DnnlDenseLayer<data_t>>(
118 : /* input-descriptor */
119 : inputDescriptorWithBatchSize,
120 : /* output-descriptor */
121 : outputDescriptorWithBatchSize,
122 : /* weights-descriptor */
123 : weightsDescriptor,
124 : /* kernel-initializer */
125 : downcastedLayer->getKernelInitializer()));
126 0 : break;
127 : }
128 0 : case LayerType::Softmax: {
129 : // We don't need any specific information from the front-end softmax
130 : // layer so we just use the base type Layer* and don't have to downcast
131 0 : graph->setData(node->getGlobalIndex(),
132 : std::make_shared<DnnlSoftmaxLayer<data_t>>(
133 : /* input-descriptor */
134 : inputDescriptorWithBatchSize,
135 : /* output-descriptor */
136 : outputDescriptorWithBatchSize));
137 0 : break;
138 : }
139 0 : case LayerType::Conv1D:
140 : case LayerType::Conv2D:
141 : case LayerType::Conv3D: {
142 : // Downcast to the polymorphic layer type
143 0 : auto downcastedLayer = downcast_safe<Conv<data_t>>(node);
144 :
145 : // Build weights descriptor
146 0 : VolumeDescriptor filterDescriptor = reverseVolumeDescriptor(
147 : attachBatchSizeToVolumeDescriptor(downcastedLayer->getNumberOfFilters(),
148 : downcastedLayer->getFilterDescriptor()));
149 :
150 0 : IndexVector_t strideVector(
151 0 : downcastedLayer->getFilterDescriptor().getNumberOfDimensions() - 1);
152 0 : strideVector.fill(downcastedLayer->getStrides());
153 :
154 : // TODO(tellenbach): Handle padding correctly
155 0 : IndexVector_t paddingVector(
156 0 : downcastedLayer->getFilterDescriptor().getNumberOfDimensions() - 1);
157 0 : paddingVector.fill(0);
158 :
159 : // Add to backend-graph
160 0 : graph->setData(nodeIdx, std::make_shared<DnnlConvolution<data_t>>(
161 : /* input-descriptor */
162 : inputDescriptorWithBatchSize,
163 : /* output-descriptor */
164 : outputDescriptorWithBatchSize,
165 : /* weights-descriptor */
166 : filterDescriptor,
167 : /* strides */
168 : strideVector,
169 : /* padding high */
170 : paddingVector,
171 : /* padding low */
172 : paddingVector,
173 : /* kernel-initializer */
174 : downcastedLayer->getKernelInitializer()));
175 0 : break;
176 : }
177 0 : case LayerType::MaxPooling1D:
178 : case LayerType::MaxPooling2D: {
179 : // Downcast to the polymorphic layer type
180 0 : auto downcastedLayer = downcast_safe<Pooling<data_t>>(node);
181 :
182 0 : IndexVector_t poolingWindow = downcastedLayer->getPoolSize();
183 0 : IndexVector_t strides(poolingWindow.size());
184 0 : strides.fill(downcastedLayer->getStrides());
185 :
186 0 : graph->setData(nodeIdx, std::make_shared<DnnlPoolingLayer<data_t>>(
187 : /* input-descriptor */
188 : inputDescriptorWithBatchSize,
189 : /* output-descriptor */
190 : outputDescriptorWithBatchSize,
191 : /* pooling-window */
192 : poolingWindow,
193 : /* strides */
194 : strides));
195 0 : break;
196 : }
197 0 : case LayerType::Sum: {
198 0 : std::vector<VolumeDescriptor> inputDesc;
199 0 : for (int i = 0; i < node->getNumberOfInputs(); ++i) {
200 0 : inputDesc.push_back(
201 : reverseVolumeDescriptor(attachBatchSizeToVolumeDescriptor(
202 : batchSize, node->getInputDescriptor(i))));
203 : }
204 0 : graph->setData(nodeIdx, std::make_shared<DnnlSum<data_t>>(
205 : /* input-descriptor */
206 : inputDesc,
207 : /* output-descriptor */
208 : outputDescriptorWithBatchSize));
209 0 : break;
210 : }
211 :
212 0 : default:
213 : throw std::logic_error("Layer of type "
214 : + getEnumMemberAsString(node->getLayerType())
215 0 : + " is not available for Dnnl backend");
216 : // assert(false && "This execution path of the code should never be reached");
217 : }
218 0 : }
219 :
220 : template <typename data_t, typename GraphType>
221 0 : static void appendActivation(index_t batchSize, GraphType* graph,
222 : std::set<index_t>* outputs)
223 : {
224 : // The API layers allow to specify activation function during the
225 : // construction of a trainable layer. Dnnl handles activation layers
226 : // (at least for training, inference allows to fusing of activation
227 : // function) as separate layers. We therefore insert a separate
228 : // activation layer after every trainable layer in the backend-graph
229 0 : auto& nodes = graph->getNodes();
230 0 : for (auto& node : nodes) {
231 0 : auto layer = getCheckedLayerPtr(node.second);
232 0 : if (layer->isTrainable()) {
233 0 : const index_t nodeIdx = node.first;
234 :
235 : // Get the corresponding node in the front-end graph as
236 : // a trainable layer
237 : auto trainableLayer =
238 0 : downcast<Trainable<data_t>>(State<data_t>::getGraph().getData(nodeIdx));
239 0 : Activation activation = trainableLayer->getActivation();
240 :
241 : // Insert a new node that will hold our activation layer
242 0 : const index_t newIdx = nodeIdx + 1;
243 0 : graph->insertNodeAfter(nodeIdx, newIdx);
244 :
245 : // Input and output are the same for activation layers
246 0 : VolumeDescriptor inputOutputDescriptorWithBatchSize =
247 : reverseVolumeDescriptor(attachBatchSizeToVolumeDescriptor(
248 : batchSize, trainableLayer->getOutputDescriptor()));
249 :
250 : // Add the activation layer
251 0 : addActivationNode<data_t>(newIdx, activation,
252 : inputOutputDescriptorWithBatchSize, graph);
253 :
254 : // If the trainable layer we just handled was one of the
255 : // model's output layers, we have to remove it from the
256 : // list of outputs and add the newly added activation layer
257 0 : if (outputs->find(nodeIdx) != outputs->end()) {
258 0 : outputs->erase(nodeIdx);
259 0 : outputs->insert(newIdx);
260 : }
261 : }
262 : }
263 0 : }
264 :
265 : template <typename GraphType>
266 0 : static void setDnnlEngine(GraphType* graph)
267 : {
268 : // Construct Dnnl CPU-engine
269 0 : auto engine = std::make_shared<dnnl::engine>(dnnl::engine::kind::cpu, 0);
270 :
271 : // Set engine for all backend layer's
272 0 : for (auto&& node : graph->getNodes()) {
273 0 : auto layer = getCheckedLayerPtr(node.second);
274 0 : layer->setEngine(engine);
275 : }
276 0 : }
277 :
278 : template <typename GraphType, typename data_t>
279 0 : static void initializeTrainableParameters(GraphType* graph, Optimizer<data_t>* optimizer)
280 : {
281 : // Initialize all trainable parameters of a layer. This e.g.
282 : // initializes weights and biases
283 0 : auto& nodes = graph->getNodes();
284 0 : for (auto& node : nodes) {
285 0 : auto layer = getCheckedLayerPtr(node.second);
286 0 : if (layer->isTrainable()) {
287 0 : auto trainableLayer =
288 : std::dynamic_pointer_cast<DnnlTrainableLayer<data_t>>(layer);
289 0 : trainableLayer->setOptimizer(optimizer);
290 : }
291 : }
292 0 : }
293 :
294 : template <typename data_t>
295 : struct BackendAdaptor<data_t, MlBackend::Dnnl> {
296 0 : static void constructBackendGraph(Model<data_t, MlBackend::Dnnl>* model)
297 : {
298 : // Get global layer graph
299 0 : auto& graph = State<data_t>::getGraph();
300 :
301 : // Get backend-graph
302 0 : auto& backendGraph = model->getBackendGraph();
303 :
304 : // Get the batch size
305 0 : index_t batchSize = model->getBatchSize();
306 :
307 : // Get all edges of the front-end graph
308 0 : auto& edges = graph.getEdges();
309 :
310 : // For each edge in the front-end graph, add a corresponding
311 : // edge to the backend-graph
312 0 : for (auto&& edges : edges) {
313 0 : auto beginNode = edges.begin()->getData();
314 0 : index_t beginIdx = beginNode->getGlobalIndex();
315 :
316 0 : auto endNode = edges.end()->getData();
317 0 : index_t endIdx = endNode->getGlobalIndex();
318 :
319 0 : backendGraph.addEdge(beginIdx, endIdx);
320 : }
321 :
322 : // Get nodes of the front-end graph
323 0 : auto& nodes = graph.getNodes();
324 :
325 : // Set backend-layer for each node in the backend-graph
326 0 : for (const auto& node : nodes) {
327 0 : auto layer = getCheckedLayerPtr(node.second);
328 0 : index_t idx = layer->getGlobalIndex();
329 0 : if (!backendGraph.getData(idx))
330 0 : addNodeToBackendGraph(batchSize, &backendGraph, layer);
331 : }
332 :
333 : // Set backend outputs
334 0 : for (const auto& output : model->getOutputs())
335 0 : model->backendOutputs_.insert(output->getGlobalIndex());
336 :
337 : // Add a separate activation-layer after each trainable layer
338 0 : appendActivation<data_t>(batchSize, &backendGraph, &model->backendOutputs_);
339 :
340 : // Create Dnnl engine and set for all backend-layers
341 0 : setDnnlEngine(&backendGraph);
342 :
343 : // Initialize all trainable parameters and set optimizer
344 0 : initializeTrainableParameters(&backendGraph, model->getOptimizer());
345 :
346 : // Set number of output-gradients based an connection in the
347 : // backend-graph
348 0 : setNumberOfOutputGradients(&backendGraph);
349 :
350 : // Compile each backend-layer for forward usage and set pointers
351 : // of input and output memory
352 0 : index_t inputIdx = model->inputs_.front()->getGlobalIndex();
353 0 : std::shared_ptr<dnnl::memory> outputMemory = nullptr;
354 0 : backendGraph.visit(
355 : // Start node for traversal
356 : inputIdx,
357 : // visitor for the current node in the traversal
358 0 : [&outputMemory](auto node) {
359 0 : node->compile(PropagationKind::Forward);
360 0 : outputMemory = node->getOutputMemory();
361 0 : assert(outputMemory != nullptr
362 : && "Output memory is null during graph-traversal");
363 : },
364 : // visitor for the current and the next node in the traversal
365 0 : [&outputMemory]([[maybe_unused]] auto node, auto nextNode) {
366 0 : nextNode->setNextInputMemory(outputMemory);
367 : });
368 :
369 : // Compile each backend-layer for backward usage and set pointers
370 : // of output-gradient and input-gradient memory
371 0 : index_t outputIdx = *model->backendOutputs_.begin();
372 :
373 0 : std::vector<std::shared_ptr<dnnl::memory>> inputGradientMemory;
374 0 : index_t inputGradientCounter = 0;
375 0 : backendGraph.visitBackward(
376 : // Start node for traversal
377 : outputIdx,
378 : // visitor for the current node in the traversal
379 0 : [&inputGradientMemory, &inputGradientCounter](auto node) {
380 : // compile the current layer
381 0 : node->compile(PropagationKind::Backward);
382 :
383 0 : inputGradientMemory.clear();
384 0 : inputGradientCounter = 0;
385 0 : for (std::size_t i = 0; i < asUnsigned(node->getNumberOfInputs()); ++i) {
386 0 : inputGradientMemory.push_back(
387 : node->getInputGradientMemory(asSigned(i)));
388 : }
389 : },
390 : // visitor for the current and the next node in the traversal
391 : [&inputGradientMemory, &inputGradientCounter]([[maybe_unused]] auto node,
392 0 : auto prevNode) {
393 : // Get input gradient memory for all
394 0 : assert(inputGradientMemory[inputGradientCounter] != nullptr
395 : && "Input-gradient memory is null during backward graph-traversal");
396 0 : prevNode->setNextOutputGradientMemory(
397 : inputGradientMemory[asUnsigned(inputGradientCounter)]);
398 0 : ++inputGradientCounter;
399 : },
400 0 : []([[maybe_unused]] auto node) { return false; });
401 0 : }
402 :
403 0 : static DataContainer<data_t> predict(Model<data_t, MlBackend::Dnnl>* model,
404 : const DataContainer<data_t>& x)
405 : {
406 0 : auto& backendGraph = model->getBackendGraph();
407 0 : index_t inputIdx = model->inputs_.front()->getGlobalIndex();
408 0 : auto inputLayer = getCheckedLayerPtr(backendGraph.getNodes().at(inputIdx));
409 :
410 : // Create Dnnl execution-stream
411 0 : dnnl::stream executionStream(*inputLayer->getEngine());
412 :
413 : // Set model input
414 0 : inputLayer->setInput(x);
415 :
416 : // Keep track of all node we already handled
417 0 : std::vector<bool> forwardPropagationList(
418 0 : asUnsigned(backendGraph.getNumberOfNodes()), false);
419 :
420 0 : backendGraph.visitWithIndex(
421 : // Start node for traversal
422 : inputIdx,
423 : // visitor for the current node in the traversal
424 0 : [&executionStream, &forwardPropagationList](auto node, index_t s) {
425 0 : node->forwardPropagate(executionStream);
426 0 : assert(forwardPropagationList[asUnsigned(s)] == false);
427 0 : forwardPropagationList[asUnsigned(s)] = true;
428 : },
429 : // visitor for the current and the next node in the traversal
430 : []([[maybe_unused]] auto node, [[maybe_unused]] index_t s,
431 0 : [[maybe_unused]] auto nextNode, [[maybe_unused]] index_t nextS) {},
432 : // We cut a traversal path if the current node can merge and
433 : // if we haven't yet handeld all of its predecessors
434 : [&backendGraph, &forwardPropagationList]([[maybe_unused]] auto node,
435 0 : index_t s) {
436 0 : for (const auto& inEdge : backendGraph.getIncomingEdges(s)) {
437 0 : if (!forwardPropagationList[asUnsigned(inEdge.begin()->getIndex())])
438 0 : return true;
439 : }
440 0 : return false;
441 : });
442 :
443 0 : forwardPropagationList =
444 0 : std::vector<bool>(asUnsigned(backendGraph.getNumberOfNodes()), false);
445 :
446 0 : index_t outputIdx = *std::begin(model->backendOutputs_);
447 0 : auto outputLayer = getCheckedLayerPtr(backendGraph.getNodes().at(outputIdx));
448 0 : return outputLayer->getOutput();
449 : }
450 :
451 : static typename Model<data_t, MlBackend::Dnnl>::History
452 0 : fit(Model<data_t, MlBackend::Dnnl>* model,
453 : const std::vector<DataContainer<data_t>>& x,
454 : const std::vector<DataContainer<data_t>>& y, index_t epochs)
455 : {
456 0 : typename Model<data_t, MlBackend::Dnnl>::History trainingHistory;
457 :
458 0 : auto& backendGraph = model->getBackendGraph();
459 0 : index_t inputIdx = model->inputs_.front()->getGlobalIndex();
460 0 : auto inputLayer = getCheckedLayerPtr(backendGraph.getNodes().at(inputIdx));
461 0 : index_t outputIdx = *std::begin(model->backendOutputs_);
462 0 : auto outputLayer = getCheckedLayerPtr(backendGraph.getNodes().at(outputIdx));
463 0 : auto lossFunc = model->getLoss();
464 :
465 : // Create a Dnnl execution-stream that will be used during
466 : // forward-propagation
467 0 : dnnl::stream executionStream(*inputLayer->getEngine());
468 :
469 : // For all epochs
470 0 : double epochLoss = 0.0;
471 0 : double epochAccuracy = 0.0;
472 0 : index_t correct = 0.0;
473 0 : for (index_t epoch = 0; epoch < epochs; ++epoch) {
474 0 : std::cout << "Epoch " << epoch + 1 << "/" << epochs << "\n";
475 0 : epochLoss = 0.0;
476 0 : epochAccuracy = 0.0;
477 0 : correct = 0.0;
478 0 : ProgressBar progBar(static_cast<uint32_t>(x.size()), 36);
479 0 : for (std::size_t idx = 0; idx < x.size(); ++idx) {
480 : // Set this model's input as the input of the model's
481 : // input layer.
482 0 : inputLayer->setInput(x[asUnsigned(idx)]);
483 :
484 : // Keep track of all nodes we already handled
485 0 : std::vector<bool> nodeList(asUnsigned(backendGraph.getNumberOfNodes()),
486 : false);
487 :
488 0 : backendGraph.visitWithIndex(
489 : // Start node for traversal
490 : inputIdx,
491 : // visitor for the current node in the traversal
492 0 : [&executionStream, &nodeList](auto node, index_t s) {
493 0 : node->forwardPropagate(executionStream);
494 0 : assert(nodeList[asUnsigned(s)] == false);
495 0 : nodeList[asUnsigned(s)] = true;
496 : },
497 : // visitor for the current and the next node in the traversal
498 : []([[maybe_unused]] auto node, [[maybe_unused]] index_t s,
499 0 : [[maybe_unused]] auto nextNode, [[maybe_unused]] index_t nextS) {},
500 : // We cut a traversal path if we haven't handled all
501 : // predecessors of the current node
502 0 : [&backendGraph, &nodeList]([[maybe_unused]] auto node, index_t s) {
503 0 : for (const auto& inEdge : backendGraph.getIncomingEdges(s)) {
504 0 : if (!nodeList[asUnsigned(inEdge.begin()->getIndex())])
505 0 : return true;
506 : }
507 0 : return false;
508 : });
509 :
510 0 : nodeList =
511 0 : std::vector<bool>(asUnsigned(backendGraph.getNumberOfNodes()), false);
512 :
513 0 : auto output = outputLayer->getOutput();
514 :
515 : // Get accuracy
516 0 : auto label = Utils::Encoding::fromOneHot(output, 10);
517 0 : for (int i = 0; i < model->batchSize_; ++i) {
518 0 : if (label[i] == y[asUnsigned(idx)][i]) {
519 0 : correct += 1;
520 : }
521 : }
522 :
523 0 : epochAccuracy =
524 : (static_cast<double>(correct)
525 0 : / static_cast<double>(((idx + 1) * asUnsigned(model->batchSize_))));
526 :
527 : // Loss calculation
528 0 : trainingHistory.loss.push_back(lossFunc(output, y[asUnsigned(idx)]));
529 0 : epochLoss += trainingHistory.loss.back();
530 :
531 0 : ++progBar;
532 0 : std::string preMessage =
533 : std::to_string(idx) + "/" + std::to_string(x.size()) + " ";
534 0 : progBar.display(
535 : preMessage,
536 : "- " + lossFunc.getName() + ": "
537 0 : + std::to_string(epochLoss / static_cast<double>(idx + 1))
538 : + " - Accuracy: " + std::to_string(epochAccuracy));
539 :
540 0 : outputLayer->setOutputGradient(
541 0 : lossFunc.getLossGradient(output, y[asUnsigned(idx)]));
542 :
543 : // Backward-propagate all nodes, starting at the output
544 : // until we reach the input.
545 0 : backendGraph.visitBackwardWithIndex(
546 : // Start node for traversal
547 : outputIdx,
548 : // Visitor for the current node in the traversal.
549 0 : [&executionStream, &nodeList](auto node, index_t s) {
550 0 : node->backwardPropagate(executionStream);
551 0 : assert(nodeList[asUnsigned(s)] == false);
552 0 : nodeList[asUnsigned(s)] = true;
553 : },
554 : // Visitor for the current and the next node in the
555 : // traversal.
556 : []([[maybe_unused]] auto node, [[maybe_unused]] index_t s,
557 0 : [[maybe_unused]] auto nextNode, [[maybe_unused]] index_t nextS) {},
558 0 : [&backendGraph, &nodeList]([[maybe_unused]] auto node, index_t s) {
559 0 : for (const auto& outEdge : backendGraph.getOutgoingEdges(s)) {
560 0 : if (!nodeList[asUnsigned(outEdge.end()->getIndex())])
561 0 : return true;
562 : }
563 0 : return false;
564 : });
565 :
566 : // Set a barrier for the backward-stream.
567 0 : executionStream.wait();
568 :
569 : // If we are done with this batch, we update all
570 : // trainable parameters. This also resets all
571 : // accumulated gradients.
572 0 : for (auto&& node : backendGraph.getNodes()) {
573 0 : auto layer = getCheckedLayerPtr(node.second);
574 0 : if (layer->isTrainable()) {
575 0 : std::static_pointer_cast<DnnlTrainableLayer<data_t>>(layer)
576 : ->updateTrainableParameters();
577 : }
578 : }
579 : }
580 0 : progBar.done(std::to_string(x.size()) + "/" + std::to_string(x.size()) + " ",
581 : "");
582 : }
583 0 : return trainingHistory;
584 : }
585 : };
586 : } // namespace detail
587 : } // namespace elsa::ml
|