Line data Source code
1 : #pragma once 2 : 3 : #include <numeric> 4 : #include <vector> 5 : #include <functional> 6 : #include <algorithm> 7 : #include <string> 8 : 9 : #include "DataContainer.h" 10 : #include "IdenticalBlocksDescriptor.h" 11 : #include "Common.h" 12 : #include "Utils.h" 13 : 14 : namespace elsa::ml 15 : { 16 : /// Reduction types for loss functions 17 : enum class LossReduction { 18 : /// reduce loss by summing up across batches 19 : Sum, 20 : /// reduce loss by summing up over batches 21 : SumOverBatchSize 22 : }; 23 : 24 : /// Base class for all loss functions 25 : /// 26 : /// @author David Tellenbach 27 : template <typename data_t = real_t> 28 : class Loss 29 : { 30 : public: 31 : /// default constructor 32 0 : Loss() = default; 33 : 34 : /// virtual destructor 35 18 : virtual ~Loss() = default; 36 : 37 : /// @returns the loss between predictions x and labels y 38 : data_t getLoss(const DataContainer<data_t>& x, const DataContainer<data_t>& y) const; 39 : 40 : /// @returns the loss between predictions x and labels y 41 : data_t operator()(const DataContainer<data_t>& x, const DataContainer<data_t>& y); 42 : 43 : /// @returns the loss-gradient between predictions x and labels y 44 : DataContainer<data_t> getLossGradient(const DataContainer<data_t>& x, 45 : const DataContainer<data_t>& y) const; 46 : 47 : /// @return the name of this loss function 48 : std::string getName() const; 49 : 50 : protected: 51 : Loss(LossReduction reduction, const std::string& name); 52 : 53 : using LossFunctionType = std::function<data_t(LossReduction, DataContainer<data_t> const&, 54 : DataContainer<data_t> const&)>; 55 : 56 : using LossGradientFunctionType = std::function<DataContainer<data_t>( 57 : LossReduction, DataContainer<data_t> const&, DataContainer<data_t> const&)>; 58 : 59 : LossFunctionType lossFunction_; 60 : LossGradientFunctionType lossGradientFunction_; 61 : LossReduction reduction_; 62 : std::string name_; 63 : }; 64 : 65 : /// @brief Computes the cross-entropy loss between true labels and predicted 66 : /// labels. 67 : /// 68 : /// @author David Tellenbach 69 : /// 70 : /// Use this cross-entropy loss when there are only two label classes 71 : /// (assumed to be 0 and 1). For each example, there should be a single 72 : /// floating-point value per prediction. 73 : template <typename data_t = real_t> 74 : class BinaryCrossentropy : public Loss<data_t> 75 : { 76 : public: 77 : /// Construct a BinaryCrossEntropy loss by optionally specifying the 78 : /// reduction type. 79 : explicit BinaryCrossentropy(LossReduction reduction = LossReduction::SumOverBatchSize); 80 : 81 : private: 82 : static data_t lossImpl(LossReduction, const DataContainer<data_t>&, 83 : const DataContainer<data_t>&); 84 : 85 : static DataContainer<data_t> lossGradientImpl(LossReduction, const DataContainer<data_t>&, 86 : const DataContainer<data_t>&); 87 : }; 88 : 89 : /// @brief Computes the crossentropy loss between the labels and predictions. 90 : /// 91 : /// @author David Tellenbach 92 : /// 93 : /// Use this crossentropy loss function when there are two or more label 94 : /// classes. We expect labels to be provided in a one-hot representation. 95 : /// If you don't want to use one-hot encoded labels, please use 96 : /// SparseCategoricalCrossentropy loss. There should be # classes floating 97 : /// point values per feature. 98 : template <typename data_t = real_t> 99 : class CategoricalCrossentropy : public Loss<data_t> 100 : { 101 : public: 102 : /// Construct a CategoricalCrossentropy loss by optionally specifying the 103 : /// reduction type. 104 : explicit CategoricalCrossentropy(LossReduction reduction = LossReduction::SumOverBatchSize); 105 : 106 : private: 107 : static data_t lossImpl(LossReduction reduction, const DataContainer<data_t>& x, 108 : const DataContainer<data_t>& y); 109 : 110 : static DataContainer<data_t> lossGradientImpl(LossReduction reduction, 111 : const DataContainer<data_t>& x, 112 : const DataContainer<data_t>& y); 113 : }; 114 : 115 : /// @brief Computes the crossentropy loss between the labels and predictions. 116 : /// 117 : /// @author David Tellenbach 118 : /// 119 : /// Use this crossentropy loss function when there are two or more label 120 : /// classes. If you want to provide labels using one-hot representation, 121 : /// please use CategoricalCrossentropy loss. There should be # classes 122 : /// floating point values per feature for x and a single floating point 123 : /// value per feature for y. 124 : template <typename data_t = real_t> 125 : class SparseCategoricalCrossentropy : public Loss<data_t> 126 : { 127 : public: 128 : /// Construct a SparseCategoricalCrossentropy loss by optionally specifying the 129 : /// reduction type. 130 : explicit SparseCategoricalCrossentropy( 131 : LossReduction reduction = LossReduction::SumOverBatchSize); 132 : 133 : private: 134 : static data_t lossImpl(LossReduction, const DataContainer<data_t>&, 135 : const DataContainer<data_t>&); 136 : 137 : static DataContainer<data_t> lossGradientImpl(LossReduction reduction, 138 : const DataContainer<data_t>& x, 139 : const DataContainer<data_t>& y); 140 : }; 141 : 142 : /// @brief Computes the mean squared error between labels y and predictions x: 143 : /// 144 : /// @author David Tellenbach 145 : /// 146 : /// \f[ \text{loss} = (y - x)^2 \f] 147 : template <typename data_t = real_t> 148 : class MeanSquaredError : public Loss<data_t> 149 : { 150 : public: 151 : /// Construct a MeanSquaredError loss by optionally specifying the 152 : /// reduction type. 153 : explicit MeanSquaredError(LossReduction reduction = LossReduction::SumOverBatchSize); 154 : 155 : private: 156 : static data_t lossImpl(LossReduction, const DataContainer<data_t>&, 157 : const DataContainer<data_t>&); 158 : 159 : static DataContainer<data_t> lossGradientImpl(LossReduction, const DataContainer<data_t>&, 160 : const DataContainer<data_t>&); 161 : }; 162 : 163 : } // namespace elsa::ml