forked from tensorflow/tensorflow
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathfunction_test.py
More file actions
366 lines (309 loc) · 13.2 KB
/
function_test.py
File metadata and controls
366 lines (309 loc) · 13.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
# Copyright 2015 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""Tests for functions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# pylint: disable=unused-import,g-bad-import-order
import tensorflow.python.platform
# pylint: enable=unused-import,g-bad-import-order
import time
import numpy as np
import tensorflow as tf
from tensorflow.python.framework import function
# pylint: disable=unused-import
from tensorflow.python.ops import functional_ops
# pylint: enable=unused-import
class FunctionTest(tf.test.TestCase):
def _mat(self, x):
return np.array([x]).astype("float32").reshape([1, 1])
def testBasic(self):
g = tf.Graph()
# Define a function
# foo(a:float, b:float, c:float)->u:float,v:float,w:float
# u = matmul(a, b) + c
# v = u^2
# w = u + v
# TODO(zhifengc): replaces w/ a nicer @decorator sugar.
foo = tf.Graph()
with foo.as_default():
a = tf.placeholder(tf.float32, name="a")
b = tf.placeholder(tf.float32, name="b")
c = tf.placeholder(tf.float32, name="c")
u = tf.add(tf.matmul(a, b), c, name="u")
v = tf.square(u, name="v")
w = tf.add_n([u, v], name="w")
fdef = function.graph_to_function_def(foo, "foo", [a, b, c], [u, v, w])
g._add_function(fdef)
# Compute 2 * 3 + 4 and its square.
with g.as_default(), tf.Session() as sess:
two = tf.constant(self._mat(2.0), name="two")
three = tf.constant(self._mat(3.0), name="three")
four = tf.constant(self._mat(4.0), name="four")
# TODO(zhifengc): w/ @decorator sugar, we will just do:
# y, s, t = foo_func(two, three, four)
# The graph contains two ops each of which calls foo.
u0, v0, w0 = g.create_op("foo",
[two, three, four],
[tf.float32, tf.float32, tf.float32],
compute_shapes=False).outputs
u1, v1, w1 = g.create_op("foo",
[four, two, three],
[tf.float32, tf.float32, tf.float32],
compute_shapes=False).outputs
# Checks some property of the graph def.
gdef = g.as_graph_def()
self.assertEqual(len(gdef.node), 5) # 5 nodes added.
self.assertEqual(len(gdef.library.function), 1) # 1 function is defined.
for _ in xrange(10):
# Run the graph, which is basicly two function calls.
ans_u0, ans_v0, ans_w0, ans_u1, ans_v1, ans_w1 = sess.run([u0, v0, w0,
u1, v1, w1])
self.assertAllEqual(ans_u0, self._mat(10.0)) # 2 * 3 + 4 = 10
self.assertAllEqual(ans_v0, self._mat(100.0)) # 10^2 = 100
self.assertAllEqual(ans_w0, self._mat(110.0)) # 100 + 10 = 110
self.assertAllEqual(ans_u1, self._mat(11.0)) # 4 * 2 + 3 = 11
self.assertAllEqual(ans_v1, self._mat(121.0)) # 11^2 = 121
self.assertAllEqual(ans_w1, self._mat(132.0)) # 11 + 121 = 132
def testDefineFunction2Args(self):
def APlus2B(a, b):
return a + b * 2
with tf.Graph().as_default():
f_def = function.define_function(APlus2B, {"a": tf.float32,
"b": tf.float32})
one = tf.constant([1.0])
two = tf.constant([2.0])
call = function.call_function(f_def, one, two)
self.assertEquals("APlus2B", call.op.name)
with tf.Session() as sess:
self.assertAllEqual([5.0], sess.run(call))
def testGradientFunc(self):
def XSquarePlusOne(x):
return x * x + 1.0
def XSquarePlusOneGrad(x, dy):
dx = functional_ops._symbolic_gradient(input=[x, dy],
Tout=[tf.float32],
f="XSquarePlusOne",
name="dx")
return dx
g = tf.Graph()
with g.as_default():
f = function.define_function(XSquarePlusOne, {"x": tf.float32})
g = function.define_function(XSquarePlusOneGrad, {"x": tf.float32,
"dy": tf.float32})
epsilon = tf.constant([0.1])
two = tf.constant([2.0])
call_f = function.call_function(f, two)
call_g = function.call_function(g, two, epsilon)
with tf.Session() as sess:
self.assertAllClose([5.0], sess.run(call_f))
self.assertAllClose([0.4], sess.run(call_g))
def testSymGradShape(self):
g = tf.Graph()
with g.as_default():
x = tf.placeholder(tf.float32, [25, 4])
y = tf.placeholder(tf.float32, [200, 100])
dz = tf.placeholder(tf.float32, [1])
# We assume Foo is a function of (x, y) -> (z) Then, Foo's
# gradient function is (x, y, dz) -> (dx, dy). dx's shape
# should be the same as x's; and dy's shape should be the same
# as y's.
dx, dy = functional_ops._symbolic_gradient(input=[x, y, dz],
Tout=[tf.float32] * 2,
f="Foo")
self.assertEquals(x.get_shape(), dx.get_shape())
self.assertEquals(y.get_shape(), dy.get_shape())
def testDefineFunctionNoArgs(self):
def AConstant():
return tf.constant([42])
with tf.Graph().as_default():
f_def = function.define_function(AConstant, {})
call = function.call_function(f_def)
self.assertEquals("AConstant", call.op.name)
with tf.Session() as sess:
self.assertAllEqual([42], sess.run(call))
def testDefineFunctionNames(self):
def Foo(a):
return a + 1
with tf.Graph().as_default():
f_def = function.define_function(Foo, {"a": tf.float32})
one = tf.constant([1.0])
call1 = function.call_function(f_def, one)
self.assertEquals("Foo", call1.op.name)
call2 = function.call_function(f_def, one)
self.assertEquals("Foo_1", call2.op.name)
call3 = function.call_function(f_def, one, name="mine")
self.assertEquals("mine", call3.op.name)
with tf.name_scope("my"):
call4 = function.call_function(f_def, one, name="precious")
self.assertEquals("my/precious", call4.op.name)
def testDefineErrors(self):
def NoResult():
pass
def VarArgs(*unused_b):
return tf.constant([1])
def DefaultArg(unused_a=12):
return tf.constant([1])
def KwArgs(**unused_kwargs):
return tf.constant([1])
def PlusMinus(a, b):
return a + b, b - a
with tf.Graph().as_default():
with self.assertRaisesRegexp(ValueError, "return at least one tensor"):
function.define_function(NoResult, {})
with self.assertRaisesRegexp(ValueError, "plain arglists are supported"):
function.define_function(VarArgs, {})
with self.assertRaisesRegexp(ValueError, "plain arglists are supported"):
function.define_function(DefaultArg, {})
with self.assertRaisesRegexp(ValueError, "plain arglists are supported"):
function.define_function(KwArgs, {})
with self.assertRaisesRegexp(ValueError, "specified input types"):
function.define_function(PlusMinus, {})
with self.assertRaisesRegexp(ValueError, "specified input types"):
function.define_function(PlusMinus, {"c": tf.float32})
with self.assertRaisesRegexp(ValueError, "type for argument: b"):
function.define_function(PlusMinus, {"a": tf.float32,
"c": tf.float32})
with self.assertRaisesRegexp(ValueError, "specified input types"):
function.define_function(PlusMinus, {"a": tf.float32,
"b": tf.float32,
"c": tf.float32})
def testCallErrors(self):
def Const():
return tf.constant(1)
def PlusOne(a):
return a + 1
def PlusMinus(a, b):
return a + b, b - a
with tf.Graph().as_default():
one = tf.constant([1])
two = tf.constant([2])
const = function.define_function(Const, {})
plus_one = function.define_function(PlusOne, {"a": tf.int32})
plus_minus = function.define_function(PlusMinus, {"a": tf.int32,
"b": tf.int32})
function.call_function(const)
with self.assertRaisesRegexp(ValueError, "arguments: 0"):
function.call_function(const, one)
with self.assertRaisesRegexp(ValueError, "arguments: 0"):
function.call_function(const, one, two)
with self.assertRaisesRegexp(ValueError, "arguments: 1"):
function.call_function(plus_one)
function.call_function(plus_one, one)
with self.assertRaisesRegexp(ValueError, "arguments: 1"):
function.call_function(plus_one, one, two)
with self.assertRaisesRegexp(ValueError, "arguments: 2"):
function.call_function(plus_minus)
with self.assertRaisesRegexp(ValueError, "arguments: 2"):
function.call_function(plus_minus, one)
function.call_function(plus_minus, one, two)
function.call_function(plus_one, one, name="p1")
with self.assertRaisesRegexp(ValueError, "Unknown keyword arguments"):
function.call_function(plus_one, one, device="/gpu:0")
def testFunctionDecorator(self):
with tf.Graph().as_default():
@function.Defun(b=tf.int32)
def Minus1(b):
return b - 1
two = tf.constant([2])
call1 = Minus1(two)
self.assertEquals("Minus1", call1.op.name)
# pylint: disable=unexpected-keyword-arg
call2 = Minus1(call1, name="next")
# pylint:enable=unexpected-keyword-arg
self.assertEquals("next", call2.op.name)
with tf.Session() as sess:
self.assertAllEqual([1], sess.run(call1))
self.assertAllEqual([0], sess.run(call2))
def testNestedFunction(self):
with tf.Graph().as_default():
@function.Defun(x=tf.float32)
def Cube(x):
return x * x * x
@function.Defun(x=tf.float32, y=tf.float32)
def CubeXPlusY(x, y):
return Cube(x) + y
z = CubeXPlusY(tf.constant(3.0), tf.constant(-2.0))
with self.test_session():
self.assertAllEqual(z.eval(), 25.0)
# Helper to construct a LSTM cell graph.
@classmethod
def LSTMCell(cls, x, mprev, cprev, weights):
xm = tf.concat(1, [x, mprev])
i_i, i_g, f_g, o_g = tf.split(1, 4, tf.matmul(xm, weights))
new_c = tf.sigmoid(f_g) * cprev + tf.sigmoid(i_g) * tf.tanh(i_i)
new_c = tf.clip_by_value(new_c, -50.0, 50.0)
new_m = tf.sigmoid(o_g) * tf.tanh(new_c)
return new_m, new_c
def _BuildForward(self, use_func=True, num_unroll=100):
batch_size = 16
lstm_dims = 32
cell = FunctionTest.LSTMCell
if use_func:
cell = function.Defun(x=tf.float32,
mprev=tf.float32,
cprev=tf.float32,
weights=tf.float32)(cell)
m = tf.zeros(shape=[batch_size, lstm_dims])
c = tf.zeros(shape=[batch_size, lstm_dims])
weights = tf.random_uniform(
[2 * lstm_dims, 4 * lstm_dims],
-1,
1,
seed=123456)
inputs = tf.random_uniform([num_unroll, batch_size, lstm_dims], seed=654321)
x = tf.unpack(inputs)
for i in range(num_unroll):
m, c = cell(x[i], m, c, weights)
return weights, m, c
def testUnrollLSTM(self):
# Run one step of the unrolled lstm graph.
def RunForward(use_func):
g = tf.Graph()
start = time.time()
with g.as_default():
_, m, c = self._BuildForward(use_func)
gdef = g.as_graph_def()
finish = time.time()
print("time: ", finish - start, " txt size: ", len(str(gdef)),
"gdef bin size: ", len(gdef.SerializeToString()))
with g.as_default(), tf.Session() as sess:
mv, cv = sess.run([m, c])
return mv, cv
mv0, cv0 = RunForward(use_func=False)
mv1, cv1 = RunForward(use_func=True)
self.assertAllClose(mv0, mv1)
self.assertAllClose(cv0, cv1)
def testUnrollLSTMGrad(self):
# Run one step of the unrolled lstm graph.
def RunForwarBackward(use_func):
g = tf.Graph()
start = time.time()
with g.as_default():
w, m, c = self._BuildForward(use_func)
loss = tf.reduce_sum(m) + tf.reduce_sum(c)
dw = tf.gradients([loss], [w])
gdef = g.as_graph_def()
finish = time.time()
print("time: ", finish - start, " txt size: ", len(str(gdef)),
"gdef bin size: ", len(gdef.SerializeToString()))
with g.as_default(), tf.Session() as sess:
ans = sess.run(dw)
return ans
ans0 = RunForwarBackward(use_func=False)
ans1 = RunForwarBackward(use_func=True)
self.assertAllClose(ans0, ans1)
if __name__ == "__main__":
tf.test.main()