forked from SciSharp/TensorFlow.NET
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathc_api.eager.cs
More file actions
490 lines (424 loc) · 22 KB
/
c_api.eager.cs
File metadata and controls
490 lines (424 loc) · 22 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
using Google.Protobuf;
using System;
using System.Runtime.InteropServices;
using Tensorflow.Contexts;
using Tensorflow.Device;
using Tensorflow.Eager;
using Tensorflow.Util;
namespace Tensorflow
{
public partial class c_api
{
/// <summary>
/// Return a new options object.
/// </summary>
/// <returns>TFE_ContextOptions*</returns>
[DllImport(TensorFlowLibName)]
public static extern SafeContextOptionsHandle TFE_NewContextOptions();
/// <summary>
/// Set the config in TF_ContextOptions.options.
/// config should be a serialized tensorflow.ConfigProto proto.
/// If config was not parsed successfully as a ConfigProto, record the
/// error information in *status.
/// </summary>
/// <param name="options">TFE_ContextOptions*</param>
/// <param name="proto"></param>
/// <param name="proto_len">size_t</param>
/// <param name="status">SafeStatusHandle</param>
[DllImport(TensorFlowLibName)]
public static extern void TFE_ContextOptionsSetConfig(SafeContextOptionsHandle opts, byte[] proto, ulong proto_len, SafeStatusHandle status);
[DllImport(TensorFlowLibName)]
public static extern void TFE_ContextAddFunctionDef(SafeContextHandle ctx, byte[] serialized_function_def, ulong size, SafeStatusHandle status);
[DllImport(TensorFlowLibName)]
public static extern void TFE_ContextOptionsSetDevicePlacementPolicy(SafeContextOptionsHandle opts, ContextDevicePlacementPolicy device_policy);
/// <summary>
/// Destroy an options object.
/// </summary>
/// <param name="options">TFE_ContextOptions*</param>
[DllImport(TensorFlowLibName)]
public static extern void TFE_DeleteContextOptions(IntPtr options);
/// <summary>
/// Configure device placement policy logging for the eager executor. Note this
/// policy is applied to any subsequent op executions.
/// </summary>
/// <param name="ctx"></param>
/// <param name="enable"></param>
/// <param name="status"></param>
[DllImport(TensorFlowLibName)]
public static extern void TFE_ContextSetLogDevicePlacement(SafeContextHandle ctx, bool enable, SafeStatusHandle status);
/// <summary>
///
/// </summary>
/// <param name="op">TFE_Op*</param>
/// <param name="attr_name">const char*</param>
/// <param name="is_list">unsigned char*</param>
/// <param name="status">TF_Status*</param>
/// <returns></returns>
[DllImport(TensorFlowLibName)]
public static extern TF_AttrType TFE_OpGetAttrType(SafeEagerOpHandle op, string attr_name, ref byte is_list, SafeStatusHandle status);
[DllImport(TensorFlowLibName)]
public static extern TF_AttrType TFE_OpNameGetAttrType(SafeContextHandle ctx, string op_or_function_name, string attr_name, ref byte is_list, SafeStatusHandle status);
/// <summary>
/// Returns the length (number of tensors) of the input argument `input_name`
/// found in the provided `op`.
/// </summary>
/// <param name="op">TFE_Op*</param>
/// <param name="input_name">const char*</param>
/// <param name="status">TF_Status*</param>
[DllImport(TensorFlowLibName)]
public static extern int TFE_OpGetInputLength(SafeEagerOpHandle op, string input_name, SafeStatusHandle status);
/// <summary>
/// Returns the length (number of tensors) of the output argument `output_name`
/// found in the provided `op`.
/// </summary>
/// <param name="op"></param>
/// <param name="input_name"></param>
/// <param name="status"></param>
/// <returns></returns>
[DllImport(TensorFlowLibName)]
public static extern int TFE_OpGetOutputLength(SafeEagerOpHandle op, string input_name, SafeStatusHandle status);
/// <summary>
///
/// </summary>
/// <param name="op">TFE_Op*</param>
/// <param name="inputs">TFE_TensorHandle**</param>
/// <param name="num_inputs">int</param>
/// <param name="status">TF_Status*</param>
/// <returns></returns>
[DllImport(TensorFlowLibName)]
public static extern int TFE_OpAddInputList(SafeEagerOpHandle op, [In, MarshalAs(UnmanagedType.CustomMarshaler, MarshalTypeRef = typeof(SafeHandleArrayMarshaler))] SafeEagerTensorHandle[] inputs, int num_inputs, SafeStatusHandle status);
/// <summary>
///
/// </summary>
/// <param name="opts">const TFE_ContextOptions*</param>
/// <param name="status">TF_Status*</param>
/// <returns>TFE_Context*</returns>
[DllImport(TensorFlowLibName)]
public static extern SafeContextHandle TFE_NewContext(SafeContextOptionsHandle opts, SafeStatusHandle status);
/// <summary>
/// Adds a function (created from TF_GraphToFunction or
/// TF_FunctionImportFunctionDef) to the context, allowing it to be executed with
/// TFE_Execute by creating an op with the same name as the function.
/// </summary>
/// <param name="ctx"></param>
/// <param name="function"></param>
/// <param name="status"></param>
[DllImport(TensorFlowLibName)]
public static extern void TFE_ContextAddFunction(SafeContextHandle ctx, SafeFuncGraphHandle function, SafeStatusHandle status);
/// <summary>
/// Removes a function from the context. Once removed, you can no longer
/// TFE_Execute it or TFE_Execute any TFE_Op which has it as an attribute or any
/// other function which calls it as an attribute.
/// </summary>
/// <param name="ctx"></param>
/// <param name="name"></param>
/// <param name="status"></param>
[DllImport(TensorFlowLibName)]
public static extern void TFE_ContextRemoveFunction(SafeContextHandle ctx, string name, SafeStatusHandle status);
/// <summary>
/// Checks whether a function is registered under `name`.
/// </summary>
/// <param name="ctx"></param>
/// <param name="name"></param>
/// <returns></returns>
[DllImport(TensorFlowLibName)]
public static extern bool TFE_ContextHasFunction(SafeContextHandle ctx, string name);
[DllImport(TensorFlowLibName)]
public static extern void TFE_ContextStartStep(SafeContextHandle ctx);
[DllImport(TensorFlowLibName)]
public static extern void TFE_ContextEndStep(SafeContextHandle ctx);
/// <summary>
///
/// </summary>
/// <param name="ctx">TFE_Context*</param>
[DllImport(TensorFlowLibName)]
public static extern void TFE_DeleteContext(IntPtr ctx);
/// <summary>
/// Execute the operation defined by <paramref name="op"/> and return handles to computed
/// tensors in <paramref name="retvals"/>.
/// </summary>
/// <remarks>
/// Upon successful return, the first <paramref name="num_retvals"/> slots in <paramref name="retvals"/> will
/// contain handle instances which the caller is responsible for disposing once they are no longer in use.
/// </remarks>
/// <param name="op"></param>
/// <param name="retvals"></param>
/// <param name="num_retvals"></param>
/// <param name="status"></param>
public static void TFE_Execute(SafeEagerOpHandle op, SafeEagerTensorHandle[] retvals, out int num_retvals, SafeStatusHandle status)
{
unsafe
{
num_retvals = retvals?.Length ?? 0;
var rawReturns = stackalloc IntPtr[num_retvals];
TFE_Execute(op, rawReturns, ref num_retvals, status);
for (var i = 0; i < num_retvals; i++)
{
// A handle is created for every return, even if rawReturns[i] is null. The resulting handle will be
// non-null but invalid, which is the same behavior P/Invoke gives for non-array SafeHandle return
// values.
retvals[i] = new SafeEagerTensorHandle(rawReturns[i]);
}
}
}
/// <summary>
/// Execute the operation defined by 'op' and return handles to computed
/// tensors in `retvals`.
/// </summary>
/// <param name="op">TFE_Op*</param>
/// <param name="retvals">TFE_TensorHandle**</param>
/// <param name="num_retvals">int*</param>
/// <param name="status">TF_Status*</param>
[DllImport(TensorFlowLibName)]
private static unsafe extern void TFE_Execute(SafeEagerOpHandle op, IntPtr* retvals, ref int num_retvals, SafeStatusHandle status);
/// <summary>
///
/// </summary>
/// <param name="ctx">TFE_Context*</param>
/// <param name="op_or_function_name">const char*</param>
/// <param name="status">TF_Status*</param>
/// <returns></returns>
[DllImport(TensorFlowLibName)]
public static extern SafeEagerOpHandle TFE_NewOp(SafeContextHandle ctx, string op_or_function_name, SafeStatusHandle status);
/// <summary>
/// Resets `op_to_reset` with `op_or_function_name` and `raw_device_name`. This
/// is for performance optimization by reusing an exiting unused op rather than
/// creating a new op every time. If `raw_device_name` is `NULL` or empty, it
/// does not set the device name. If it's not `NULL`, then it attempts to parse
/// and set the device name. It's effectively `TFE_OpSetDevice`, but it is faster
/// than separately calling it because if the existing op has the same
/// `raw_device_name`, it skips parsing and just leave as it is.
/// </summary>
/// <param name="op_to_reset">TFE_Op*</param>
/// <param name="op_or_function_name">const char*</param>
/// <param name="raw_device_name">const char*</param>
/// <param name="status">TF_Status*</param>
[DllImport(TensorFlowLibName)]
public static extern void TFE_OpReset(SafeEagerOpHandle op_to_reset, string op_or_function_name, string raw_device_name, SafeStatusHandle status);
/// <summary>
///
/// </summary>
/// <param name="op">TFE_Op*</param>
[DllImport(TensorFlowLibName)]
public static extern void TFE_DeleteOp(IntPtr op);
/// <summary>
///
/// </summary>
/// <param name="op">TFE_Op*</param>
/// <param name="attr_name">const char*</param>
/// <param name="value">TF_DataType</param>
[DllImport(TensorFlowLibName)]
public static extern void TFE_OpSetAttrType(SafeEagerOpHandle op, string attr_name, TF_DataType value);
[DllImport(TensorFlowLibName)]
public static extern void TFE_OpSetAttrInt(SafeEagerOpHandle op, string attr_name, long value);
[DllImport(TensorFlowLibName)]
public static extern void TFE_OpSetAttrFloat(SafeEagerOpHandle op, string attr_name, float value);
/// <summary>
///
/// </summary>
/// <param name="op">TFE_Op*</param>
/// <param name="attr_name">const char*</param>
/// <param name="dims">const int64_t*</param>
/// <param name="num_dims">const int</param>
/// <param name="out_status">TF_Status*</param>
[DllImport(TensorFlowLibName)]
public static extern void TFE_OpSetAttrShape(SafeEagerOpHandle op, string attr_name, long[] dims, int num_dims, SafeStatusHandle out_status);
[DllImport(TensorFlowLibName)]
public static extern void TFE_OpSetAttrShapeList(SafeEagerOpHandle op, string attr_name, IntPtr[] dims, int[] num_dims, int num_values, SafeStatusHandle out_status);
[DllImport(TensorFlowLibName)]
public static extern void TFE_OpSetAttrStringList(SafeEagerOpHandle op, string attr_name, string[] values, ulong[] lengths, int num_values);
[DllImport(TensorFlowLibName)]
public static extern void TFE_OpSetAttrBool(SafeEagerOpHandle op, string attr_name, bool value);
[DllImport(TensorFlowLibName)]
public static extern void TFE_OpSetAttrFunctionName(SafeEagerOpHandle op, string attr_name, string data, int length);
/// <summary>
///
/// </summary>
/// <param name="op">TFE_Op*</param>
/// <param name="attr_name">const char*</param>
/// <param name="value">const void*</param>
/// <param name="length">size_t</param>
[DllImport(TensorFlowLibName)]
public static extern void TFE_OpSetAttrString(SafeEagerOpHandle op, string attr_name, string value, ulong length);
[DllImport(TensorFlowLibName)]
public static extern void TFE_OpSetAttrTypeList(SafeEagerOpHandle op, string attr_name, TF_DataType[] values, int num_values);
[DllImport(TensorFlowLibName)]
public static extern void TFE_OpSetAttrIntList(SafeEagerOpHandle op, string attr_name, long[] values, int num_values);
[DllImport(TensorFlowLibName)]
public static extern void TFE_OpSetAttrValueProto(IntPtr op, string attr_name, IntPtr proto, ulong proto_len, SafeStatusHandle status);
/// <summary>
///
/// </summary>
/// <param name="op"></param>
/// <param name="device_name"></param>
/// <param name="status"></param>
[DllImport(TensorFlowLibName)]
public static extern void TFE_OpSetDevice(SafeEagerOpHandle op, string device_name, SafeStatusHandle status);
/// <summary>
///
/// </summary>
/// <param name="op">TFE_Op*</param>
/// <param name="h">TFE_TensorHandle*</param>
/// <param name="status">TF_Status*</param>
[DllImport(TensorFlowLibName)]
public static extern void TFE_OpAddInput(SafeEagerOpHandle op, SafeEagerTensorHandle h, SafeStatusHandle status);
/// <summary>
///
/// </summary>
/// <param name="t">const tensorflow::Tensor&</param>
/// <returns>TFE_TensorHandle*</returns>
[DllImport(TensorFlowLibName)]
public static extern SafeEagerTensorHandle TFE_NewTensorHandle(SafeTensorHandle t, SafeStatusHandle status);
[DllImport(TensorFlowLibName)]
public static extern SafeEagerTensorHandle TFE_EagerTensorHandle(IntPtr t);
/// <summary>
/// Sets the default execution mode (sync/async). Note that this can be
/// overridden per thread using TFE_ContextSetExecutorForThread.
/// </summary>
/// <param name="opts">TFE_ContextOptions*</param>
/// <param name="enable">unsigned char</param>
[DllImport(TensorFlowLibName)]
public static extern void TFE_ContextOptionsSetAsync(SafeContextOptionsHandle opts, byte enable);
/// <summary>
///
/// </summary>
/// <param name="h">TFE_TensorHandle*</param>
/// <returns></returns>
[DllImport(TensorFlowLibName)]
public static extern TF_DataType TFE_TensorHandleDataType(SafeEagerTensorHandle h);
/// <summary>
/// This function will block till the operation that produces `h` has
/// completed. The memory returned might alias the internal memory used by
/// TensorFlow.
/// </summary>
/// <param name="h">TFE_TensorHandle*</param>
/// <param name="status">TF_Status*</param>
/// <returns></returns>
[DllImport(TensorFlowLibName)]
public static extern SafeTensorHandle TFE_TensorHandleResolve(SafeEagerTensorHandle h, SafeStatusHandle status);
/// <summary>
/// This function will block till the operation that produces `h` has completed.
/// </summary>
/// <param name="h">TFE_TensorHandle*</param>
/// <param name="status">TF_Status*</param>
/// <returns></returns>
[DllImport(TensorFlowLibName)]
public static extern int TFE_TensorHandleNumDims(SafeEagerTensorHandle h, SafeStatusHandle status);
[DllImport(TensorFlowLibName)]
public static extern int TFE_TensorHandleDim(SafeEagerTensorHandle h, int dim, SafeStatusHandle status);
/// <summary>
/// Returns the device of the operation that produced `h`. If `h` was produced by
/// a copy, returns the destination device of the copy. Note that the returned
/// device name is not always the device holding the tensor handle's memory. If
/// you want the latter, use TFE_TensorHandleBackingDeviceName. This function
/// will block till the operation that produces `h` has completed.
/// </summary>
/// <param name="h">TFE_TensorHandle*</param>
/// <param name="status">TF_Status*</param>
/// <returns></returns>
[DllImport(TensorFlowLibName)]
public static extern IntPtr TFE_TensorHandleDeviceName(SafeEagerTensorHandle h, SafeStatusHandle status);
/// <summary>
/// Returns the name of the device in whose memory `h` resides.
/// </summary>
/// <param name="h">TFE_TensorHandle*</param>
/// <param name="status">TF_Status*</param>
/// <returns></returns>
[DllImport(TensorFlowLibName)]
public static extern IntPtr TFE_TensorHandleBackingDeviceName(SafeEagerTensorHandle h, SafeStatusHandle status);
/// <summary>
///
/// </summary>
/// <param name="ctx">TFE_Context*</param>
/// <param name="status">TF_Status*</param>
/// <returns></returns>
[DllImport(TensorFlowLibName)]
public static extern SafeDeviceListHandle TFE_ContextListDevices(SafeContextHandle ctx, SafeStatusHandle status);
/// <summary>
/// Clears the internal caches in the TFE context. Useful when reseeding random ops.
/// </summary>
/// <param name="ctx">TFE_Context*</param>
[DllImport(TensorFlowLibName)]
public static extern void TFE_ContextClearCaches(SafeContextHandle ctx);
/// <summary>
///
/// </summary>
/// <param name="h">TFE_TensorHandle*</param>
[DllImport(TensorFlowLibName)]
public static extern void TFE_DeleteTensorHandle(IntPtr h);
/// <summary>
///
/// </summary>
/// <param name="h">TFE_TensorHandle*</param>
[DllImport(TensorFlowLibName)]
public static extern void TFE_DeleteEagerTensor(IntPtr h);
[DllImport(TensorFlowLibName)]
public static extern void TF_DeleteBindingArray(IntPtr h);
[DllImport(TensorFlowLibName)]
public static extern void TFE_DeleteBindingTensorArray(IntPtr h);
/// <summary>
/// Creates a new eager Executor. Nodes in one executor are guaranteed to be
/// executed in sequence. Assigning nodes to different executors allows executing
/// nodes in parallel.
/// </summary>
/// <param name="is_async"></param>
/// <returns>TFE_Executor*</returns>
[DllImport(TensorFlowLibName)]
public static extern SafeExecutorHandle TFE_NewExecutor(bool is_async);
/// <summary>
/// Deletes the eager Executor without waiting for enqueued nodes. Please call
/// TFE_ExecutorWaitForAllPendingNodes before calling this API if you want to
/// make sure all nodes are finished.
/// </summary>
/// <param name="executor">TFE_Executor*</param>
[DllImport(TensorFlowLibName)]
public static extern void TFE_DeleteExecutor(IntPtr executor);
/// <summary>
/// Causes the calling thread to block till all ops dispatched in this executor
/// have been executed. Note that "execution" here refers to kernel execution /
/// scheduling of copies, etc. Similar to sync execution, it doesn't guarantee
/// that lower level device queues (like GPU streams) have been flushed.
///
/// This call may not block for execution of ops enqueued concurrently with this
/// call.
/// </summary>
/// <param name="executor">TFE_Executor*</param>
/// <param name="status">TF_Status*</param>
[DllImport(TensorFlowLibName)]
public static extern void TFE_ExecutorWaitForAllPendingNodes(SafeExecutorHandle executor, SafeStatusHandle status);
/// <summary>
/// Sets a custom Executor for current thread. All nodes created by this thread
/// will be added to this Executor. It will override current executor.
/// </summary>
/// <param name="ctx"></param>
/// <param name="executor"></param>
[DllImport(TensorFlowLibName)]
public static extern void TFE_ContextSetExecutorForThread(SafeContextHandle ctx, SafeExecutorHandle executor);
/// <summary>
/// Returns the Executor for current thread.
/// </summary>
/// <param name="ctx"></param>
/// <returns>TFE_Executor*</returns>
[DllImport(TensorFlowLibName)]
public static extern SafeExecutorHandle TFE_ContextGetExecutorForThread(SafeContextHandle ctx);
[DllImport(TensorFlowLibName)]
public static extern IntPtr TFE_TapeSetNew(bool persistent, bool watch_accessed_variables);
[DllImport(TensorFlowLibName)]
public static extern void TFE_TapeSetRemove(IntPtr tape);
[DllImport(TensorFlowLibName)]
public static extern void TFE_TapeWatch(IntPtr tape, IntPtr variable);
[DllImport(TensorFlowLibName)]
public static extern void TFE_TapeVariableAccessed(IntPtr variable);
[DllImport(TensorFlowLibName)]
public static extern IntPtr TFE_TapeWatchedVariables(IntPtr tape);
[DllImport(TensorFlowLibName)]
public static extern IntPtr ResourceVariable_Handle(IntPtr variable);
[DllImport(TensorFlowLibName)]
public static extern SafeStatusHandle TFE_TapeGradient(IntPtr tape,
IntPtr[] target, int target_size,
IntPtr[] sources, int source_size,
IntPtr[] outputs, int output_size);
[DllImport(TensorFlowLibName)]
public static extern bool TFE_IsCustomDevice(SafeContextHandle ctx, string device_name);
}
}