Line data Source code
1 : #pragma once
2 :
3 : #include <string>
4 :
5 : #include "VolumeDescriptor.h"
6 : #include "TypeCasts.hpp"
7 :
8 : namespace elsa
9 : {
10 : namespace ml
11 : {
12 : /// Initializer that can be used to initialize trainable parameters in a network layer
13 : enum class Initializer {
14 : /**
15 : * Ones initialization
16 : *
17 : * Initialize data with \f$ 1 \f$
18 : */
19 : Ones,
20 :
21 : /**
22 : * Zeros initialization
23 : *
24 : * Initialize data with \f$ 0 \f$
25 : */
26 : Zeros,
27 :
28 : /**
29 : * Uniform initialization
30 : *
31 : * Initialize data with random samples from a uniform distribution in
32 : * the interval \f$ [-1, 1 ] \f$.
33 : */
34 : Uniform,
35 :
36 : /**
37 : * Normal initialization
38 : *
39 : * Initialize data with random samples from a standard normal
40 : * distribution, i.e., a normal distribution with mean 0 and
41 : * standard deviation \f$ 1 \f$.
42 : */
43 : Normal,
44 :
45 : /**
46 : * Truncated normal initialization
47 : *
48 : * Initialize data with random samples from a truncated standard normal
49 : * distribution, i.e., a normal distribution with mean 0 and standard deviation \f$ 1
50 : * \f$ where values with a distance of greater than \f$ 2 \times \f$ standard deviations
51 : * from the mean are discarded.
52 : */
53 : TruncatedNormal,
54 :
55 : /**
56 : * Glorot uniform initialization
57 : *
58 : * Initialize a data container with a random samples from a uniform
59 : * distribution on the interval
60 : * \f$ \left [ - \sqrt{\frac{6}{\text{fanIn} + \text{fanOut}}} ,
61 : * \sqrt{\frac{6}{\text{fanIn} + \text{fanOut}}} \right ] \f$
62 : */
63 : GlorotUniform,
64 :
65 : /**
66 : * Glorot normal initialization
67 : *
68 : * Initialize data with random samples from a truncated normal distribution
69 : * with mean \f$ 0 \f$ and stddev \f$ \sqrt{ \frac{2}{\text{fanIn} + \text{fanOut}}}
70 : * \f$.
71 : */
72 : GlorotNormal,
73 :
74 : /**
75 : * He normal initialization
76 : *
77 : * Initialize data with random samples from a truncated normal distribution
78 : * with mean \f$ 0 \f$ and stddev \f$ \sqrt{\frac{2}{\text{fanIn}}} \f$
79 : */
80 : HeNormal,
81 :
82 : /**
83 : * He uniform initialization
84 : *
85 : * Initialize a data container with a random samples from a uniform
86 : * distribution on the interval \f$ \left [ - \sqrt{\frac{6}{\text{fanIn}}} ,
87 : * \sqrt{\frac{6}{\text{fanIn}}} \right ] \f$
88 : */
89 : HeUniform,
90 :
91 : /**
92 : * RamLak filter initialization
93 : *
94 : * Initialize data with values of the RamLak filter, the discrete
95 : * version of the Ramp filter in the spatial domain.
96 : *
97 : * Values for this initialization are given by the following
98 : * equation:
99 : *
100 : * \f[
101 : * \text{data}[i] = \begin{cases}
102 : * \frac{1}{i^2 \pi^2}, & i \text{ even} \\
103 : * \frac 14, & i = \frac{\text{size}-1}{2} \\
104 : * 0, & i \text{ odd}.
105 : * \end{cases}
106 : * \f]
107 : */
108 : RamLak
109 : };
110 :
111 : /// Padding type for Pooling and Convolutional layers
112 : enum class Padding {
113 : /// Do not pad the input
114 : Valid,
115 : /// Pad the input such that the output shape matches the input shape.
116 : Same
117 : };
118 :
119 : /// Backend to execute model primitives.
120 : enum class MlBackend {
121 : /// Automatically choose the fastest backend available.
122 : Auto,
123 : /// Use the Dnnl, aka. OneDNN backend which is optimized for CPUs.
124 : Dnnl,
125 : /// Use the Cudnn backend which is optimized for Nvidia GPUs.
126 : Cudnn
127 : };
128 :
129 : /// Type of the interpolation for Upsampling
130 : enum class Interpolation {
131 : /// Perform nearest neighbour interpolarion
132 : NearestNeighbour,
133 : /// Perform bilinear interpolarion
134 : Bilinear
135 : };
136 :
137 : /// type of a network layer
138 : enum class LayerType {
139 : Undefined,
140 : Input,
141 : Dense,
142 : Activation,
143 : Sigmoid,
144 : Relu,
145 : Tanh,
146 : ClippedRelu,
147 : Elu,
148 : Identity,
149 : Conv1D,
150 : Conv2D,
151 : Conv3D,
152 : Conv2DTranspose,
153 : Conv3DTranspose,
154 : MaxPooling1D,
155 : MaxPooling2D,
156 : MaxPooling3D,
157 : AveragePooling1D,
158 : AveragePooling2D,
159 : AveragePooling3D,
160 : Sum,
161 : Concatenate,
162 : Reshape,
163 : Flatten,
164 : Softmax,
165 : UpSampling1D,
166 : UpSampling2D,
167 : UpSampling3D,
168 : Projector
169 : };
170 :
171 : /// Direction of data propagation through a network
172 : enum class PropagationKind {
173 : /// perform a forward propagation
174 : Forward,
175 : /// perform a backward propagation
176 : Backward,
177 : // perform both, forward and backward propagation
178 : Full
179 : };
180 :
181 : /// Activation function for Dense and Convolutional layers
182 : enum class Activation {
183 : /// Sigmoid activation function
184 : Sigmoid,
185 : /// Relu activation function
186 : Relu,
187 : /// Clipped Relu activation function.
188 : ClippedRelu,
189 : /// Hyperbolic tangent activation function.
190 : Tanh,
191 : /// Exponential Linear Unit.
192 : Elu,
193 : /// Identity activation
194 : Identity
195 : };
196 :
197 : namespace detail
198 : {
199 : std::string getEnumMemberAsString(LayerType);
200 : std::string getEnumMemberAsString(Initializer);
201 : std::string getEnumMemberAsString(MlBackend);
202 : std::string getEnumMemberAsString(PropagationKind);
203 :
204 : template <typename Derived, typename Base, typename Del>
205 : static std::unique_ptr<Derived, Del>
206 : static_unique_ptr_cast(std::unique_ptr<Base, Del>&& p)
207 : {
208 : auto d = static_cast<Derived*>(p.release());
209 : return std::unique_ptr<Derived, Del>(d, std::move(p.get_deleter()));
210 : }
211 :
212 : template <typename Derived, typename Base, typename Del>
213 : static std::unique_ptr<Derived, Del>
214 31 : dynamic_unique_ptr_cast(std::unique_ptr<Base, Del>&& p)
215 : {
216 31 : if (Derived* result = dynamic_cast<Derived*>(p.get())) {
217 31 : p.release();
218 31 : return std::unique_ptr<Derived, Del>(result, std::move(p.get_deleter()));
219 : }
220 0 : return std::unique_ptr<Derived, Del>(nullptr, p.get_deleter());
221 : }
222 : } // namespace detail
223 : } // namespace ml
224 : } // namespace elsa
|