Line data Source code
1 : #include "Reshape.h"
2 :
3 : namespace elsa::ml
4 : {
5 : template <typename data_t>
6 1 : Reshape<data_t>::Reshape(const VolumeDescriptor& targetShape, const std::string& name)
7 1 : : Layer<data_t>(LayerType::Reshape, name)
8 : {
9 :
10 1 : this->outputDescriptor_ = targetShape.clone();
11 1 : }
12 :
13 : template <typename data_t>
14 1 : void Reshape<data_t>::computeOutputDescriptor()
15 : {
16 1 : if (this->outputDescriptor_->getNumberOfCoefficients()
17 1 : != this->inputDescriptors_.front()->getNumberOfCoefficients())
18 : throw std::invalid_argument(
19 0 : "Descriptors of input and reshaping target must be of same size");
20 1 : }
21 :
22 : template <typename data_t>
23 1 : Flatten<data_t>::Flatten(const std::string& name) : Layer<data_t>(LayerType::Flatten, name)
24 : {
25 1 : }
26 :
27 : template <typename data_t>
28 1 : void Flatten<data_t>::computeOutputDescriptor()
29 : {
30 1 : IndexVector_t dims(1);
31 1 : dims[0] = this->inputDescriptors_.front()->getNumberOfCoefficientsPerDimension().prod();
32 1 : this->outputDescriptor_ = VolumeDescriptor(dims).clone();
33 1 : }
34 :
35 : template <typename data_t, LayerType layerType, index_t upSamplingDimensions>
36 0 : UpSampling<data_t, layerType, upSamplingDimensions>::UpSampling(
37 : const std::array<index_t, upSamplingDimensions>& size, Interpolation interpolation,
38 : const std::string& name)
39 0 : : Layer<data_t>(layerType, name), size_(size), interpolation_(interpolation)
40 : {
41 0 : }
42 :
43 : template <typename data_t, LayerType layerType, index_t upSamplingDimensions>
44 0 : void UpSampling<data_t, layerType, upSamplingDimensions>::computeOutputDescriptor()
45 : {
46 0 : IndexVector_t dims = this->inputDescriptors_.front()->getNumberOfCoefficientsPerDimension();
47 :
48 : // spatial dims that get upsampled
49 : if constexpr (upSamplingDimensions >= 1) {
50 0 : dims[0] *= size_[0];
51 : }
52 :
53 : if constexpr (upSamplingDimensions >= 2) {
54 0 : dims[1] *= size_[1];
55 : }
56 :
57 : if constexpr (upSamplingDimensions == 3) {
58 : dims[2] *= size_[2];
59 : }
60 :
61 0 : this->outputDescriptor_ = VolumeDescriptor(dims).clone();
62 0 : }
63 :
64 : template <typename data_t, LayerType layerType, index_t upSamplingDimensions>
65 0 : Interpolation UpSampling<data_t, layerType, upSamplingDimensions>::getInterpolation() const
66 : {
67 0 : return interpolation_;
68 : }
69 :
70 : template <typename data_t>
71 : UpSampling1D<data_t>::UpSampling1D(const std::array<index_t, 1>& size,
72 : Interpolation interpolation, const std::string& name)
73 : : UpSampling<data_t, LayerType::UpSampling1D, 1>(size, interpolation, name)
74 : {
75 : }
76 :
77 : template <typename data_t>
78 0 : UpSampling2D<data_t>::UpSampling2D(const std::array<index_t, 2>& size,
79 : Interpolation interpolation, const std::string& name)
80 0 : : UpSampling<data_t, LayerType::UpSampling2D, 2>(size, interpolation, name)
81 : {
82 0 : }
83 :
84 : template <typename data_t>
85 : UpSampling3D<data_t>::UpSampling3D(const std::array<index_t, 3>& size,
86 : Interpolation interpolation, const std::string& name)
87 : : UpSampling<data_t, LayerType::UpSampling3D, 3>(size, interpolation, name)
88 : {
89 : }
90 :
91 : template class Reshape<float>;
92 : template class Flatten<float>;
93 : template class UpSampling<float, LayerType::UpSampling2D, 2>;
94 : template struct UpSampling2D<float>;
95 : } // namespace elsa::ml
|