forked from jumpserver/jumpserver
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcache.py
More file actions
218 lines (172 loc) · 6.54 KB
/
cache.py
File metadata and controls
218 lines (172 loc) · 6.54 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
import time
from common.utils.lock import DistributedLock
from common.utils.connection import get_redis_client
from common.utils import lazyproperty
from common.utils import get_logger
logger = get_logger(__file__)
class ComputeLock(DistributedLock):
"""
需要重建缓存的时候加上该锁,避免重复计算
"""
def __init__(self, key):
name = f'compute:{key}'
super().__init__(name=name)
class CacheFieldBase:
field_type = str
def __init__(self, queryset=None, compute_func_name=None):
assert None in (queryset, compute_func_name), f'queryset and compute_func_name can only have one'
self.compute_func_name = compute_func_name
self.queryset = queryset
class CharField(CacheFieldBase):
field_type = str
class IntegerField(CacheFieldBase):
field_type = int
class CacheType(type):
def __new__(cls, name, bases, attrs: dict):
to_update = {}
field_desc_mapper = {}
for k, v in attrs.items():
if isinstance(v, CacheFieldBase):
desc = CacheValueDesc(k, v)
to_update[k] = desc
field_desc_mapper[k] = desc
attrs.update(to_update)
attrs['field_desc_mapper'] = field_desc_mapper
return type.__new__(cls, name, bases, attrs)
class Cache(metaclass=CacheType):
field_desc_mapper: dict
timeout = None
def __init__(self):
self._data = None
self.redis = get_redis_client()
def __getitem__(self, item):
return self.field_desc_mapper[item]
def __contains__(self, item):
return item in self.field_desc_mapper
def get_field(self, name):
return self.field_desc_mapper[name]
@property
def fields(self):
return self.field_desc_mapper.values()
@property
def field_names(self):
names = self.field_desc_mapper.keys()
return names
@lazyproperty
def key_suffix(self):
return self.get_key_suffix()
@property
def key_prefix(self):
clz = self.__class__
return f'cache.{clz.__module__}.{clz.__name__}'
@property
def key(self):
return f'{self.key_prefix}.{self.key_suffix}'
@property
def data(self):
if self._data is None:
data = self.load_data_from_db()
if not data:
with ComputeLock(self.key):
data = self.load_data_from_db()
if not data:
# 缓存中没有数据时,去数据库获取
self.init_all_values()
return self._data
def to_internal_value(self, data: dict):
internal_data = {}
for k, v in data.items():
field = k.decode()
if field in self:
value = self[field].to_internal_value(v.decode())
internal_data[field] = value
else:
logger.warn(f'Cache got invalid field: '
f'key={self.key} '
f'invalid_field={field} '
f'valid_fields={self.field_names}')
return internal_data
def load_data_from_db(self) -> dict:
data = self.redis.hgetall(self.key)
logger.debug(f'Get data from cache: key={self.key} data={data}')
if data:
data = self.to_internal_value(data)
self._data = data
return data
def save_data_to_db(self, data):
logger.debug(f'Set data to cache: key={self.key} data={data}')
self.redis.hset(self.key, mapping=data)
self.load_data_from_db()
def compute_values(self, *fields):
field_objs = []
for field in fields:
field_objs.append(self[field])
data = {
field_obj.field_name: field_obj.compute_value(self)
for field_obj in field_objs
}
return data
def init_all_values(self):
t_start = time.time()
logger.debug(f'Start init cache: key={self.key}')
data = self.compute_values(*self.field_names)
self.save_data_to_db(data)
logger.debug(f'End init cache: cost={time.time()-t_start} key={self.key}')
return data
def refresh(self, *fields):
if not fields:
# 没有指定 field 要刷新所有的值
self.init_all_values()
return
data = self.load_data_from_db()
if not data:
# 缓存中没有数据,设置所有的值
self.init_all_values()
return
refresh_values = self.compute_values(*fields)
self.save_data_to_db(refresh_values)
def get_key_suffix(self):
raise NotImplementedError
def reload(self):
self._data = None
def expire(self, *fields):
self._data = None
if not fields:
self.redis.delete(self.key)
else:
self.redis.hdel(self.key, *fields)
logger.debug(f'Expire cached fields: key={self.key} fields={fields}')
class CacheValueDesc:
def __init__(self, field_name, field_type: CacheFieldBase):
self.field_name = field_name
self.field_type = field_type
self._data = None
def __repr__(self):
clz = self.__class__
return f'<{clz.__name__} {self.field_name} {self.field_type}>'
def __get__(self, instance: Cache, owner):
if instance is None:
return self
if self.field_name not in instance.data:
instance.refresh(self.field_name)
# 防止边界情况没有值,报错
value = instance.data.get(self.field_name)
return value
def compute_value(self, instance: Cache):
t_start = time.time()
logger.debug(f'Start compute cache field: field={self.field_name} key={instance.key}')
if self.field_type.queryset is not None:
new_value = self.field_type.queryset.count()
else:
compute_func_name = self.field_type.compute_func_name
if not compute_func_name:
compute_func_name = f'compute_{self.field_name}'
compute_func = getattr(instance, compute_func_name, None)
assert compute_func is not None, \
f'Define `{compute_func_name}` method in {instance.__class__}'
new_value = compute_func()
new_value = self.field_type.field_type(new_value)
logger.debug(f'End compute cache field: cost={time.time()-t_start} field={self.field_name} value={new_value} key={instance.key}')
return new_value
def to_internal_value(self, value):
return self.field_type.field_type(value)