X Tutup
Skip to content

Commit 8762754

Browse files
committed
add tf.pad() to pads a tensor.
1 parent 1d6f3de commit 8762754

File tree

4 files changed

+56
-0
lines changed

4 files changed

+56
-0
lines changed

src/TensorFlowNET.Core/APIs/tf.array.cs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,18 @@ public Tensor one_hot(Tensor indices, int depth,
9999
int axis = -1,
100100
string name = null) => array_ops.one_hot(indices, depth, dtype: dtype, axis: axis, name: name);
101101

102+
/// <summary>
103+
/// Pads a tensor
104+
/// </summary>
105+
/// <param name="tensor"></param>
106+
/// <param name="paddings"></param>
107+
/// <param name="mode"></param>
108+
/// <param name="name"></param>
109+
/// <param name="constant_values"></param>
110+
/// <returns></returns>
111+
public Tensor pad(Tensor tensor, Tensor paddings, string mode = "CONSTANT", string name = null, int constant_values = 0)
112+
=> array_ops.pad(tensor, paddings, mode: mode, name: name, constant_values: constant_values);
113+
102114
/// <summary>
103115
/// A placeholder op that passes through `input` when its output is not fed.
104116
/// </summary>

src/TensorFlowNET.Core/Operations/array_ops.py.cs

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -552,6 +552,40 @@ public static Tensor stack(object values, int axis = 0, string name = "stack")
552552
throw new NotImplementedException("array_ops.stack");
553553
}
554554

555+
public static Tensor pad(Tensor tensor, Tensor paddings, string mode = "CONSTANT", string name = null, int constant_values = 0)
556+
{
557+
Tensor result = null;
558+
mode = mode.ToUpper();
559+
if(mode == "CONSTANT")
560+
{
561+
if (constant_values != 0)
562+
throw new NotImplementedException("gen_array_ops.pad_v2");
563+
else
564+
result = gen_array_ops.pad(tensor, paddings, name: name);
565+
}
566+
567+
// Restore shape information where possible.
568+
var paddings_constant = tensor_util.constant_value(
569+
result.op.inputs[1], partial: true);
570+
var input_shape = result.op.inputs[0].TensorShape;
571+
if (input_shape.ndim > -1 &&
572+
!result.TensorShape.is_fully_defined() &&
573+
!(paddings_constant is null))
574+
{
575+
var new_shape = new List<int>();
576+
foreach((NDArray padding, int dim) in zip(paddings_constant.GetNDArrays(), np.array(input_shape.dims).GetNDArrays()))
577+
{
578+
if (padding is null || dim == -1 || padding.GetData<int>().Contains(-1))
579+
new_shape.Add(-1);
580+
else
581+
new_shape.Add(np.sum(padding) + dim);
582+
}
583+
result.set_shape(new_shape.ToArray());
584+
}
585+
586+
return result;
587+
}
588+
555589
public static Tensor placeholder(TF_DataType dtype)
556590
{
557591
throw new NotImplementedException("array_ops.placeholder");

src/TensorFlowNET.Core/Operations/gen_array_ops.cs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,13 @@ public static Tensor gather_v2(Tensor @params, Tensor indices, int axis, string
9999
return _op.outputs[0];
100100
}
101101

102+
public static Tensor pad(Tensor input, Tensor paddings, string name = null)
103+
{
104+
var _op = _op_def_lib._apply_op_helper("Pad", name: name, args: new { input, paddings });
105+
106+
return _op.output;
107+
}
108+
102109
public static Tensor pack(Tensor[] values, int axis = 0, string name = null)
103110
{
104111
var _op = _op_def_lib._apply_op_helper("Pack", name: name, args: new { values, axis });

test/TensorFlowNET.Examples/ImageProcessing/YOLO/common.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@ public static Tensor convolutional(Tensor input_data, int[] filters_shape, Tenso
1919

2020
if (downsample)
2121
{
22+
(int pad_h, int pad_w) = ((int)Math.Floor((filters_shape[0] - 2) / 2.0f) + 1, (int)Math.Floor((filters_shape[1] - 2) / 2.0f) + 1);
23+
var paddings = tf.constant(new int[,] { { 0, 0 }, { pad_h, pad_h }, { pad_w, pad_w }, { 0, 0 } });
24+
input_data = tf.pad(input_data, paddings, "CONSTANT");
2225
throw new NotImplementedException("");
2326
}
2427
else

0 commit comments

Comments
 (0)
X Tutup