-
Notifications
You must be signed in to change notification settings - Fork 1.7k
Expand file tree
/
Copy pathbatch.py
More file actions
333 lines (297 loc) · 10.7 KB
/
batch.py
File metadata and controls
333 lines (297 loc) · 10.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
"""Batch collation
Authors
* Aku Rouhe 2020
"""
import collections
import torch
from torch.utils.data._utils.collate import default_convert
from torch.utils.data._utils.pin_memory import (
pin_memory as recursive_pin_memory,
)
from speechbrain.utils.data_utils import (
batch_pad_right,
mod_default_collate,
recursive_to,
)
PaddedData = collections.namedtuple("PaddedData", ["data", "lengths"])
class PaddedBatch:
"""Collate_fn when examples are dicts and have variable-length sequences.
Different elements in the examples get matched by key.
All numpy tensors get converted to Torch (PyTorch default_convert)
Then, by default, all torch.Tensor valued elements get padded and support
collective pin_memory() and to() calls.
Regular Python data types are just collected in a list.
Arguments
---------
examples : list
List of example dicts, as produced by Dataloader.
padded_keys : list, None
(Optional) List of keys to pad on. If None, pad all torch.Tensors
device_prep_keys : list, None
(Optional) Only these keys participate in collective memory pinning and moving with
to().
If None, defaults to all items with torch.Tensor values.
padding_func : callable, optional
Called with a list of tensors to be padded together. Needs to return
two tensors: the padded data, and another tensor for the data lengths.
padding_kwargs : dict, None
(Optional) Extra kwargs to pass to padding_func. E.G. mode, value
This is used as the default padding configuration for all keys.
per_key_padding_kwargs : dict, None
(Optional) Per-key padding configuration. Keys in this dict should match
the keys in the examples. Each value should be a dict with padding parameters
(e.g., {'value': -100, 'mode': 'constant'}). If a key is not in this dict,
the global padding_kwargs will be used.
apply_default_convert : bool
Whether to apply PyTorch default_convert (numpy to torch recursively,
etc.) on all data. Default:True, usually does the right thing.
nonpadded_stack : bool
Whether to apply PyTorch-default_collate-like stacking on values that
didn't get padded. This stacks if it can, but doesn't error out if it
cannot. Default:True, usually does the right thing.
Example
-------
>>> batch = PaddedBatch(
... [
... {"id": "ex1", "foo": torch.Tensor([1.0])},
... {"id": "ex2", "foo": torch.Tensor([2.0, 1.0])},
... ]
... )
>>> # Attribute or key-based access:
>>> batch.id
['ex1', 'ex2']
>>> batch["id"]
['ex1', 'ex2']
>>> # torch.Tensors get padded
>>> type(batch.foo)
<class 'speechbrain.dataio.batch.PaddedData'>
>>> batch.foo.data
tensor([[1., 0.],
[2., 1.]])
>>> batch.foo.lengths
tensor([0.5000, 1.0000])
>>> # Batch supports collective operations:
>>> _ = batch.to(dtype=torch.half)
>>> batch.foo.data
tensor([[1., 0.],
[2., 1.]], dtype=torch.float16)
>>> batch.foo.lengths
tensor([0.5000, 1.0000], dtype=torch.float16)
>>> # Numpy tensors get converted to torch and padded as well:
>>> import numpy as np
>>> batch = PaddedBatch(
... [{"wav": np.asarray([1, 2, 3, 4])}, {"wav": np.asarray([1, 2, 3])}]
... )
>>> batch.wav # +ELLIPSIS
PaddedData(data=tensor([[1, 2,...
>>> # Basic stacking collation deals with non padded data:
>>> batch = PaddedBatch(
... [
... {
... "spk_id": torch.tensor([1]),
... "wav": torch.tensor([0.1, 0.0, 0.3]),
... },
... {
... "spk_id": torch.tensor([2]),
... "wav": torch.tensor([0.2, 0.3, -0.1]),
... },
... ],
... padded_keys=["wav"],
... )
>>> batch.spk_id
tensor([[1],
[2]])
>>> # And some data is left alone:
>>> batch = PaddedBatch(
... [{"text": ["Hello"]}, {"text": ["How", "are", "you?"]}]
... )
>>> batch.text
[['Hello'], ['How', 'are', 'you?']]
>>> # Per-key padding configuration:
>>> batch = PaddedBatch(
... [
... {
... "wav": torch.tensor([1, 2, 3]),
... "labels": torch.tensor([1, 2]),
... },
... {"wav": torch.tensor([4, 5]), "labels": torch.tensor([3])},
... ],
... per_key_padding_kwargs={
... "wav": {"value": 0},
... "labels": {"value": -100},
... },
... )
>>> batch.wav.data
tensor([[1, 2, 3],
[4, 5, 0]])
>>> batch.labels.data
tensor([[ 1, 2],
[ 3, -100]])
"""
def __init__(
self,
examples,
padded_keys=None,
device_prep_keys=None,
padding_func=batch_pad_right,
padding_kwargs=None,
per_key_padding_kwargs=None,
apply_default_convert=True,
nonpadded_stack=True,
):
padding_kwargs = padding_kwargs if padding_kwargs is not None else {}
per_key_padding_kwargs = (
per_key_padding_kwargs if per_key_padding_kwargs is not None else {}
)
self.__length = len(examples)
self.__keys = list(examples[0].keys())
self.__padded_keys = []
self.__device_prep_keys = []
for key in self.__keys:
values = [example[key] for example in examples]
# Default convert usually does the right thing (numpy2torch etc.)
if apply_default_convert:
values = default_convert(values)
if (padded_keys is not None and key in padded_keys) or (
padded_keys is None and isinstance(values[0], torch.Tensor)
):
# Padding and PaddedData
self.__padded_keys.append(key)
# Use per-key padding config if available, otherwise fall back to global padding_kwargs
if key in per_key_padding_kwargs:
key_padding_kwargs = per_key_padding_kwargs[key]
else:
key_padding_kwargs = padding_kwargs
padded = PaddedData(*padding_func(values, **key_padding_kwargs))
setattr(self, key, padded)
else:
# Default PyTorch collate usually does the right thing
# (convert lists of equal sized tensors to batch tensors, etc.)
if nonpadded_stack:
values = mod_default_collate(values)
setattr(self, key, values)
if (device_prep_keys is not None and key in device_prep_keys) or (
device_prep_keys is None and isinstance(values[0], torch.Tensor)
):
self.__device_prep_keys.append(key)
def __len__(self):
return self.__length
def __getitem__(self, key):
if key in self.__keys:
return getattr(self, key)
else:
raise KeyError(f"Batch doesn't have key: {key}")
def __iter__(self):
"""Iterates over the different elements of the batch.
Returns
-------
Iterator over the batch.
Example
-------
>>> batch = PaddedBatch(
... [
... {"id": "ex1", "val": torch.Tensor([1.0])},
... {"id": "ex2", "val": torch.Tensor([2.0, 1.0])},
... ]
... )
>>> ids, vals = batch
>>> ids
['ex1', 'ex2']
"""
return iter(getattr(self, key) for key in self.__keys)
def pin_memory(self):
"""In-place, moves relevant elements to pinned memory."""
for key in self.__device_prep_keys:
value = getattr(self, key)
pinned = recursive_pin_memory(value)
setattr(self, key, pinned)
return self
def to(self, *args, **kwargs):
"""In-place move/cast relevant elements.
Passes all arguments to torch.Tensor.to, see its documentation.
"""
for key in self.__device_prep_keys:
value = getattr(self, key)
moved = recursive_to(value, *args, **kwargs)
setattr(self, key, moved)
return self
def at_position(self, pos):
"""Gets the position."""
key = self.__keys[pos]
return getattr(self, key)
@property
def batchsize(self):
"""Returns the bach size"""
return self.__length
class BatchsizeGuesser:
"""Try to figure out the batchsize, but never error out
If this cannot figure out anything else, will fallback to guessing 1
Example
-------
>>> guesser = BatchsizeGuesser()
>>> # Works with simple tensors:
>>> guesser(torch.randn((2, 3)))
2
>>> # Works with sequences of tensors:
>>> guesser((torch.randn((2, 3)), torch.randint(high=5, size=(2,))))
2
>>> # Works with PaddedBatch:
>>> guesser(
... PaddedBatch([{"wav": [1.0, 2.0, 3.0]}, {"wav": [4.0, 5.0, 6.0]}])
... )
2
>>> guesser("Even weird non-batches have a fallback")
1
"""
def __init__(self):
self.method = None
def __call__(self, batch):
try:
return self.method(batch)
except: # noqa: E722
return self.find_suitable_method(batch)
def find_suitable_method(self, batch):
"""Try the different methods and note which worked"""
try:
bs = self.attr_based(batch)
self.method = self.attr_based
return bs
except: # noqa: E722
pass
try:
bs = self.torch_tensor_bs(batch)
self.method = self.torch_tensor_bs
return bs
except: # noqa: E722
pass
try:
bs = self.len_of_first(batch)
self.method = self.len_of_first
return bs
except: # noqa: E722
pass
try:
bs = self.len_of_iter_first(batch)
self.method = self.len_of_iter_first
return bs
except: # noqa: E722
pass
# Last ditch fallback:
bs = self.fallback(batch)
self.method = self.fallback(batch)
return bs
def attr_based(self, batch):
"""Implementation of attr_based."""
return batch.batchsize
def torch_tensor_bs(self, batch):
"""Implementation of torch_tensor_bs."""
return batch.shape[0]
def len_of_first(self, batch):
"""Implementation of len_of_first."""
return len(batch[0])
def len_of_iter_first(self, batch):
"""Implementation of len_of_iter_first."""
return len(next(iter(batch)))
def fallback(self, batch):
"""Implementation of fallback."""
return 1