X Tutup
Skip to content

Commit 2c331a1

Browse files
make @llnl.util.lang.memoized support kwargs (spack#21722)
* make memoized() support kwargs * add testing for @memoized
1 parent 916c94f commit 2c331a1

File tree

2 files changed

+96
-10
lines changed

2 files changed

+96
-10
lines changed

lib/spack/llnl/util/lang.py

Lines changed: 35 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,10 @@
1313
import sys
1414
from datetime import datetime, timedelta
1515

16+
import six
1617
from six import string_types
1718

18-
from llnl.util.compat import Hashable, MutableMapping, zip_longest
19+
from llnl.util.compat import MutableMapping, zip_longest
1920

2021
# Ignore emacs backups when listing modules
2122
ignore_modules = [r'^\.#', '~$']
@@ -165,22 +166,43 @@ def union_dicts(*dicts):
165166
return result
166167

167168

169+
# Used as a sentinel that disambiguates tuples passed in *args from coincidentally
170+
# matching tuples formed from kwargs item pairs.
171+
_kwargs_separator = (object(),)
172+
173+
174+
def stable_args(*args, **kwargs):
175+
"""A key factory that performs a stable sort of the parameters."""
176+
key = args
177+
if kwargs:
178+
key += _kwargs_separator + tuple(sorted(kwargs.items()))
179+
return key
180+
181+
168182
def memoized(func):
169183
"""Decorator that caches the results of a function, storing them in
170184
an attribute of that function.
171185
"""
172186
func.cache = {}
173187

174188
@functools.wraps(func)
175-
def _memoized_function(*args):
176-
if not isinstance(args, Hashable):
177-
# Not hashable, so just call the function.
178-
return func(*args)
189+
def _memoized_function(*args, **kwargs):
190+
key = stable_args(*args, **kwargs)
179191

180-
if args not in func.cache:
181-
func.cache[args] = func(*args)
182-
183-
return func.cache[args]
192+
try:
193+
return func.cache[key]
194+
except KeyError:
195+
ret = func(*args, **kwargs)
196+
func.cache[key] = ret
197+
return ret
198+
except TypeError as e:
199+
# TypeError is raised when indexing into a dict if the key is unhashable.
200+
raise six.raise_from(
201+
UnhashableArguments(
202+
"args + kwargs '{}' was not hashable for function '{}'"
203+
.format(key, func.__name__),
204+
),
205+
e)
184206

185207
return _memoized_function
186208

@@ -930,3 +952,7 @@ def nullcontext(*args, **kwargs):
930952
TODO: replace with contextlib.nullcontext() if we ever require python 3.7.
931953
"""
932954
yield
955+
956+
957+
class UnhashableArguments(TypeError):
958+
"""Raise when an @memoized function receives unhashable arg or kwarg values."""

lib/spack/spack/test/llnl/util/lang.py

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import pytest
1111

1212
import llnl.util.lang
13-
from llnl.util.lang import match_predicate, pretty_date
13+
from llnl.util.lang import match_predicate, memoized, pretty_date, stable_args
1414

1515

1616
@pytest.fixture()
@@ -205,3 +205,63 @@ def _cmp_key(self):
205205
assert hash(a) == hash(a2)
206206
assert hash(b) == hash(b)
207207
assert hash(b) == hash(b2)
208+
209+
210+
@pytest.mark.parametrize(
211+
"args1,kwargs1,args2,kwargs2",
212+
[
213+
# Ensure tuples passed in args are disambiguated from equivalent kwarg items.
214+
(('a', 3), {}, (), {'a': 3})
215+
],
216+
)
217+
def test_unequal_args(args1, kwargs1, args2, kwargs2):
218+
assert stable_args(*args1, **kwargs1) != stable_args(*args2, **kwargs2)
219+
220+
221+
@pytest.mark.parametrize(
222+
"args1,kwargs1,args2,kwargs2",
223+
[
224+
# Ensure that kwargs are stably sorted.
225+
((), {'a': 3, 'b': 4}, (), {'b': 4, 'a': 3}),
226+
],
227+
)
228+
def test_equal_args(args1, kwargs1, args2, kwargs2):
229+
assert stable_args(*args1, **kwargs1) == stable_args(*args2, **kwargs2)
230+
231+
232+
@pytest.mark.parametrize(
233+
"args, kwargs",
234+
[
235+
((1,), {}),
236+
((), {'a': 3}),
237+
((1,), {'a': 3}),
238+
],
239+
)
240+
def test_memoized(args, kwargs):
241+
@memoized
242+
def f(*args, **kwargs):
243+
return 'return-value'
244+
assert f(*args, **kwargs) == 'return-value'
245+
key = stable_args(*args, **kwargs)
246+
assert list(f.cache.keys()) == [key]
247+
assert f.cache[key] == 'return-value'
248+
249+
250+
@pytest.mark.parametrize(
251+
"args, kwargs",
252+
[
253+
(([1],), {}),
254+
((), {'a': [1]})
255+
],
256+
)
257+
def test_memoized_unhashable(args, kwargs):
258+
"""Check that an exception is raised clearly"""
259+
@memoized
260+
def f(*args, **kwargs):
261+
return None
262+
with pytest.raises(llnl.util.lang.UnhashableArguments) as exc_info:
263+
f(*args, **kwargs)
264+
exc_msg = str(exc_info.value)
265+
key = stable_args(*args, **kwargs)
266+
assert str(key) in exc_msg
267+
assert "function 'f'" in exc_msg

0 commit comments

Comments
 (0)
X Tutup