forked from SciSharp/TensorFlow.NET
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathDatasets.cs
More file actions
43 lines (37 loc) · 1.31 KB
/
Datasets.cs
File metadata and controls
43 lines (37 loc) · 1.31 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
using Tensorflow.NumPy;
namespace Tensorflow
{
public class Datasets<TDataSet> where TDataSet : IDataSet
{
public TDataSet Train { get; private set; }
public TDataSet Validation { get; private set; }
public TDataSet Test { get; private set; }
public Datasets(TDataSet train, TDataSet validation, TDataSet test)
{
Train = train;
Validation = validation;
Test = test;
}
public (NDArray, NDArray) Randomize(NDArray x, NDArray y)
{
var perm = np.random.permutation((int)y.dims[0]);
np.random.shuffle(perm);
return (x[perm], y[perm]);
}
/// <summary>
/// selects a few number of images determined by the batch_size variable (if you don't know why, read about Stochastic Gradient Method)
/// </summary>
/// <param name="x"></param>
/// <param name="y"></param>
/// <param name="start"></param>
/// <param name="end"></param>
/// <returns></returns>
public (NDArray, NDArray) GetNextBatch(NDArray x, NDArray y, int start, int end)
{
var slice = new Slice(start, end);
var x_batch = x[slice];
var y_batch = y[slice];
return (x_batch, y_batch);
}
}
}