forked from SciSharp/TensorFlow.NET
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathHuber.cs
More file actions
29 lines (26 loc) · 1.25 KB
/
Huber.cs
File metadata and controls
29 lines (26 loc) · 1.25 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
namespace Tensorflow.Keras.Losses;
public class Huber : LossFunctionWrapper
{
protected Tensor delta = tf.Variable(1.0);
public Huber(
string reduction = null,
Tensor delta = null,
string name = null) :
base(reduction: reduction, name: name == null ? "huber" : name)
{
this.delta = delta == null ? this.delta : delta;
}
public override Tensor Apply(Tensor y_true = null, Tensor y_pred = null, bool from_logits = false, int axis = -1)
{
Tensor y_pred_cast = math_ops.cast(y_pred, dtype: TF_DataType.TF_FLOAT);
Tensor y_true_cast = math_ops.cast(y_true, dtype: TF_DataType.TF_FLOAT);
Tensor delta = math_ops.cast(this.delta, dtype: TF_DataType.TF_FLOAT);
Tensor error = math_ops.subtract(y_pred_cast, y_true_cast);
Tensor abs_error = math_ops.abs(error);
Tensor half = ops.convert_to_tensor(0.5, dtype: abs_error.dtype);
return gen_math_ops.mean(array_ops.where_v2(abs_error <= delta,
half * math_ops.pow(error, 2),
half * math_ops.pow(delta, 2) + delta * (abs_error - delta)),
ops.convert_to_tensor(-1));
}
}