LCOV - code coverage report
Current view: top level - ml - Merging.cpp (source / functions) Hit Total Coverage
Test: test_coverage.info.cleaned Lines: 0 29 0.0 %
Date: 2022-07-06 02:47:47 Functions: 0 7 0.0 %

          Line data    Source code
       1             : #include "Merging.h"
       2             : 
       3             : namespace elsa::ml
       4             : {
       5             :     template <typename data_t>
       6           0 :     Merging<data_t>::Merging(LayerType layerType, std::initializer_list<Layer<data_t>*> inputs,
       7             :                              const std::string& name)
       8           0 :         : Layer<data_t>(layerType, name)
       9             :     {
      10           0 :         this->setInput(inputs);
      11           0 :     }
      12             : 
      13             :     template <typename data_t>
      14           0 :     bool Merging<data_t>::canMerge() const
      15             :     {
      16           0 :         return true;
      17             :     }
      18             : 
      19             :     template <typename data_t>
      20           0 :     Sum<data_t>::Sum(std::initializer_list<Layer<data_t>*> inputs, const std::string& name)
      21           0 :         : Merging<data_t>(LayerType::Sum, inputs, name)
      22             :     {
      23           0 :     }
      24             : 
      25             :     template <typename data_t>
      26           0 :     void Sum<data_t>::computeOutputDescriptor()
      27             :     {
      28           0 :         if (std::adjacent_find(this->inputDescriptors_.begin(), this->inputDescriptors_.end(),
      29           0 :                                [](const auto& a, const auto& b) { return *a != *b; })
      30           0 :             != this->inputDescriptors_.end()) {
      31           0 :             throw std::invalid_argument("All inputs for Sum layer must have the same shape");
      32             :         }
      33             : 
      34             :         // At this point we are sure that all input descriptors match. Since we just
      35             :         // compute the coeff-wise sum it's enough to take just one of the input
      36             :         // descriptors as output descriptor
      37           0 :         this->outputDescriptor_ = this->inputDescriptors_.front()->clone();
      38           0 :     }
      39             : 
      40             :     template <typename data_t>
      41           0 :     Concatenate<data_t>::Concatenate(index_t axis, std::initializer_list<Layer<data_t>*> inputs,
      42             :                                      const std::string& name)
      43           0 :         : Merging<data_t>(LayerType::Concatenate, inputs, name), axis_(axis)
      44             :     {
      45           0 :     }
      46             : 
      47             :     template <typename data_t>
      48           0 :     void Concatenate<data_t>::computeOutputDescriptor()
      49             :     {
      50             : 
      51           0 :         index_t concatDim = 0;
      52           0 :         for (const auto& in : this->inputDescriptors_) {
      53           0 :             concatDim += in->getNumberOfCoefficientsPerDimension()[axis_];
      54             :         }
      55           0 :         IndexVector_t dims(this->inputDescriptors_.front()->getNumberOfDimensions());
      56           0 :         for (int i = 0; i < this->inputDescriptors_.front()->getNumberOfDimensions(); ++i)
      57           0 :             dims[i] = this->inputDescriptors_.front()->getNumberOfCoefficientsPerDimension()[i];
      58           0 :         dims[axis_] = concatDim;
      59           0 :         this->outputDescriptor_ = VolumeDescriptor(dims).clone();
      60           0 :     }
      61             : 
      62             :     template class Merging<float>;
      63             :     template class Sum<float>;
      64             :     template class Concatenate<float>;
      65             : } // namespace elsa::ml

Generated by: LCOV version 1.15