/*****************************************************************************
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
******************************************************************************/
using Google.Protobuf;
using Google.Protobuf.Collections;
using Tensorflow.NumPy;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using Tensorflow.Contexts;
using Tensorflow.Eager;
using Tensorflow.Graphs;
using Tensorflow.Util;
using static Tensorflow.Binding;
using static Tensorflow.CppShapeInferenceResult.Types;
namespace Tensorflow
{
public partial class ops
{
public static long tensor_id(Tensor tensor)
{
return tensor.Id;
}
public static void add_to_collection(string name, T value)
{
var graph = tf.get_default_graph();
graph.add_to_collection(name, value);
}
public static void add_to_collections(List names, T value)
{
var graph = tf.get_default_graph();
graph.add_to_collections(names, value);
}
///
/// Wrapper for `Graph.get_collection()` using the default graph.
/// contains many standard names for collections.
///
///
/// The key for the collection. For example, the `GraphKeys` class
///
///
///
/// The list of values in the collection with the given `name`, or
/// an empty list if no value has been added to that collection. The
/// list contains the values in the order under which they were
/// collected.
///
public static object get_collection(string key, string scope = null)
{
return get_default_graph().get_collection(key, scope);
}
public static List get_collection(string key, string scope = null)
{
return get_default_graph().get_collection(key, scope);
}
public static List get_collection_ref(string key)
{
return get_default_graph().get_collection_ref(key);
}
public static Graph _get_graph_from_inputs(params object[] op_input_list)
{
var current_default_graph = get_default_graph();
if (current_default_graph.building_function)
return current_default_graph;
Graph graph = null;
foreach (var op_input in op_input_list)
{
if (op_input is Tensor op_input_tensor)
graph = graph ?? op_input_tensor.graph;
}
return graph ?? current_default_graph;
}
public static Graph _get_graph_from_inputs(Tensors op_input_list)
=> _get_graph_from_inputs(op_input_list: op_input_list, graph: null);
public static Graph _get_graph_from_inputs(Tensors op_input_list, Graph graph = null)
{
foreach (var op_input in op_input_list)
{
// Determine if this is a valid graph_element.
// var graph_element = op_input;
}
return get_default_graph();
}
///
/// Converts the given `value` to a `Tensor`.
///
///
///
///
///
public static Tensor convert_to_tensor(object value,
TF_DataType dtype = TF_DataType.DtInvalid,
string name = null,
bool as_ref = false,
TF_DataType preferred_dtype = TF_DataType.DtInvalid,
Context ctx = null)
{
if (dtype == TF_DataType.DtInvalid)
dtype = preferred_dtype;
if (dtype == TF_DataType.DtInvalid)
dtype = value.GetDataType();
if (value is EagerTensor eager_tensor)
{
if (tf.executing_eagerly())
{
if (dtype != TF_DataType.DtInvalid && dtype != eager_tensor.dtype)
return gen_math_ops.cast(eager_tensor, dtype.as_base_dtype(), name: name);
return eager_tensor;
}
else
{
var graph = get_default_graph();
if (graph is FuncGraph funcGraph)
{
return funcGraph.capture(eager_tensor, name: name);
}
if (!graph.building_function)
{
// throw new RuntimeError("Attempting to capture an EagerTensor without building a function.");
return eager_tensor.AsPlaceholder(name: name);
}
}
}
else if (value is KerasTensor kt)
{
if (kt.inferred_value != null)
{
return convert_to_tensor(kt.inferred_value, dtype: kt.dtype, name: name);
}
}
// graph mode
Tensor ret = value switch
{
NDArray nd => constant_op.constant(nd, dtype: dtype, name: name),
EagerTensor tensor => tensor.dtype == TF_DataType.TF_RESOURCE
? tensor.AsPlaceholder(name: name)
: tensor.AsConstant(name: name),
Tensor tensor => tensor,
IEnumerable tensors => array_ops._autopacking_helper(tensors, dtype, name == null ? "packed" : name),
RefVariable varVal => varVal._TensorConversionFunction(dtype: dtype, name: name, as_ref: as_ref),
ResourceVariable varVal => varVal._TensorConversionFunction(dtype: dtype, name: name, as_ref: as_ref),
Axis ts => constant_op.constant(ts, dtype: dtype, name: name),
Shape ts => constant_op.constant(ts.dims, dtype: dtype, name: name),
string str => constant_op.constant(str, dtype: tf.@string, name: name),
string[] str => constant_op.constant(str, dtype: tf.@string, name: name),
IEnumerable