X Tutup
Skip to content

Commit 0c6f8cd

Browse files
committed
fix internal_convert_to_tensor shape exception for scalar value.
1 parent 1c5731f commit 0c6f8cd

File tree

4 files changed

+54
-20
lines changed

4 files changed

+54
-20
lines changed

docs/source/Variable.md

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,19 @@
11
# Chapter. Variable
22

3+
The variables in TensorFlow are mainly used to represent variable parameter values in the machine learning model. Variables can be initialized by the `tf.Variable` function. During the graph computation the variables are modified by other operations. Variables exist in the session, as long as they are in the same session, other computing nodes on the network can access the same variable value. Variables use lazy loading and will only request memory space when they are used.
4+
5+
TensorFlow中变量主要用来表示机器学习模型中的可变参数值,变量通过可以通过`tf.Variable` 类进行初始化。在图运行过程中,通过各种操作对变量进行修改。变量存在于会话当中,只要是在同一个会话里,网络上的其它计算结节都可以访问到相同的变量值。变量采用延迟加载的方式,只有使用的时候才会申请内存空间。
6+
7+
```csharp
8+
var x = tf.Variable(10, name: "x");
9+
using (var session = tf.Session())
10+
{
11+
session.run(x.initializer);
12+
var result = session.run(x);
13+
Console.Write(result); // should be 10
14+
}
15+
```
16+
17+
The above code first creates a variable operation, initializes the variable, then runs the session, and finally gets the result. This code is very simple, but it shows the complete process how TensorFlow operates on variables. When creating a variable, you pass a `tensor` as the initial value to the function `Variable()`. TensorFlow provides a series of operators to initialize the tensor, the initial value is a constant or a random value.
18+
19+
以上代码先创建变量操作,初始化变量,再运行会话,最后得到结果。这段代码非常简单,但是它体现了整个TensorFlow对变量操作的完整流程。当创建一个变量时,你将一个`张量`作为初始值传入函数`Variable()`。TensorFlow提供了一系列操作符来初始化张量,初始值是常量或是随机值。

