LCOV - code coverage report
Current view: top level - ml - Initializer.cpp (source / functions) Hit Total Coverage
Test: test_coverage.info.cleaned Lines: 22 107 20.6 %
Date: 2022-07-06 02:47:47 Functions: 5 17 29.4 %

          Line data    Source code
       1             : #include "Initializer.h"
       2             : 
       3             : namespace elsa
       4             : {
       5             :     namespace ml
       6             :     {
       7             :         namespace detail
       8             :         {
       9             :             template <typename data_t>
      10             :             std::random_device InitializerImpl<data_t>::randomDevice_{};
      11             : 
      12             :             template <typename data_t>
      13             :             bool InitializerImpl<data_t>::useSeed_ = false;
      14             : 
      15             :             template <typename data_t>
      16             :             uint64_t InitializerImpl<data_t>::seed_ = 1;
      17             : 
      18             :             template <typename data_t>
      19           0 :             void InitializerImpl<data_t>::setSeed(uint64_t seed)
      20             :             {
      21           0 :                 seed_ = seed;
      22           0 :                 useSeed_ = true;
      23           0 :             }
      24             : 
      25             :             template <typename data_t>
      26           0 :             void InitializerImpl<data_t>::clearSeed()
      27             :             {
      28           0 :                 useSeed_ = false;
      29           0 :             }
      30             : 
      31             :             template <typename data_t>
      32           2 :             void InitializerImpl<data_t>::initialize(
      33             :                 data_t* data, index_t size, Initializer initializer,
      34             :                 [[maybe_unused]] const InitializerImpl<data_t>::FanPairType& fanInOut)
      35             :             {
      36           2 :                 switch (initializer) {
      37           1 :                     case Initializer::Ones:
      38           1 :                         InitializerImpl::ones(data, size);
      39           1 :                         return;
      40           1 :                     case Initializer::Zeros:
      41           1 :                         InitializerImpl::zeros(data, size);
      42           1 :                         return;
      43           0 :                     case Initializer::Uniform:
      44           0 :                         InitializerImpl::uniform(data, size, -1, 1);
      45           0 :                         return;
      46           0 :                     case Initializer::GlorotUniform:
      47           0 :                         InitializerImpl::glorotUniform(data, size, fanInOut);
      48           0 :                         return;
      49           0 :                     case Initializer::HeUniform:
      50           0 :                         InitializerImpl::heUniform(data, size, fanInOut);
      51           0 :                         return;
      52           0 :                     case Initializer::TruncatedNormal:
      53           0 :                         InitializerImpl::truncatedNormal(data, size, 0, 1);
      54           0 :                         return;
      55           0 :                     case Initializer::Normal:
      56           0 :                         InitializerImpl::normal(data, size, 0, 1);
      57           0 :                         return;
      58           0 :                     case Initializer::GlorotNormal:
      59           0 :                         InitializerImpl::glorotNormal(data, size, fanInOut);
      60           0 :                         return;
      61           0 :                     case Initializer::RamLak:
      62           0 :                         InitializerImpl::ramlak(data, size);
      63           0 :                         return;
      64           0 :                     default:
      65           0 :                         throw std::invalid_argument("Unkown random initializer");
      66             :                 }
      67             :             }
      68             : 
      69             :             template <typename data_t>
      70           2 :             void InitializerImpl<data_t>::initialize(data_t* data, index_t size,
      71             :                                                      Initializer initializer)
      72             :             {
      73           2 :                 FanPairType fan{0, 0};
      74           2 :                 initialize(data, size, initializer, fan);
      75           2 :             }
      76             : 
      77             :             template <typename data_t>
      78           0 :             std::mt19937_64 InitializerImpl<data_t>::getEngine()
      79             :             {
      80           0 :                 if (useSeed_)
      81           0 :                     return std::mt19937_64(seed_);
      82             :                 else
      83           0 :                     return std::mt19937_64(randomDevice_());
      84             :             }
      85             : 
      86             :             template <typename data_t>
      87           0 :             void InitializerImpl<data_t>::uniform(data_t* data, index_t size, data_t lowerBound,
      88             :                                                   data_t upperBound)
      89             :             {
      90           0 :                 UniformDistributionType dist(lowerBound, upperBound);
      91           0 :                 std::mt19937_64 engine = getEngine();
      92             : 
      93           0 :                 for (index_t i = 0; i < size; ++i)
      94           0 :                     data[i] = dist(engine);
      95           0 :             }
      96             : 
      97             :             template <typename data_t>
      98           0 :             void InitializerImpl<data_t>::uniform(data_t* data, index_t size)
      99             :             {
     100           0 :                 InitializerImpl<data_t>::uniform(data, size, 0, std::numeric_limits<data_t>::max());
     101           0 :             }
     102             : 
     103             :             template <typename data_t>
     104           2 :             void InitializerImpl<data_t>::constant(data_t* data, index_t size, data_t constant)
     105             :             {
     106        2002 :                 for (index_t i = 0; i < size; ++i)
     107        2000 :                     data[i] = constant;
     108           2 :             }
     109             : 
     110             :             template <typename data_t>
     111           1 :             void InitializerImpl<data_t>::ones(data_t* data, index_t size)
     112             :             {
     113           1 :                 constant(data, size, static_cast<data_t>(1));
     114           1 :             }
     115             : 
     116             :             template <typename data_t>
     117           1 :             void InitializerImpl<data_t>::zeros(data_t* data, index_t size)
     118             :             {
     119           1 :                 constant(data, size, static_cast<data_t>(0));
     120           1 :             }
     121             : 
     122             :             template <typename data_t>
     123           0 :             void InitializerImpl<data_t>::glorotUniform(
     124             :                 data_t* data, index_t size, const InitializerImpl<data_t>::FanPairType& fan)
     125             :             {
     126           0 :                 auto bound = static_cast<data_t>(std::sqrt(
     127           0 :                     6 / (static_cast<data_t>(fan.first) + static_cast<data_t>(fan.second))));
     128           0 :                 uniform(data, size, -1 * bound, bound);
     129           0 :             }
     130             : 
     131             :             template <typename data_t>
     132           0 :             void InitializerImpl<data_t>::glorotNormal(
     133             :                 data_t* data, index_t size, const InitializerImpl<data_t>::FanPairType& fan)
     134             :             {
     135           0 :                 auto stddev = static_cast<data_t>(std::sqrt(
     136           0 :                     2 / (static_cast<data_t>(fan.first) + static_cast<data_t>(fan.second))));
     137           0 :                 truncatedNormal(data, size, 0, stddev);
     138           0 :             }
     139             : 
     140             :             template <typename data_t>
     141           0 :             void InitializerImpl<data_t>::normal(data_t* data, index_t size, data_t mean,
     142             :                                                  data_t stddev)
     143             :             {
     144             :                 static_assert(!std::is_same<std::false_type, NormalDistributionType>::value,
     145             :                               "Cannot use normal distribution with the given data-type");
     146             : 
     147           0 :                 NormalDistributionType dist(mean, stddev);
     148           0 :                 std::mt19937_64 engine = getEngine();
     149             : 
     150           0 :                 for (index_t i = 0; i < size; ++i)
     151           0 :                     data[i] = dist(engine);
     152           0 :             }
     153             : 
     154             :             template <typename data_t>
     155           0 :             void InitializerImpl<data_t>::truncatedNormal(data_t* data, index_t size, data_t mean,
     156             :                                                           data_t stddev)
     157             :             {
     158             :                 static_assert(!std::is_same<std::false_type, NormalDistributionType>::value,
     159             :                               "Cannot use normal distribution with the given data-type");
     160             : 
     161           0 :                 NormalDistributionType dist(mean, stddev);
     162           0 :                 std::mt19937_64 engine = getEngine();
     163             : 
     164           0 :                 for (index_t i = 0; i < size; ++i) {
     165           0 :                     auto value = dist(engine);
     166           0 :                     while (std::abs(mean - value) > 2 * stddev) {
     167           0 :                         value = dist(engine);
     168             :                     }
     169           0 :                     data[i] = value;
     170             :                 }
     171           0 :             }
     172             : 
     173             :             template <typename data_t>
     174           0 :             void InitializerImpl<data_t>::heNormal(
     175             :                 data_t* data, index_t size, const InitializerImpl<data_t>::FanPairType& fanInOut)
     176             :             {
     177             :                 auto stddev =
     178           0 :                     std::sqrt(static_cast<data_t>(2) / static_cast<data_t>(fanInOut.first));
     179           0 :                 truncatedNormal(data, size, 0, stddev);
     180           0 :             }
     181             : 
     182             :             template <typename data_t>
     183           0 :             void InitializerImpl<data_t>::heUniform(data_t* data, index_t size,
     184             :                                                     const InitializerImpl<data_t>::FanPairType& fan)
     185             :             {
     186           0 :                 auto bound = static_cast<data_t>(std::sqrt(6 / (static_cast<data_t>(fan.first))));
     187           0 :                 uniform(data, size, -1 * bound, bound);
     188           0 :             }
     189             : 
     190             :             template <typename data_t>
     191           0 :             void InitializerImpl<data_t>::ramlak(data_t* data, index_t size)
     192             :             {
     193           0 :                 const index_t hw = as<index_t>((as<data_t>(size) - 1) / 2);
     194             : 
     195           0 :                 for (index_t i = -hw; i <= hw; ++i) {
     196           0 :                     if ((i % 2) != 0)
     197           0 :                         data[i + hw] =
     198           0 :                             data_t(-1) / (as<data_t>(i) * as<data_t>(i) * pi<data_t> * pi<data_t>);
     199             :                     else
     200           0 :                         data[i + hw] = data_t(0);
     201             :                 }
     202           0 :                 data[hw] = data_t(0.25);
     203           0 :             }
     204             : 
     205             :             template class InitializerImpl<float>;
     206             :         } // namespace detail
     207             :     }     // namespace ml
     208             : } // namespace elsa

Generated by: LCOV version 1.15