# Copyright 2015 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""## Script Language Operators.
TensorFlow provides allows you to wrap python/numpy functions as
TensorFlow operators.
"""
# pylint: disable=g-bad-name
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.framework import ops
from tensorflow.python.ops import common_shapes
from tensorflow.python.ops import gen_script_ops
class FuncRegistry(object):
"""A helper class to keep track of registered py functions.
FuncRegistry keeps a map from unique tokens (string) to python
functions, which takes numpy arrays and outputs numpy arrays.
"""
def __init__(self):
self._unique_id = 0
self._funcs = {}
def insert(self, func):
"""Registers `func` and returns a unique token for this entry."""
token = self._next_unique_token()
self._funcs[token] = func
return token
def remove(self, token):
"""Removes the registered function corresponding to `token`."""
self._funcs.pop(token, None)
def __call__(self, token, args):
"""Calls the registered function for `token` with args."""
func = self._funcs[token]
if func is None:
raise ValueError("callback %s is not found" % token)
ret = func(*args)
# Ensures that we return either a single np array or or a list of
# np array.
if isinstance(ret, list):
ret = [np.array(x) for x in ret]
else:
ret = np.array(ret)
return ret
def size(self):
"""Returns how many functions are currently registered."""
return len(self._funcs)
def _next_unique_token(self):
"""Returns a unique token."""
uid = self._unique_id
self._unique_id += 1
return "pyfunc_%d" % uid
# Global registry for py functions.
_py_funcs = FuncRegistry()
pywrap_tensorflow.InitializePyTrampoline(_py_funcs)
class CleanupFunc(object):
"""A helper class to remove a registered function from _py_funcs."""
def __init__(self, token):
self._token = token
def __del__(self):
_py_funcs.remove(self._token)
def py_func(func, inp, Tout, name=None):
"""Wraps a python function and uses it as a tensorflow op.
Given a python function `func`, which takes numpy arrays as its
inputs and returns numpy arrays as its outputs. E.g.,
def my_func(x):
return np.sinh(x)
inp = tf.placeholder(..., tf.float32)
y = py_func(my_func, [inp], [tf.float32])
The above snippet constructs a tf graph which invokes a numpy
sinh(x) as an op in the graph.
Args:
func: A python function.
inp: A list of `Tensor`.
Tout: A list of tensorflow data types indicating what `func`
returns.
name: A name for the operation (optional).
Returns:
A list of `Tensor` which `func` computes.
"""
token = _py_funcs.insert(func)
# We tie the registered function's life-time with the current
# default graph. I.e., when the current graph is destroyed, we
# should remove its py funcs.
cleanup = CleanupFunc(token)
g = ops.get_default_graph()
# pylint: disable=protected-access
#
# TODO(zhifengc): Consider adding a Graph method to collect
# `cleanup` objects in one of its member.
if not hasattr(g, "_cleanup_py_funcs_used_in_graph"):
g._cleanup_py_funcs_used_in_graph = []
# When g is destroyed, elements in _cleanup_py_funcs_used_in_graph
# will be destroyed and their __del__ will remove the 'token' from
# the funcs registry.
g._cleanup_py_funcs_used_in_graph.append(cleanup)
return gen_script_ops._py_func(input=inp, token=token, Tout=Tout, name=name)
# pylint: enable=protected-access
ops.RegisterShape("PyFunc")(common_shapes.unknown_shape)
ops.NoGradient("PyFunc")