src/TensorFlowNET.Core/Operations/OpDefLibrary.cs

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -44,14 +44,13 @@ public Operation _apply_op_helper(string op_type_name, string name = "", dynamic
4444
var input_types = new List<TF_DataType>();
4545
var base_types = new List<TF_DataType>();
4646

47-
Operation op = null;
48-
Python.with<ops.name_scope>(new ops.name_scope(name), scope =>
47+
return Python.with<ops.name_scope, Operation>(new ops.name_scope(name), scope =>
4948
{
5049
// Perform input type inference
5150
foreach (var input_arg in op_def.InputArg)
5251
{
53-
var input_name = input_arg.Name;
54-
var values = keywords[input_name];
52+
var input_arg_name = input_arg.Name;
53+
var values = keywords[input_arg_name];
5554
// Goals:
5655
// * Convert values to Tensors if it contains constants.
5756
// * Verify that values is a list if that matches the input_arg's
@@ -64,13 +63,13 @@ public Operation _apply_op_helper(string op_type_name, string name = "", dynamic
6463
// * If the input_arg has an explicit type, make sure the input
6564
// conforms.
6665

66+
DataType dtype = DataType.DtInvalid;
67+
DataType default_dtype = DataType.DtInvalid;
68+
6769
if (_IsListParameter(input_arg))
6870
{
69-
DataType dtype = DataType.DtInvalid;
70-
DataType default_dtype = DataType.DtInvalid;
71-
7271
if (!_IsListValue(values))
73-
throw new TypeError($"Expected list for '{input_name}' argument to '{op_type_name}' Op, not {values}.");
72+
throw new TypeError($"Expected list for '{input_arg_name}' argument to '{op_type_name}' Op, not {values}.");
7473
if(input_arg.Type != DataType.DtInvalid)
7574
{
7675
dtype = input_arg.Type;
@@ -87,19 +86,22 @@ public Operation _apply_op_helper(string op_type_name, string name = "", dynamic
8786
}
8887
else
8988
{
90-
if (keywords[input_name] is Tensor)
89+
if (default_type_attr_map.ContainsKey(input_arg.TypeAttr))
90+
default_dtype = (DataType)default_type_attr_map[input_arg.TypeAttr];
91+
92+
if (keywords[input_arg_name] is Tensor)
9193
{
9294
}
9395
else
9496
{
95-
keywords[input_name] = ops.internal_convert_to_tensor(values, name: input_name);
97+
keywords[input_arg_name] = ops.internal_convert_to_tensor(values, name: input_arg_name);
9698
}
9799

98100
if (!String.IsNullOrEmpty(input_arg.TypeAttr))
99101
{
100-
attrs[input_arg.TypeAttr] = (keywords[input_name] as Tensor).dtype;
102+
attrs[input_arg.TypeAttr] = (keywords[input_arg_name] as Tensor).dtype;
101103
}
102-
values = new Tensor[] { keywords[input_name] as Tensor };
104+
values = new Tensor[] { keywords[input_arg_name] as Tensor };
103105
}
104106

105107
inputs.AddRange(values as Tensor[]);
@@ -122,7 +124,7 @@ public Operation _apply_op_helper(string op_type_name, string name = "", dynamic
122124
{
123125
var key = attr_def.Name;
124126
if (!attrs.ContainsKey(key))
125-
Console.WriteLine($"{key} not found in attr_def.");
127+
Console.WriteLine($"_apply_op_helper: key '{key}' is not found in '{op_def.Name}' operation's attr_def.");
126128
var value = attrs[key];
127129
var attr_value = new AttrValue();
128130

@@ -165,14 +167,14 @@ public Operation _apply_op_helper(string op_type_name, string name = "", dynamic
165167
}
166168

167169
// Add Op to graph
168-
op = g.create_op(op_type_name, inputs, output_types.ToArray(),
170+
var op = g.create_op(op_type_name, inputs, output_types.ToArray(),
169171
name: scope,
170172
input_types: input_types.ToArray(),
171173
attrs: attr_protos,
172174
op_def: op_def);
173-
});
174175

175-
return op;
176+
return op;
177+
});
176178
}
177179

178180
public DataType _MakeType(TF_DataType v, AttrDef attr_def)

src/TensorFlowNET.Core/Operations/math_ops.py.cs

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,10 @@ public static Tensor reduced_shape(Tensor input_shape, Tensor axes)
2020
var input_rank = array_ops.size(input_shape);
2121
axes = (axes + input_rank) % input_rank;
2222
var axes_shape = array_ops.shape(axes);
23-
var a1 = new Tensor[] { input_rank, axes };
24-
var a2 = new Tensor[] { input_shape, gen_array_ops.fill(axes_shape, 1) };
23+
var rng = math_ops.range(input_rank);
24+
var a1 = new Tensor[] { rng, axes };
25+
var fill = gen_array_ops.fill(axes_shape, 1);
26+
var a2 = new Tensor[] { input_shape, fill };
2527

2628
return gen_data_flow_ops.dynamic_stitch(a1, a2);
2729
}
@@ -80,8 +82,17 @@ private static Tensor _ReductionDims(Tensor x, Tensor axis)
8082
}
8183
}
8284

83-
public static Tensor range(object start, Tensor limit = null, object delta = null, TF_DataType dtype = TF_DataType.DtInvalid, string name = "range" )
85+
public static Tensor range(object start, object limit = null, object delta = null, TF_DataType dtype = TF_DataType.DtInvalid, string name = "range" )
8486
{
87+
if(limit == null)
88+
{
89+
limit = start;
90+
start = 0;
91+
}
92+
93+
if (delta == null)
94+
delta = 1;
95+
8596
return Python.with<ops.name_scope, Tensor>(new ops.name_scope(name, "Range", new object[] { start, limit, delta }), scope =>
8697
{
8798
name = scope;

src/TensorFlowNET.Core/ops.py.cs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -333,8 +333,12 @@ public static Tensor internal_convert_to_tensor<T>(T value, DataType dtype = Dat
333333
{
334334
case "Tensor":
335335
return value as Tensor;
336+
case "Int32":
337+
return constant_op.constant(Convert.ToInt32(value), name);
338+
case "Double":
339+
return constant_op.constant(Convert.ToDouble(value), name);
336340
default:
337-
return constant_op.constant(np.array(value), name);
341+
throw new NotImplementedException($"internal_convert_to_tensor: Can't convert {typeof(T).Name} to Tensor");
338342
}
339343
}
340344
}

0 commit comments

Comments
 (0)
X Tutup