11using System ;
22using System . Collections . Generic ;
3+ using System . Linq ;
34using System . Text ;
5+ using static Tensorflow . Binding ;
46
57namespace Tensorflow . Train
68{
@@ -11,6 +13,7 @@ public class ExponentialMovingAverage
1113 bool _zero_debias ;
1214 string _name ;
1315 public string name => _name ;
16+ List < VariableV1 > _averages ;
1417
1518 public ExponentialMovingAverage ( float decay , int ? num_updates = null , bool zero_debias = false ,
1619 string name = "ExponentialMovingAverage" )
@@ -19,18 +22,31 @@ public ExponentialMovingAverage(float decay, int? num_updates = null, bool zero_
1922 _num_updates = num_updates ;
2023 _zero_debias = zero_debias ;
2124 _name = name ;
25+ _averages = new List < VariableV1 > ( ) ;
2226 }
2327
2428 /// <summary>
2529 /// Maintains moving averages of variables.
2630 /// </summary>
2731 /// <param name="var_list"></param>
2832 /// <returns></returns>
29- public Operation apply ( VariableV1 [ ] var_list = null )
33+ public Operation apply ( RefVariable [ ] var_list = null )
3034 {
31- throw new NotImplementedException ( "" ) ;
32- }
35+ if ( var_list == null )
36+ var_list = variables . trainable_variables ( ) as RefVariable [ ] ;
3337
38+ foreach ( var var in var_list )
39+ {
40+ if ( ! _averages . Contains ( var ) )
41+ {
42+ ops . init_scope ( ) ;
43+ var slot = new SlotCreator ( ) ;
44+ var . initialized_value ( ) ;
45+ // var avg = slot.create_zeros_slot
46+ }
47+ }
3448
49+ throw new NotImplementedException ( "" ) ;
50+ }
3551 }
3652}
0 commit comments