X Tutup
Skip to content

Commit ecbda0c

Browse files
committed
add RefVariable.read_value()
1 parent af73e3c commit ecbda0c

File tree

8 files changed

+67
-8
lines changed

8 files changed

+67
-8
lines changed

src/TensorFlowNET.Core/APIs/tf.train.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ public Optimizer GradientDescentOptimizer(float learning_rate)
3131
public Optimizer AdamOptimizer(float learning_rate, string name = "Adam")
3232
=> new AdamOptimizer(learning_rate, name: name);
3333

34-
public object ExponentialMovingAverage(float decay)
34+
public ExponentialMovingAverage ExponentialMovingAverage(float decay)
3535
=> new ExponentialMovingAverage(decay);
3636

3737
public Saver Saver(VariableV1[] var_list = null) => new Saver(var_list: var_list);

src/TensorFlowNET.Core/Train/ExponentialMovingAverage.cs

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
using System;
22
using System.Collections.Generic;
3+
using System.Linq;
34
using System.Text;
5+
using static Tensorflow.Binding;
46

57
namespace Tensorflow.Train
68
{
@@ -11,6 +13,7 @@ public class ExponentialMovingAverage
1113
bool _zero_debias;
1214
string _name;
1315
public string name => _name;
16+
List<VariableV1> _averages;
1417

1518
public ExponentialMovingAverage(float decay, int? num_updates = null, bool zero_debias = false,
1619
string name = "ExponentialMovingAverage")
@@ -19,18 +22,31 @@ public ExponentialMovingAverage(float decay, int? num_updates = null, bool zero_
1922
_num_updates = num_updates;
2023
_zero_debias = zero_debias;
2124
_name = name;
25+
_averages = new List<VariableV1>();
2226
}
2327

2428
/// <summary>
2529
/// Maintains moving averages of variables.
2630
/// </summary>
2731
/// <param name="var_list"></param>
2832
/// <returns></returns>
29-
public Operation apply(VariableV1[] var_list = null)
33+
public Operation apply(RefVariable[] var_list = null)
3034
{
31-
throw new NotImplementedException("");
32-
}
35+
if (var_list == null)
36+
var_list = variables.trainable_variables() as RefVariable[];
3337

38+
foreach(var var in var_list)
39+
{
40+
if (!_averages.Contains(var))
41+
{
42+
ops.init_scope();
43+
var slot = new SlotCreator();
44+
var.initialized_value();
45+
// var avg = slot.create_zeros_slot
46+
}
47+
}
3448

49+
throw new NotImplementedException("");
50+
}
3551
}
3652
}

src/TensorFlowNET.Core/Variables/RefVariable.cs

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,5 +308,28 @@ public RefVariable from_proto(VariableDef proto, string import_scope)
308308
{
309309
throw new NotImplementedException();
310310
}
311+
312+
/// <summary>
313+
/// Returns the value of this variable, read in the current context.
314+
/// </summary>
315+
/// <returns></returns>
316+
private ITensorOrOperation read_value()
317+
{
318+
return array_ops.identity(_variable, name: "read");
319+
}
320+
321+
public Tensor is_variable_initialized(RefVariable variable)
322+
{
323+
return state_ops.is_variable_initialized(variable);
324+
}
325+
326+
public Tensor initialized_value()
327+
{
328+
ops.init_scope();
329+
throw new NotImplementedException("");
330+
/*return control_flow_ops.cond(is_variable_initialized(this),
331+
read_value,
332+
() => initial_value);*/
333+
}
311334
}
312335
}

src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ You may obtain a copy of the License at
1414
limitations under the License.
1515
******************************************************************************/
1616

17+
using System;
1718
using System.Collections.Generic;
1819
using Tensorflow.Eager;
1920

@@ -145,5 +146,10 @@ public static Tensor scatter_add(RefVariable @ref, Tensor indices, Tensor update
145146
var _op = _op_def_lib._apply_op_helper("ScatterAdd", name: name, args: new { @ref, indices, updates, use_locking });
146147
return _op.outputs[0];
147148
}
149+
150+
public static Tensor is_variable_initialized(RefVariable @ref, string name = null)
151+
{
152+
throw new NotImplementedException("");
153+
}
148154
}
149155
}

src/TensorFlowNET.Core/Variables/state_ops.cs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,5 +106,13 @@ public static Tensor scatter_add(RefVariable @ref, Tensor indices, Tensor update
106106

107107
throw new NotImplementedException("scatter_add");
108108
}
109+
110+
public static Tensor is_variable_initialized(RefVariable @ref, string name = null)
111+
{
112+
if (@ref.dtype.is_ref_dtype())
113+
return gen_state_ops.is_variable_initialized(@ref: @ref, name: name);
114+
throw new NotImplementedException("");
115+
//return @ref.is_initialized(name: name);
116+
}
109117
}
110118
}

test/TensorFlowNET.Examples/ImageProcessing/YOLO/Main.cs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,12 @@ public Graph BuildGraph()
9191

9292
tf_with(tf.name_scope("define_loss"), scope =>
9393
{
94-
model = new YOLOv3(cfg, input_data, trainable);
94+
// model = new YOLOv3(cfg, input_data, trainable);
95+
});
96+
97+
tf_with(tf.name_scope("define_weight_decay"), scope =>
98+
{
99+
var moving_ave = tf.train.ExponentialMovingAverage(moving_ave_decay).apply((RefVariable[])tf.trainable_variables());
95100
});
96101

97102
return graph;

test/TensorFlowNET.UnitTest/ImageTest.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@
88

99
namespace TensorFlowNET.UnitTest
1010
{
11+
/// <summary>
12+
/// Find more examples in https://www.programcreek.com/python/example/90444/tensorflow.read_file
13+
/// </summary>
1114
[TestClass]
1215
public class ImageTest
1316
{

test/TensorFlowNET.UnitTest/NameScopeTest.cs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,7 @@ public void NestedNameScope_Using()
6969
Assert.AreEqual("scope1", g._name_stack);
7070
var const3 = tf.constant(2.0);
7171
Assert.AreEqual("scope1/Const_1:0", const3.name);
72-
}
73-
74-
;
72+
};
7573

7674
g.Dispose();
7775

0 commit comments

Comments
 (0)
X Tutup