forked from SciSharp/TensorFlow.NET
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathMnistDataSet.cs
More file actions
86 lines (76 loc) · 3.05 KB
/
MnistDataSet.cs
File metadata and controls
86 lines (76 loc) · 3.05 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Text;
using NumSharp;
using Tensorflow;
namespace Tensorflow.Hub
{
public class MnistDataSet : DataSetBase
{
public int NumOfExamples { get; private set; }
public int EpochsCompleted { get; private set; }
public int IndexInEpoch { get; private set; }
public MnistDataSet(NDArray images, NDArray labels, Type dataType, bool reshape)
{
EpochsCompleted = 0;
IndexInEpoch = 0;
NumOfExamples = images.shape[0];
images = images.reshape(images.shape[0], images.shape[1] * images.shape[2]);
images = images.astype(dataType);
// for debug np.multiply performance
var sw = new Stopwatch();
sw.Start();
images = np.multiply(images, 1.0f / 255.0f);
sw.Stop();
Console.WriteLine($"{sw.ElapsedMilliseconds}ms");
Data = images;
labels = labels.astype(dataType);
Labels = labels;
}
public (NDArray, NDArray) GetNextBatch(int batch_size, bool fake_data = false, bool shuffle = true)
{
var start = IndexInEpoch;
// Shuffle for the first epoch
if(EpochsCompleted == 0 && start == 0 && shuffle)
{
var perm0 = np.arange(NumOfExamples);
np.random.shuffle(perm0);
Data = Data[perm0];
Labels = Labels[perm0];
}
// Go to the next epoch
if (start + batch_size > NumOfExamples)
{
// Finished epoch
EpochsCompleted += 1;
// Get the rest examples in this epoch
var rest_num_examples = NumOfExamples - start;
//var images_rest_part = _images[np.arange(start, _num_examples)];
//var labels_rest_part = _labels[np.arange(start, _num_examples)];
// Shuffle the data
if (shuffle)
{
var perm = np.arange(NumOfExamples);
np.random.shuffle(perm);
Data = Data[perm];
Labels = Labels[perm];
}
start = 0;
IndexInEpoch = batch_size - rest_num_examples;
var end = IndexInEpoch;
var images_new_part = Data[np.arange(start, end)];
var labels_new_part = Labels[np.arange(start, end)];
/*return (np.concatenate(new float[][] { images_rest_part.Data<float>(), images_new_part.Data<float>() }, axis: 0),
np.concatenate(new float[][] { labels_rest_part.Data<float>(), labels_new_part.Data<float>() }, axis: 0));*/
return (images_new_part, labels_new_part);
}
else
{
IndexInEpoch += batch_size;
var end = IndexInEpoch;
return (Data[np.arange(start, end)], Labels[np.arange(start, end)]);
}
}
}
}