forked from IronLanguages/ironpython3
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathExtensionMethodSet.cs
More file actions
490 lines (414 loc) · 19.1 KB
/
ExtensionMethodSet.cs
File metadata and controls
490 lines (414 loc) · 19.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the Apache 2.0 License.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Dynamic;
using System.Linq.Expressions;
using System.Reflection;
using System.Threading;
using IronPython.Runtime.Binding;
using IronPython.Runtime.Operations;
using IronPython.Runtime.Types;
using Microsoft.Scripting.Actions;
using Microsoft.Scripting.Utils;
namespace IronPython.Runtime {
/// <summary>
/// Represents the set of extension methods which are loaded into a module.
///
/// This set is immutable (as far the external viewer is considered). When a
/// new extension method set is loaded into a module we create a new ExtensionMethodsSet object.
///
/// Multiple modules which have the same set of extension methods use the same set.
/// </summary>
internal sealed class ExtensionMethodSet : IEquatable<ExtensionMethodSet> {
private PythonExtensionBinder _extBinder;
private Dictionary<Assembly, AssemblyLoadInfo>/*!*/ _loadedAssemblies;
private readonly int _id;
private static int _curId;
public static readonly ExtensionMethodSet Empty = new ExtensionMethodSet();
public const int OutOfIds = Int32.MinValue;
private ExtensionMethodSet(Dictionary<Assembly, AssemblyLoadInfo> dict) {
_loadedAssemblies = dict;
if (_curId < 0 || (_id = Interlocked.Increment(ref _curId)) < 0) {
// overflow, we ran out of ids..
_id = OutOfIds;
}
}
public BindingRestrictions GetRestriction(Expression codeContext) {
BindingRestrictions extCheck;
if (_id == ExtensionMethodSet.OutOfIds) {
extCheck = BindingRestrictions.GetInstanceRestriction(
Expression.Call(
typeof(PythonOps).GetMethod(nameof(PythonOps.GetExtensionMethodSet)),
codeContext
),
this
);
} else {
extCheck = BindingRestrictions.GetExpressionRestriction(
Expression.Call(
typeof(PythonOps).GetMethod(nameof(PythonOps.IsExtensionSet)),
codeContext,
Expression.Constant(_id)
)
);
}
return extCheck;
}
private ExtensionMethodSet() {
_loadedAssemblies = new Dictionary<Assembly, AssemblyLoadInfo>();
}
/// <summary>
/// Tracks the extension types that are loaded for a given assembly.
///
/// We can have either types, namespaces, or a full assembly added as a reference.
///
/// When the user just adds types we just add them to the type hash set.
///
/// When the user adds namespaces we add them to the namespaces hashset. On the
/// next lookup we'll lazily load the types from that namespace and put them in Types.
///
/// When the user adds assemblies we set the value to the NotYetLoadedButFullAssembly
/// value. The next load request will load the types from that namespace and put them
/// in Types. When we do that we'll mark the assembly as FullyLoaded so we don't
/// have to go through that again if the user adds a namespace.
/// </summary>
private sealed class AssemblyLoadInfo : IEquatable<AssemblyLoadInfo> {
private static IEqualityComparer<HashSet<PythonType>> TypeComparer = CollectionUtils.CreateSetComparer<PythonType>();
private static IEqualityComparer<HashSet<string>> StringComparer = CollectionUtils.CreateSetComparer<string>();
public HashSet<PythonType> Types; // set of types loaded from the assembly
public HashSet<string> Namespaces; // list of namespaces which should be loaded from the assembly.
public bool IsFullAssemblyLoaded; // all types in the assembly are loaded and stored in the Types set.
private readonly Assembly _asm;
public AssemblyLoadInfo(Assembly asm) {
_asm = asm;
}
public override int GetHashCode() {
if (_asm != null) {
return _asm.GetHashCode();
}
return 0;
}
public override bool Equals(object obj) {
if (obj is AssemblyLoadInfo asmLoadInfo) {
return Equals(asmLoadInfo);
}
return false;
}
public AssemblyLoadInfo EnsureTypesLoaded() {
if (Namespaces != null || Types == null) {
HashSet<PythonType> loadedTypes = new HashSet<PythonType>();
var ns = Namespaces;
if (ns != null) {
foreach (var type in _asm.GetExportedTypes()) {
if (type.IsExtension() && ns.Contains(type.Namespace)) {
loadedTypes.Add(DynamicHelpers.GetPythonTypeFromType(type));
}
}
#if FEATURE_ASSEMBLY_GETFORWARDEDTYPES
Type[] forwardedTypes;
try {
forwardedTypes = _asm.GetForwardedTypes();
} catch (ReflectionTypeLoadException ex) {
forwardedTypes = ex.Types;
}
foreach (var type in forwardedTypes) {
if (type != null && type.IsExtension() && ns.Contains(type.Namespace)) {
loadedTypes.Add(DynamicHelpers.GetPythonTypeFromType(type));
}
}
#endif
}
var info = new AssemblyLoadInfo(_asm);
info.Types = loadedTypes;
if (ns == null) {
info.IsFullAssemblyLoaded = true;
}
return info;
}
return this;
}
#region IEquatable<AssemblyLoadInfo> Members
public bool Equals(AssemblyLoadInfo other) {
if ((object)this == (object)other) {
return true;
} else if (other == null || _asm != other._asm) {
return false;
}
if (IsFullAssemblyLoaded && other.IsFullAssemblyLoaded) {
// full assembly is loaded for both
return true;
}
return TypeComparer.Equals(Types, other.Types) && StringComparer.Equals(Namespaces, other.Namespaces);
}
#endregion
}
/// <summary>
/// Returns all of the extension methods with the given name.
/// </summary>
public IEnumerable<MethodInfo> GetExtensionMethods(string/*!*/ name) {
Assert.NotNull(name);
lock (this) {
EnsureLoaded();
foreach (var keyValue in _loadedAssemblies) {
AssemblyLoadInfo info = keyValue.Value;
Debug.Assert(info.Types != null);
foreach (var type in info.Types) {
List<MethodInfo> methods;
if (type.ExtensionMethods.TryGetValue(name, out methods)) {
foreach (var method in methods) {
yield return method;
}
}
}
}
}
}
private void EnsureLoaded() {
bool hasUnloaded = false;
foreach (AssemblyLoadInfo info in _loadedAssemblies.Values) {
if (info.Namespaces != null ||
info.Types == null) {
hasUnloaded = true;
}
}
if (hasUnloaded) {
LoadAllTypes();
}
}
/// <summary>
/// Returns all of the extension methods which are applicable for the given type.
/// </summary>
public IEnumerable<MethodInfo> GetExtensionMethods(PythonType type) {
lock (this) {
EnsureLoaded();
foreach (var keyValue in _loadedAssemblies) {
AssemblyLoadInfo info = keyValue.Value;
Debug.Assert(info.Types != null);
foreach (var containingType in info.Types) {
foreach(var methodList in containingType.ExtensionMethods.Values) {
foreach (var method in methodList) {
var methodParams = method.GetParameters();
if (methodParams.Length == 0) {
continue;
}
if (PythonExtensionBinder.IsApplicableExtensionMethod(type.UnderlyingSystemType, methodParams[0].ParameterType)) {
yield return method;
}
}
}
}
}
}
}
private void LoadAllTypes() {
var newAssemblies = new Dictionary<Assembly, AssemblyLoadInfo>(_loadedAssemblies.Count);
foreach (var keyValue in _loadedAssemblies) {
AssemblyLoadInfo info = keyValue.Value;
var asm = keyValue.Key;
newAssemblies[asm] = info.EnsureTypesLoaded();
}
_loadedAssemblies = newAssemblies;
}
public static ExtensionMethodSet AddType(PythonContext context, ExtensionMethodSet/*!*/ existingSet, PythonType/*!*/ type) {
Assert.NotNull(existingSet, type);
lock (existingSet) {
AssemblyLoadInfo assemblyLoadInfo;
if (existingSet._loadedAssemblies.TryGetValue(type.UnderlyingSystemType.Assembly, out assemblyLoadInfo)) {
if (assemblyLoadInfo.IsFullAssemblyLoaded ||
(assemblyLoadInfo.Types != null && assemblyLoadInfo.Types.Contains(type)) ||
(assemblyLoadInfo.Namespaces != null && assemblyLoadInfo.Namespaces.Contains(type.UnderlyingSystemType.Namespace))) {
// type is already in this set.
return existingSet;
}
}
var dict = NewInfoOrCopy(existingSet);
if (!dict.TryGetValue(type.UnderlyingSystemType.Assembly, out assemblyLoadInfo)) {
dict[type.UnderlyingSystemType.Assembly] = assemblyLoadInfo = new AssemblyLoadInfo(type.UnderlyingSystemType.Assembly);
}
if (assemblyLoadInfo.Types == null) {
assemblyLoadInfo.Types = new HashSet<PythonType>();
}
assemblyLoadInfo.Types.Add(type);
return context.UniqifyExtensions(new ExtensionMethodSet(dict));
}
}
public static ExtensionMethodSet AddNamespace(PythonContext context, ExtensionMethodSet/*!*/ existingSet, NamespaceTracker/*!*/ ns) {
Assert.NotNull(existingSet, ns);
lock (existingSet) {
AssemblyLoadInfo asmInfo;
Dictionary<Assembly, AssemblyLoadInfo> newDict = null;
foreach (var assembly in ns.PackageAssemblies) {
if (existingSet != null && existingSet._loadedAssemblies.TryGetValue(assembly, out asmInfo)) {
if (asmInfo.IsFullAssemblyLoaded) {
// full assembly is already in this set.
continue;
}
if (asmInfo.Namespaces == null || !asmInfo.Namespaces.Contains(ns.Name)) {
if (newDict == null) {
newDict = NewInfoOrCopy(existingSet);
}
if (newDict[assembly].Namespaces == null) {
newDict[assembly].Namespaces = new HashSet<string>();
}
newDict[assembly].Namespaces.Add(ns.Name);
}
} else {
if (newDict == null) {
newDict = NewInfoOrCopy(existingSet);
}
var newAsmInfo = newDict[assembly] = new AssemblyLoadInfo(assembly);
newAsmInfo.Namespaces = new HashSet<string>();
newAsmInfo.Namespaces.Add(ns.Name);
}
}
if (newDict != null) {
return context.UniqifyExtensions(new ExtensionMethodSet(newDict));
}
return existingSet;
}
}
public static bool operator == (ExtensionMethodSet set1, ExtensionMethodSet set2) {
if ((object)set1 != (object)null) {
return set1.Equals(set2);
}
return (object)set2 == (object)null;
}
public static bool operator != (ExtensionMethodSet set1, ExtensionMethodSet set2) {
return !(set1 == set2);
}
public PythonExtensionBinder GetBinder(PythonContext/*!*/ context) {
Debug.Assert(context != null);
if (_extBinder == null) {
_extBinder = new PythonExtensionBinder(context.Binder, this);
}
return _extBinder;
}
private static Dictionary<Assembly, AssemblyLoadInfo> NewInfoOrCopy(ExtensionMethodSet/*!*/ existingSet) {
var dict = new Dictionary<Assembly, AssemblyLoadInfo>();
if (existingSet != null) {
foreach (var keyValue in existingSet._loadedAssemblies) {
var assemblyLoadInfo = new AssemblyLoadInfo(keyValue.Key);
if (keyValue.Value.Namespaces != null) {
assemblyLoadInfo.Namespaces = new HashSet<string>(keyValue.Value.Namespaces);
}
if (keyValue.Value.Types != null) {
assemblyLoadInfo.Types = new HashSet<PythonType>(keyValue.Value.Types);
}
dict[keyValue.Key] = assemblyLoadInfo;
}
}
return dict;
}
public int Id { get { return _id; } }
public override bool Equals(object obj) {
ExtensionMethodSet other = obj as ExtensionMethodSet;
if (other != null) {
return this.Equals(other);
}
return false;
}
#region IEquatable<ExtensionMethodSet> Members
public bool Equals(ExtensionMethodSet other) {
if (other == null) {
return false;
} else if ((object)this == (object)other) {
// object identity
return true;
} else if (_loadedAssemblies.Count != other._loadedAssemblies.Count) {
// different assembly set
return false;
}
foreach (var asmAndInfo in _loadedAssemblies) {
var asm = asmAndInfo.Key;
var info = asmAndInfo.Value;
AssemblyLoadInfo otherAsmInfo;
if (!other._loadedAssemblies.TryGetValue(asm, out otherAsmInfo)) {
// different assembly set.
return false;
}
if (otherAsmInfo != info) {
return false;
}
}
return true;
}
public override int GetHashCode() {
int res = 6551;
foreach (var asm in _loadedAssemblies.Keys) {
res ^= asm.GetHashCode();
}
return res;
}
#endregion
#if FALSE
// simpler but less lazy...
public ExtensionMethodSet(HashSet<Type> newTypes, HashSet<NamespaceTracker> newNspaces, HashSet<Assembly> newAssms) {
_types = newTypes;
_namespaces = newNspaces;
_assemblies = newAssms;
}
private HashSet<Type> _types;
private HashSet<Assembly> _assemblies;
private HashSet<NamespaceTracker> _namespaces;
internal static ExtensionMethodSet AddExtensions(ExtensionMethodSet extensionMethodSet, object[] extensions) {
var res = extensionMethodSet;
var newTypes = new HashSet<Type>();
var newAssms = new HashSet<Assembly>();
var newNspaces = new HashSet<NamespaceTracker>();
foreach (object o in extensions) {
PythonType type = o as PythonType;
if (type != null) {
if (res._types != null && res._types.Contains(type)) {
continue;
}
newTypes.Add(type.UnderlyingSystemType);
}
Assembly asm = o as Assembly;
if (asm != null) {
if (res._assemblies != null && res._assemblies.Contains(asm)) {
continue;
}
foreach (var method in ReflectionUtils.GetVisibleExtensionMethods(asm)) {
if (newTypes.Contains(method.DeclaringType) ||
(res._types != null && res._types.Contains(method.DeclaringType)) ) {
continue;
}
newTypes.Add(method.DeclaringType);
}
}
NamespaceTracker ns = o as NamespaceTracker;
if (ns != null) {
if (res._namespaces != null && res._namespaces.Contains(ns)) {
continue;
}
foreach (var packageAsm in ns.PackageAssemblies) {
foreach (var method in ReflectionUtils.GetVisibleExtensionMethods(packageAsm)) {
if (newTypes.Contains(method.DeclaringType) ||
(res._types != null && res._types.Contains(method.DeclaringType))) {
continue;
}
newTypes.Add(method.DeclaringType);
}
}
}
}
if (newTypes.Count > 0) {
if (res._types != null) {
newTypes.UnionWith(res._types);
}
if (res._namespaces != null) {
newNspaces.UnionWith(res._namespaces);
}
if (res._assemblies != null) {
newAssms.UnionWith(res._assemblies);
}
return new ExtensionMethodSet(newTypes, newNspaces, newAssms);
}
return res;
}
#endif
}
}