forked from SciSharp/TensorFlow.NET
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathRnnUtils.cs
More file actions
103 lines (96 loc) · 4.64 KB
/
RnnUtils.cs
File metadata and controls
103 lines (96 loc) · 4.64 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Text;
using Tensorflow.Common.Types;
using Tensorflow.Keras.Layers;
using Tensorflow.Common.Extensions;
namespace Tensorflow.Keras.Utils
{
internal static class RnnUtils
{
internal static Tensors generate_zero_filled_state(Tensor batch_size_tensor, INestStructure<long> state_size, TF_DataType dtype)
{
Func<long, Tensor> create_zeros = (unnested_state_size) =>
{
var flat_dims = new Shape(unnested_state_size).dims;
var init_state_size = new Tensor[] { batch_size_tensor }.
Concat(flat_dims.Select(x => tf.constant(x, dtypes.int32))).ToArray();
return array_ops.zeros(init_state_size, dtype: dtype);
};
// TODO(Rinne): map structure with nested tensors.
if(state_size.TotalNestedCount > 1)
{
return new Tensors(state_size.Flatten().Select(s => create_zeros(s)).ToArray());
}
else
{
return create_zeros(state_size.Flatten().First());
}
}
internal static Tensors generate_zero_filled_state_for_cell(IRnnCell cell, Tensors inputs, Tensor batch_size, TF_DataType dtype)
{
if (inputs is not null)
{
batch_size = array_ops.shape(inputs)[0];
dtype = inputs.dtype;
}
return generate_zero_filled_state(batch_size, cell.StateSize, dtype);
}
/// <summary>
/// Standardizes `__call__` to a single list of tensor inputs.
///
/// When running a model loaded from a file, the input tensors
/// `initial_state` and `constants` can be passed to `RNN.__call__()` as part
/// of `inputs` instead of by the dedicated keyword arguments.This method
/// makes sure the arguments are separated and that `initial_state` and
/// `constants` are lists of tensors(or None).
/// </summary>
/// <param name="inputs">Tensor or list/tuple of tensors. which may include constants
/// and initial states.In that case `num_constant` must be specified.</param>
/// <param name="initial_state">Tensor or list of tensors or None, initial states.</param>
/// <param name="constants">Tensor or list of tensors or None, constant tensors.</param>
/// <param name="num_constants">Expected number of constants (if constants are passed as
/// part of the `inputs` list.</param>
/// <returns></returns>
internal static (Tensors, Tensors, Tensors) standardize_args(Tensors inputs, Tensors initial_state, Tensors constants, int num_constants)
{
if(inputs.Length > 1)
{
// There are several situations here:
// In the graph mode, __call__ will be only called once. The initial_state
// and constants could be in inputs (from file loading).
// In the eager mode, __call__ will be called twice, once during
// rnn_layer(inputs=input_t, constants=c_t, ...), and second time will be
// model.fit/train_on_batch/predict with real np data. In the second case,
// the inputs will contain initial_state and constants as eager tensor.
//
// For either case, the real input is the first item in the list, which
// could be a nested structure itself. Then followed by initial_states, which
// could be a list of items, or list of list if the initial_state is complex
// structure, and finally followed by constants which is a flat list.
Debug.Assert(initial_state is null && constants is null);
if(num_constants > 0)
{
constants = inputs.TakeLast(num_constants).ToArray().ToTensors();
inputs = inputs.SkipLast(num_constants).ToArray().ToTensors();
}
if(inputs.Length > 1)
{
initial_state = inputs.Skip(1).ToArray().ToTensors();
inputs = inputs.Take(1).ToArray().ToTensors();
}
}
return (inputs, initial_state, constants);
}
/// <summary>
/// Check whether the state_size contains multiple states.
/// </summary>
/// <param name="state_size"></param>
/// <returns></returns>
public static bool is_multiple_state(INestStructure<long> state_size)
{
return state_size.TotalNestedCount > 1;
}
}
}