forked from localstack/localstack
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathproxy_server.py
More file actions
201 lines (164 loc) · 6.71 KB
/
proxy_server.py
File metadata and controls
201 lines (164 loc) · 6.71 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
import gzip
import logging
import os
import select
import socket
from typing import Any, Dict, Tuple, Union
import requests
from localstack.constants import BIND_HOST, HEADER_ACCEPT_ENCODING, LOCALHOST_IP
from localstack.services.generic_proxy import ProxyListener, start_proxy_server
from localstack.utils.async_utils import ensure_event_loop
from localstack.utils.common import (
TMP_THREADS,
is_number,
new_tmp_file,
run_safe,
save_file,
start_worker_thread,
to_bytes,
)
from localstack.utils.run import FuncThread
LOG = logging.getLogger(__name__)
BUFFER_SIZE = 2 ** 10 # 1024
PortOrUrl = Union[str, int]
def start_tcp_proxy(src, dst, handler, **kwargs):
"""Run a simple TCP proxy (tunneling raw connections from src to dst), using a message handler
that can be used to intercept messages and return predefined responses for certain requests.
Arguments:
src -- Source IP address and port string. I.e.: '127.0.0.1:8000'
dst -- Destination IP address and port. I.e.: '127.0.0.1:8888'
handler -- a handler function to intercept requests (returns tuple (forward_value, response_value))
"""
src = "%s:%s" % (BIND_HOST, src) if is_number(src) else src
dst = "%s:%s" % (LOCALHOST_IP, dst) if is_number(dst) else dst
thread = kwargs.get("_thread")
def ip_to_tuple(ip):
ip, port = ip.split(":")
return ip, int(port)
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.bind(ip_to_tuple(src))
s.listen(1)
s.settimeout(10)
def handle_request(s_src, thread):
s_dst = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s_dst.connect(ip_to_tuple(dst))
sockets = [s_src, s_dst]
try:
while thread.running:
s_read, _, _ = select.select(sockets, [], [])
for s in s_read:
data = s.recv(BUFFER_SIZE)
if data in [b"", "", None]:
return
if s == s_src:
forward, response = data, None
if handler:
forward, response = handler(data)
if forward is not None:
s_dst.sendall(forward)
elif response is not None:
s_src.sendall(response)
return
elif s == s_dst:
s_src.sendall(data)
finally:
run_safe(s_src.close)
run_safe(s_dst.close)
while thread.running:
try:
src_socket, _ = s.accept()
start_worker_thread(lambda *args, _thread: handle_request(src_socket, _thread))
except socket.timeout:
pass
def start_ssl_proxy(
port: int,
target: PortOrUrl,
target_ssl=False,
client_cert_key: Tuple[str, str] = None,
asynchronous: bool = False,
fix_encoding: bool = False,
):
"""Start a proxy server that accepts SSL requests and forwards requests to a backend (either SSL or non-SSL)"""
if client_cert_key or fix_encoding:
# use a custom proxy listener, in case the user provides client certificates for authentication
if client_cert_key:
server = _do_start_ssl_proxy_with_client_auth(
port, target, client_cert_key=client_cert_key
)
else:
server = _do_start_ssl_proxy_with_listener(port, target)
if not asynchronous:
server.join()
return server
def _run(*args):
return _do_start_ssl_proxy(port, target, target_ssl=target_ssl)
if not asynchronous:
return _run()
proxy = FuncThread(_run)
TMP_THREADS.append(proxy)
proxy.start()
return proxy
def _do_start_ssl_proxy(port: int, target: PortOrUrl, target_ssl=False):
import pproxy
from localstack.services.generic_proxy import GenericProxy
if ":" not in str(target):
target = "127.0.0.1:%s" % target
LOG.debug("Starting SSL proxy server %s -> %s", port, target)
# create server and remote connection
server = pproxy.Server("secure+tunnel://0.0.0.0:%s" % port)
target_proto = "ssl+tunnel" if target_ssl else "tunnel"
remote = pproxy.Connection("%s://%s" % (target_proto, target))
args = dict(rserver=[remote], verbose=print)
# set SSL contexts
_, cert_file_name, key_file_name = GenericProxy.create_ssl_cert()
for context in pproxy.server.sslcontexts:
context.load_cert_chain(cert_file_name, key_file_name)
loop = ensure_event_loop()
handler = loop.run_until_complete(server.start_server(args))
try:
loop.run_forever()
except KeyboardInterrupt:
print("exit!")
handler.close()
loop.run_until_complete(handler.wait_closed())
loop.run_until_complete(loop.shutdown_asyncgens())
loop.close()
def _do_start_ssl_proxy_with_client_auth(
port: int, target: PortOrUrl, client_cert_key: Tuple[str, str]
):
# prepare cert files (TODO: check whether/how we can pass cert strings to requests.request(..) directly)
cert_file = client_cert_key[0]
if not os.path.exists(cert_file):
cert_file = new_tmp_file()
save_file(cert_file, client_cert_key[0])
key_file = client_cert_key[1]
if not os.path.exists(key_file):
key_file = new_tmp_file()
save_file(key_file, client_cert_key[1])
cert_params = (cert_file, key_file)
# start proxy
requests_kwargs = {"cert": cert_params}
result = _do_start_ssl_proxy_with_listener(port, target, requests_kwargs=requests_kwargs)
return result
def _do_start_ssl_proxy_with_listener(
port: int, target: PortOrUrl, requests_kwargs: Dict[str, Any] = None
):
target = f"http://localhost:{target}" if isinstance(target, int) else target
base_url = f"{'https://' if '://' not in target else ''}{target.rstrip('/')}"
requests_kwargs = requests_kwargs or {}
# define forwarding listener
class Listener(ProxyListener):
def forward_request(self, method, path, data, headers):
# send request to target
url = f"{base_url}{path}"
response = requests.request(
method=method, url=url, data=data, headers=headers, verify=False, **requests_kwargs
)
# fix encoding of response, based on Accept-Encoding header
if "gzip" in headers.get(HEADER_ACCEPT_ENCODING, "").lower():
response._content = gzip.compress(to_bytes(response._content))
response.headers["Content-Length"] = str(len(response._content))
response.headers["Content-Encoding"] = "gzip"
return response
proxy_thread = start_proxy_server(port, update_listener=Listener(), use_ssl=True)
return proxy_thread