forked from tensorflow/tensorflow
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgraph_util_test.py
More file actions
175 lines (156 loc) · 7.02 KB
/
graph_util_test.py
File metadata and controls
175 lines (156 loc) · 7.02 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
# Copyright 2015 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tensorflow.python.client.graph_util."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow.python.platform
import tensorflow as tf
from tensorflow.python.client import graph_util
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import constant_op
from tensorflow.python.ops import data_flow_ops
# pylint: disable=unused-import
from tensorflow.python.ops import math_ops
# pylint: enable=unused-import
from tensorflow.python.ops import state_ops
from tensorflow.python.platform import googletest
class DeviceFunctionsTest(googletest.TestCase):
def testPinToCpu(self):
with ops.Graph().as_default() as g, g.device(graph_util.pin_to_cpu):
const_a = constant_op.constant(5.0)
const_b = constant_op.constant(10.0)
add_c = const_a + const_b
var_v = state_ops.variable_op([], dtype=dtypes.float32)
assign_c_to_v = state_ops.assign(var_v, add_c)
const_string = constant_op.constant("on a cpu")
dynamic_stitch_int_result = data_flow_ops.dynamic_stitch(
[[0, 1, 2], [2, 3]], [[12, 23, 34], [1, 2]])
dynamic_stitch_float_result = data_flow_ops.dynamic_stitch(
[[0, 1, 2], [2, 3]], [[12.0, 23.0, 34.0], [1.0, 2.0]])
self.assertEqual(const_a.device, "/device:CPU:0")
self.assertEqual(const_b.device, "/device:CPU:0")
self.assertEqual(add_c.device, "/device:CPU:0")
self.assertEqual(var_v.device, "/device:CPU:0")
self.assertEqual(assign_c_to_v.device, "/device:CPU:0")
self.assertEqual(const_string.device, "/device:CPU:0")
self.assertEqual(dynamic_stitch_int_result.device, "/device:CPU:0")
self.assertEqual(dynamic_stitch_float_result.device, "/device:CPU:0")
def testPinRequiredOpsOnCPU(self):
with ops.Graph().as_default() as g, g.device(
graph_util.pin_variables_on_cpu):
const_a = constant_op.constant(5.0)
const_b = constant_op.constant(10.0)
add_c = const_a + const_b
var_v = state_ops.variable_op([], dtype=dtypes.float32)
assign_c_to_v = state_ops.assign(var_v, add_c)
dynamic_stitch_int_result = data_flow_ops.dynamic_stitch(
[[0, 1, 2], [2, 3]], [[12, 23, 34], [1, 2]])
dynamic_stitch_float_result = data_flow_ops.dynamic_stitch(
[[0, 1, 2], [2, 3]], [[12.0, 23.0, 34.0], [1.0, 2.0]])
# Non-variable ops shuld not specify a device
self.assertEqual(const_a.device, None)
self.assertEqual(const_b.device, None)
self.assertEqual(add_c.device, None)
# Variable ops specify a device
self.assertEqual(var_v.device, "/device:CPU:0")
self.assertEqual(assign_c_to_v.device, "/device:CPU:0")
def testTwoDeviceFunctions(self):
with ops.Graph().as_default() as g:
var_0 = state_ops.variable_op([1], dtype=dtypes.float32)
with g.device(graph_util.pin_variables_on_cpu):
var_1 = state_ops.variable_op([1], dtype=dtypes.float32)
var_2 = state_ops.variable_op([1], dtype=dtypes.float32)
var_3 = state_ops.variable_op([1], dtype=dtypes.float32)
with g.device(graph_util.pin_variables_on_cpu):
var_4 = state_ops.variable_op([1], dtype=dtypes.float32)
with g.device("/device:GPU:0"):
var_5 = state_ops.variable_op([1], dtype=dtypes.float32)
var_6 = state_ops.variable_op([1], dtype=dtypes.float32)
self.assertEqual(var_0.device, None)
self.assertEqual(var_1.device, "/device:CPU:0")
self.assertEqual(var_2.device, None)
self.assertEqual(var_3.device, None)
self.assertEqual(var_4.device, "/device:CPU:0")
self.assertEqual(var_5.device, "/device:GPU:0")
self.assertEqual(var_6.device, "/device:CPU:0")
def testExplicitDevice(self):
with ops.Graph().as_default() as g:
const_0 = constant_op.constant(5.0)
with g.device("/device:GPU:0"):
const_1 = constant_op.constant(5.0)
with g.device("/device:GPU:1"):
const_2 = constant_op.constant(5.0)
with g.device("/device:CPU:0"):
const_3 = constant_op.constant(5.0)
with g.device("/device:CPU:1"):
const_4 = constant_op.constant(5.0)
with g.device("/job:ps"):
const_5 = constant_op.constant(5.0)
self.assertEqual(const_0.device, None)
self.assertEqual(const_1.device, "/device:GPU:0")
self.assertEqual(const_2.device, "/device:GPU:1")
self.assertEqual(const_3.device, "/device:CPU:0")
self.assertEqual(const_4.device, "/device:CPU:1")
self.assertEqual(const_5.device, "/job:ps")
def testDefaultDevice(self):
with ops.Graph().as_default() as g, g.device(
graph_util.pin_variables_on_cpu):
with g.device("/job:ps"):
const_0 = constant_op.constant(5.0)
with g.device("/device:GPU:0"):
const_1 = constant_op.constant(5.0)
with g.device("/device:GPU:1"):
const_2 = constant_op.constant(5.0)
with g.device("/device:CPU:0"):
const_3 = constant_op.constant(5.0)
with g.device("/device:CPU:1"):
const_4 = constant_op.constant(5.0)
with g.device("/replica:0"):
const_5 = constant_op.constant(5.0)
self.assertEqual(const_0.device, "/job:ps")
self.assertEqual(const_1.device, "/device:GPU:0")
self.assertEqual(const_2.device, "/device:GPU:1")
self.assertEqual(const_3.device, "/device:CPU:0")
self.assertEqual(const_4.device, "/device:CPU:1")
self.assertEqual(const_5.device, "/replica:0")
def testExtractSubGraph(self):
graph_def = tf.GraphDef()
n1 = graph_def.node.add()
n1.name = "n1"
n1.input.extend(["n5"])
n2 = graph_def.node.add()
n2.name = "n2"
# Take the first output of the n1 node as the input.
n2.input.extend(["n1:0"])
n3 = graph_def.node.add()
n3.name = "n3"
# Add a control input (which isn't really needed by the kernel, but
# rather to enforce execution order between nodes).
n3.input.extend(["^n2"])
n4 = graph_def.node.add()
n4.name = "n4"
# It is fine to have a loops in the graph as well.
n5 = graph_def.node.add()
n5.name = "n5"
n5.input.extend(["n1"])
sub_graph = graph_util.extract_sub_graph(graph_def, ["n3"])
self.assertEqual("n1", sub_graph.node[0].name)
self.assertEqual("n2", sub_graph.node[1].name)
self.assertEqual("n3", sub_graph.node[2].name)
self.assertEqual("n5", sub_graph.node[3].name)
if __name__ == "__main__":
googletest.main()