-
Notifications
You must be signed in to change notification settings - Fork 33
Expand file tree
/
Copy path__init__.py
More file actions
93 lines (72 loc) · 3.17 KB
/
__init__.py
File metadata and controls
93 lines (72 loc) · 3.17 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
"""Adds async capabilities to the base product object"""
import asyncio
from copy import deepcopy
from httpx import AsyncClient
from domaintools.base_results import Results
from domaintools.constants import RTTF_PRODUCTS_LIST, OutputFormat, HEADER_ACCEPT_KEY_CSV_FORMAT
from domaintools.exceptions import ServiceUnavailableException
class _AIter(object):
"""A wrapper to wrap an AsyncResults as an async iterable"""
__slots__ = (
"results",
"iterator",
)
def __init__(self, results):
self.results = results
self.iterator = None
def __aiter__(self):
return self
async def __anext__(self):
if self.iterator is None:
await self.results
self.iterator = self.results._items().__iter__()
try:
return self.iterator.__next__()
except StopIteration:
raise StopAsyncIteration
class AsyncResults(Results):
"""The base (abstract) DomainTools product definition with Async capabilities built in"""
def __await__(self):
return self.__awaitable__().__await__()
async def _make_async_request(self, session):
session_params_and_headers = self._get_session_params_and_headers()
headers = session_params_and_headers.get("headers")
if self.product in ["iris-investigate", "iris-enrich", "iris-detect-escalate-domains"]:
post_data = self.kwargs.copy()
post_data.update(self.api.extra_request_params)
results = await session.post(url=self.url, data=post_data, headers=headers)
elif self.product in ["iris-detect-manage-watchlist-domains"]:
patch_data = self.kwargs.copy()
patch_data.update(self.api.extra_request_params, headers=headers)
results = await session.patch(url=self.url, json=patch_data)
else:
parameters = session_params_and_headers.get("parameters")
results = await session.get(url=self.url, params=parameters, headers=headers, **self.api.extra_request_params)
if results:
self.setStatus(results.status_code, results)
if self.kwargs.get("format", "json") == "json":
self._data = results.json()
else:
self._data = results.text()
self.check_limit_exceeded()
async def __awaitable__(self):
if self._data is None:
async with AsyncClient(verify=self.api.verify_ssl, proxy=self.api.proxy_url, timeout=None) as session:
wait_time = self._wait_time()
if wait_time is None and self.api:
try:
await self._make_async_request(session)
except ServiceUnavailableException:
await asyncio.sleep(60)
self._wait_time()
await self._make_async_request(session)
else:
await asyncio.sleep(wait_time)
await self._make_async_request(session)
return self
def __aiter__(self):
return _AIter(self)
async def __aenter__(self):
return await self
async def __aexit__(self, *args):
return