LCOV - code coverage report
Current view: top level - ml - Loss.h (source / functions) Hit Total Coverage
Test: test_coverage.info.cleaned Lines: 1 2 50.0 %
Date: 2022-07-06 02:47:47 Functions: 1 3 33.3 %

          Line data    Source code
       1             : #pragma once
       2             : 
       3             : #include <numeric>
       4             : #include <vector>
       5             : #include <functional>
       6             : #include <algorithm>
       7             : #include <string>
       8             : 
       9             : #include "DataContainer.h"
      10             : #include "IdenticalBlocksDescriptor.h"
      11             : #include "Common.h"
      12             : #include "Utils.h"
      13             : 
      14             : namespace elsa::ml
      15             : {
      16             :     /// Reduction types for loss functions
      17             :     enum class LossReduction {
      18             :         /// reduce loss by summing up across batches
      19             :         Sum,
      20             :         /// reduce loss by summing up over batches
      21             :         SumOverBatchSize
      22             :     };
      23             : 
      24             :     /// Base class for all loss functions
      25             :     ///
      26             :     /// @author David Tellenbach
      27             :     template <typename data_t = real_t>
      28             :     class Loss
      29             :     {
      30             :     public:
      31             :         /// default constructor
      32           0 :         Loss() = default;
      33             : 
      34             :         /// virtual destructor
      35          18 :         virtual ~Loss() = default;
      36             : 
      37             :         /// @returns the loss between predictions x and labels y
      38             :         data_t getLoss(const DataContainer<data_t>& x, const DataContainer<data_t>& y) const;
      39             : 
      40             :         /// @returns the loss between predictions x and labels y
      41             :         data_t operator()(const DataContainer<data_t>& x, const DataContainer<data_t>& y);
      42             : 
      43             :         /// @returns the loss-gradient between predictions x and labels y
      44             :         DataContainer<data_t> getLossGradient(const DataContainer<data_t>& x,
      45             :                                               const DataContainer<data_t>& y) const;
      46             : 
      47             :         /// @return the name of this loss function
      48             :         std::string getName() const;
      49             : 
      50             :     protected:
      51             :         Loss(LossReduction reduction, const std::string& name);
      52             : 
      53             :         using LossFunctionType = std::function<data_t(LossReduction, DataContainer<data_t> const&,
      54             :                                                       DataContainer<data_t> const&)>;
      55             : 
      56             :         using LossGradientFunctionType = std::function<DataContainer<data_t>(
      57             :             LossReduction, DataContainer<data_t> const&, DataContainer<data_t> const&)>;
      58             : 
      59             :         LossFunctionType lossFunction_;
      60             :         LossGradientFunctionType lossGradientFunction_;
      61             :         LossReduction reduction_;
      62             :         std::string name_;
      63             :     };
      64             : 
      65             :     /// @brief Computes the cross-entropy loss between true labels and predicted
      66             :     /// labels.
      67             :     ///
      68             :     /// @author David Tellenbach
      69             :     ///
      70             :     /// Use this cross-entropy loss when there are only two label classes
      71             :     /// (assumed to be 0 and 1). For each example, there should be a single
      72             :     /// floating-point value per prediction.
      73             :     template <typename data_t = real_t>
      74             :     class BinaryCrossentropy : public Loss<data_t>
      75             :     {
      76             :     public:
      77             :         /// Construct a BinaryCrossEntropy loss by optionally specifying the
      78             :         /// reduction type.
      79             :         explicit BinaryCrossentropy(LossReduction reduction = LossReduction::SumOverBatchSize);
      80             : 
      81             :     private:
      82             :         static data_t lossImpl(LossReduction, const DataContainer<data_t>&,
      83             :                                const DataContainer<data_t>&);
      84             : 
      85             :         static DataContainer<data_t> lossGradientImpl(LossReduction, const DataContainer<data_t>&,
      86             :                                                       const DataContainer<data_t>&);
      87             :     };
      88             : 
      89             :     /// @brief Computes the crossentropy loss between the labels and predictions.
      90             :     ///
      91             :     /// @author David Tellenbach
      92             :     ///
      93             :     /// Use this crossentropy loss function when there are two or more label
      94             :     /// classes. We expect labels to be provided in a one-hot representation.
      95             :     /// If you don't want to use one-hot encoded labels, please use
      96             :     /// SparseCategoricalCrossentropy loss. There should be # classes floating
      97             :     /// point values per feature.
      98             :     template <typename data_t = real_t>
      99             :     class CategoricalCrossentropy : public Loss<data_t>
     100             :     {
     101             :     public:
     102             :         /// Construct a CategoricalCrossentropy loss by optionally specifying the
     103             :         /// reduction type.
     104             :         explicit CategoricalCrossentropy(LossReduction reduction = LossReduction::SumOverBatchSize);
     105             : 
     106             :     private:
     107             :         static data_t lossImpl(LossReduction reduction, const DataContainer<data_t>& x,
     108             :                                const DataContainer<data_t>& y);
     109             : 
     110             :         static DataContainer<data_t> lossGradientImpl(LossReduction reduction,
     111             :                                                       const DataContainer<data_t>& x,
     112             :                                                       const DataContainer<data_t>& y);
     113             :     };
     114             : 
     115             :     /// @brief Computes the crossentropy loss between the labels and predictions.
     116             :     ///
     117             :     /// @author David Tellenbach
     118             :     ///
     119             :     /// Use this crossentropy loss function when there are two or more label
     120             :     /// classes. If you want to provide labels using one-hot representation,
     121             :     /// please use CategoricalCrossentropy loss. There should be # classes
     122             :     /// floating point values per feature for x and a single floating point
     123             :     /// value per feature for y.
     124             :     template <typename data_t = real_t>
     125             :     class SparseCategoricalCrossentropy : public Loss<data_t>
     126             :     {
     127             :     public:
     128             :         /// Construct a SparseCategoricalCrossentropy loss by optionally specifying the
     129             :         /// reduction type.
     130             :         explicit SparseCategoricalCrossentropy(
     131             :             LossReduction reduction = LossReduction::SumOverBatchSize);
     132             : 
     133             :     private:
     134             :         static data_t lossImpl(LossReduction, const DataContainer<data_t>&,
     135             :                                const DataContainer<data_t>&);
     136             : 
     137             :         static DataContainer<data_t> lossGradientImpl(LossReduction reduction,
     138             :                                                       const DataContainer<data_t>& x,
     139             :                                                       const DataContainer<data_t>& y);
     140             :     };
     141             : 
     142             :     /// @brief Computes the mean squared error between labels y and predictions x:
     143             :     ///
     144             :     /// @author David Tellenbach
     145             :     ///
     146             :     /// \f[ \text{loss} = (y - x)^2 \f]
     147             :     template <typename data_t = real_t>
     148             :     class MeanSquaredError : public Loss<data_t>
     149             :     {
     150             :     public:
     151             :         /// Construct a MeanSquaredError loss by optionally specifying the
     152             :         /// reduction type.
     153             :         explicit MeanSquaredError(LossReduction reduction = LossReduction::SumOverBatchSize);
     154             : 
     155             :     private:
     156             :         static data_t lossImpl(LossReduction, const DataContainer<data_t>&,
     157             :                                const DataContainer<data_t>&);
     158             : 
     159             :         static DataContainer<data_t> lossGradientImpl(LossReduction, const DataContainer<data_t>&,
     160             :                                                       const DataContainer<data_t>&);
     161             :     };
     162             : 
     163             : } // namespace elsa::ml

Generated by: LCOV version 1.15