forked from SciSharp/TensorFlow.NET
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathBatchDataset.cs
More file actions
38 lines (34 loc) · 1.17 KB
/
BatchDataset.cs
File metadata and controls
38 lines (34 loc) · 1.17 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
using System;
using System.Linq;
using static Tensorflow.Binding;
namespace Tensorflow
{
/// <summary>
/// A `Dataset` that batches contiguous elements from its input.
/// </summary>
public class BatchDataset : UnaryDataset
{
Tensor _batch_size;
Tensor _drop_remainder;
public BatchDataset(IDatasetV2 input_dataset, int batch_size, bool drop_remainder = false) :
base(input_dataset)
{
_input_dataset = input_dataset;
_batch_size = tf.convert_to_tensor(batch_size, dtype: TF_DataType.TF_INT64, name: "batch_size");
_drop_remainder = tf.convert_to_tensor(drop_remainder, dtype: TF_DataType.TF_BOOL, name: "drop_remainder");
if (drop_remainder)
{
throw new NotImplementedException("");
}
else
{
structure = input_dataset.element_spec.Select(x => x._batch(-1)).ToArray();
}
variant_tensor = ops.batch_dataset_v2(input_dataset.variant_tensor,
_batch_size,
_drop_remainder,
output_types,
output_shapes);
}
}
}