Line data Source code
1 : #include "Merging.h" 2 : 3 : namespace elsa::ml 4 : { 5 : template <typename data_t> 6 0 : Merging<data_t>::Merging(LayerType layerType, std::initializer_list<Layer<data_t>*> inputs, 7 : const std::string& name) 8 0 : : Layer<data_t>(layerType, name) 9 : { 10 0 : this->setInput(inputs); 11 0 : } 12 : 13 : template <typename data_t> 14 0 : bool Merging<data_t>::canMerge() const 15 : { 16 0 : return true; 17 : } 18 : 19 : template <typename data_t> 20 0 : Sum<data_t>::Sum(std::initializer_list<Layer<data_t>*> inputs, const std::string& name) 21 0 : : Merging<data_t>(LayerType::Sum, inputs, name) 22 : { 23 0 : } 24 : 25 : template <typename data_t> 26 0 : void Sum<data_t>::computeOutputDescriptor() 27 : { 28 0 : if (std::adjacent_find(this->inputDescriptors_.begin(), this->inputDescriptors_.end(), 29 0 : [](const auto& a, const auto& b) { return *a != *b; }) 30 0 : != this->inputDescriptors_.end()) { 31 0 : throw std::invalid_argument("All inputs for Sum layer must have the same shape"); 32 : } 33 : 34 : // At this point we are sure that all input descriptors match. Since we just 35 : // compute the coeff-wise sum it's enough to take just one of the input 36 : // descriptors as output descriptor 37 0 : this->outputDescriptor_ = this->inputDescriptors_.front()->clone(); 38 0 : } 39 : 40 : template <typename data_t> 41 0 : Concatenate<data_t>::Concatenate(index_t axis, std::initializer_list<Layer<data_t>*> inputs, 42 : const std::string& name) 43 0 : : Merging<data_t>(LayerType::Concatenate, inputs, name), axis_(axis) 44 : { 45 0 : } 46 : 47 : template <typename data_t> 48 0 : void Concatenate<data_t>::computeOutputDescriptor() 49 : { 50 : 51 0 : index_t concatDim = 0; 52 0 : for (const auto& in : this->inputDescriptors_) { 53 0 : concatDim += in->getNumberOfCoefficientsPerDimension()[axis_]; 54 : } 55 0 : IndexVector_t dims(this->inputDescriptors_.front()->getNumberOfDimensions()); 56 0 : for (int i = 0; i < this->inputDescriptors_.front()->getNumberOfDimensions(); ++i) 57 0 : dims[i] = this->inputDescriptors_.front()->getNumberOfCoefficientsPerDimension()[i]; 58 0 : dims[axis_] = concatDim; 59 0 : this->outputDescriptor_ = VolumeDescriptor(dims).clone(); 60 0 : } 61 : 62 : template class Merging<float>; 63 : template class Sum<float>; 64 : template class Concatenate<float>; 65 : } // namespace elsa::ml