LCOV - code coverage report
Current view: top level - ml - Common.h (source / functions) Hit Total Coverage
Test: test_coverage.info.cleaned Lines: 4 5 80.0 %
Date: 2021-06-21 02:59:00 Functions: 1 1 100.0 %

          Line data    Source code
       1             : #pragma once
       2             : 
       3             : #include <string>
       4             : 
       5             : #include "VolumeDescriptor.h"
       6             : #include "TypeCasts.hpp"
       7             : 
       8             : namespace elsa
       9             : {
      10             :     namespace ml
      11             :     {
      12             :         /// Initializer that can be used to initialize trainable parameters in a network layer
      13             :         enum class Initializer {
      14             :             /**
      15             :              * Ones initialization
      16             :              *
      17             :              * Initialize data with \f$ 1 \f$
      18             :              */
      19             :             Ones,
      20             : 
      21             :             /**
      22             :              * Zeros initialization
      23             :              *
      24             :              * Initialize data with \f$ 0 \f$
      25             :              */
      26             :             Zeros,
      27             : 
      28             :             /**
      29             :              * Uniform initialization
      30             :              *
      31             :              * Initialize data with random samples from a uniform distribution in
      32             :              * the interval \f$ [-1, 1 ] \f$.
      33             :              */
      34             :             Uniform,
      35             : 
      36             :             /**
      37             :              * Normal initialization
      38             :              *
      39             :              * Initialize data with random samples from a standard normal
      40             :              * distribution, i.e., a normal distribution with mean 0 and
      41             :              * standard deviation \f$ 1 \f$.
      42             :              */
      43             :             Normal,
      44             : 
      45             :             /**
      46             :              * Truncated normal initialization
      47             :              *
      48             :              * Initialize data with random samples from a truncated standard normal
      49             :              * distribution, i.e., a normal distribution with mean 0 and standard deviation \f$ 1
      50             :              * \f$ where values with a distance of greater than \f$ 2 \times \f$ standard deviations
      51             :              * from the mean are discarded.
      52             :              */
      53             :             TruncatedNormal,
      54             : 
      55             :             /**
      56             :              * Glorot uniform initialization
      57             :              *
      58             :              * Initialize a data container with a random samples from a uniform
      59             :              * distribution on the interval
      60             :              * \f$ \left [ - \sqrt{\frac{6}{\text{fanIn} + \text{fanOut}}} ,
      61             :              * \sqrt{\frac{6}{\text{fanIn} + \text{fanOut}}} \right ] \f$
      62             :              */
      63             :             GlorotUniform,
      64             : 
      65             :             /**
      66             :              * Glorot normal initialization
      67             :              *
      68             :              * Initialize data with random samples from a truncated normal distribution
      69             :              * with mean \f$ 0 \f$ and stddev \f$ \sqrt{ \frac{2}{\text{fanIn} + \text{fanOut}}}
      70             :              * \f$.
      71             :              */
      72             :             GlorotNormal,
      73             : 
      74             :             /**
      75             :              * He normal initialization
      76             :              *
      77             :              * Initialize data with random samples from a truncated normal distribution
      78             :              * with mean \f$ 0 \f$ and stddev \f$ \sqrt{\frac{2}{\text{fanIn}}} \f$
      79             :              */
      80             :             HeNormal,
      81             : 
      82             :             /**
      83             :              * He uniform initialization
      84             :              *
      85             :              * Initialize a data container with a random samples from a uniform
      86             :              * distribution on the interval \f$  \left [ - \sqrt{\frac{6}{\text{fanIn}}} ,
      87             :              * \sqrt{\frac{6}{\text{fanIn}}} \right ] \f$
      88             :              */
      89             :             HeUniform,
      90             : 
      91             :             /**
      92             :              * RamLak filter initialization
      93             :              *
      94             :              * Initialize data with values of the RamLak filter, the discrete
      95             :              * version of the Ramp filter in the spatial domain.
      96             :              *
      97             :              * Values for this initialization are given by the following
      98             :              * equation:
      99             :              *
     100             :              * \f[
     101             :              *  \text{data}[i] = \begin{cases}
     102             :              *  \frac{1}{i^2 \pi^2}, & i \text{ even} \\
     103             :              *  \frac 14, & i = \frac{\text{size}-1}{2} \\
     104             :              *  0, & i \text{ odd}.
     105             :              * \end{cases}
     106             :              * \f]
     107             :              */
     108             :             RamLak
     109             :         };
     110             : 
     111             :         /// Padding type for Pooling and Convolutional layers
     112             :         enum class Padding {
     113             :             /// Do not pad the input
     114             :             Valid,
     115             :             /// Pad the input such that the output shape matches the input shape.
     116             :             Same
     117             :         };
     118             : 
     119             :         /// Backend to execute model primitives.
     120             :         enum class MlBackend {
     121             :             /// Automatically choose the fastest backend available.
     122             :             Auto,
     123             :             /// Use the Dnnl, aka. OneDNN backend which is optimized for CPUs.
     124             :             Dnnl,
     125             :             /// Use the Cudnn backend which is optimized for Nvidia GPUs.
     126             :             Cudnn
     127             :         };
     128             : 
     129             :         /// Type of the interpolation for Upsampling
     130             :         enum class Interpolation {
     131             :             /// Perform nearest neighbour interpolarion
     132             :             NearestNeighbour,
     133             :             /// Perform bilinear interpolarion
     134             :             Bilinear
     135             :         };
     136             : 
     137             :         /// type of a network layer
     138             :         enum class LayerType {
     139             :             Undefined,
     140             :             Input,
     141             :             Dense,
     142             :             Activation,
     143             :             Sigmoid,
     144             :             Relu,
     145             :             Tanh,
     146             :             ClippedRelu,
     147             :             Elu,
     148             :             Identity,
     149             :             Conv1D,
     150             :             Conv2D,
     151             :             Conv3D,
     152             :             Conv2DTranspose,
     153             :             Conv3DTranspose,
     154             :             MaxPooling1D,
     155             :             MaxPooling2D,
     156             :             MaxPooling3D,
     157             :             AveragePooling1D,
     158             :             AveragePooling2D,
     159             :             AveragePooling3D,
     160             :             Sum,
     161             :             Concatenate,
     162             :             Reshape,
     163             :             Flatten,
     164             :             Softmax,
     165             :             UpSampling1D,
     166             :             UpSampling2D,
     167             :             UpSampling3D,
     168             :             Projector
     169             :         };
     170             : 
     171             :         /// Direction of data propagation through a network
     172             :         enum class PropagationKind {
     173             :             /// perform a forward propagation
     174             :             Forward,
     175             :             /// perform a backward propagation
     176             :             Backward,
     177             :             // perform both, forward and backward propagation
     178             :             Full
     179             :         };
     180             : 
     181             :         /// Activation function for Dense and Convolutional layers
     182             :         enum class Activation {
     183             :             /// Sigmoid activation function
     184             :             Sigmoid,
     185             :             /// Relu activation function
     186             :             Relu,
     187             :             /// Clipped Relu activation function.
     188             :             ClippedRelu,
     189             :             /// Hyperbolic tangent activation function.
     190             :             Tanh,
     191             :             /// Exponential Linear Unit.
     192             :             Elu,
     193             :             /// Identity activation
     194             :             Identity
     195             :         };
     196             : 
     197             :         namespace detail
     198             :         {
     199             :             std::string getEnumMemberAsString(LayerType);
     200             :             std::string getEnumMemberAsString(Initializer);
     201             :             std::string getEnumMemberAsString(MlBackend);
     202             :             std::string getEnumMemberAsString(PropagationKind);
     203             : 
     204             :             template <typename Derived, typename Base, typename Del>
     205             :             static std::unique_ptr<Derived, Del>
     206             :                 static_unique_ptr_cast(std::unique_ptr<Base, Del>&& p)
     207             :             {
     208             :                 auto d = static_cast<Derived*>(p.release());
     209             :                 return std::unique_ptr<Derived, Del>(d, std::move(p.get_deleter()));
     210             :             }
     211             : 
     212             :             template <typename Derived, typename Base, typename Del>
     213             :             static std::unique_ptr<Derived, Del>
     214          31 :                 dynamic_unique_ptr_cast(std::unique_ptr<Base, Del>&& p)
     215             :             {
     216          31 :                 if (Derived* result = dynamic_cast<Derived*>(p.get())) {
     217          31 :                     p.release();
     218          31 :                     return std::unique_ptr<Derived, Del>(result, std::move(p.get_deleter()));
     219             :                 }
     220           0 :                 return std::unique_ptr<Derived, Del>(nullptr, p.get_deleter());
     221             :             }
     222             :         } // namespace detail
     223             :     }     // namespace ml
     224             : } // namespace elsa

Generated by: LCOV version 1.15