@@ -65,50 +65,45 @@ class Direct(BaseProtocol):
6565 pass
6666
6767class Trojan (BaseProtocol ):
68- async def guess (self , reader , auth , authtable , ** kw ):
68+ async def guess (self , reader , users , ** kw ):
6969 header = await reader .read_w (56 )
70- toauth = hashlib .sha224 (auth or b'' ).hexdigest ()
71- if header == toauth .encode ():
72- authtable .set_authed ()
73- return True
70+ if users :
71+ for user in users :
72+ if hashlib .sha224 (user ).hexdigest ().encode () == header :
73+ return user
74+ else :
75+ if hashlib .sha224 (b'' ).hexdigest ().encode () == header :
76+ return True
7477 reader .rollback (header )
75- async def accept (self , reader , ** kw ):
78+ async def accept (self , reader , user , ** kw ):
7679 assert await reader .read_n (3 ) == b'\x0d \x0a \x01 '
7780 host_name , port , _ = await socks_address_stream (reader , (await reader .read_n (1 ))[0 ])
7881 assert await reader .read_n (2 ) == b'\x0d \x0a '
79- return host_name , port
82+ return user , host_name , port
8083 async def connect (self , reader_remote , writer_remote , rauth , host_name , port , ** kw ):
8184 toauth = hashlib .sha224 (rauth or b'' ).hexdigest ().encode ()
8285 writer_remote .write (toauth + b'\x0d \x0a \x01 \x03 ' + packstr (host_name .encode ()) + port .to_bytes (2 , 'big' ) + b'\x0d \x0a ' )
8386
8487class SSR (BaseProtocol ):
85- async def guess (self , reader , auth , authtable , ** kw ):
86- if auth :
87- header = await reader .read_w (len (auth ))
88- if header != auth :
89- reader .rollback (header )
90- return False
91- authtable .set_authed ()
88+ async def guess (self , reader , users , ** kw ):
89+ if users :
90+ header = await reader .read_w (max (len (i ) for i in users ))
91+ reader .rollback (header )
92+ user = next (filter (lambda x : x == header [:len (x )], users ), None )
93+ if user is None :
94+ return
95+ await reader .read_n (len (user ))
96+ return user
9297 header = await reader .read_w (1 )
9398 reader .rollback (header )
94- return header [0 ] in (1 , 3 , 4 )
95- async def accept (self , reader , ** kw ):
99+ return header [0 ] in (1 , 3 , 4 , 17 , 19 , 20 )
100+ async def accept (self , reader , user , ** kw ):
96101 host_name , port , data = await socks_address_stream (reader , (await reader .read_n (1 ))[0 ])
97- return host_name , port
102+ return user , host_name , port
98103 async def connect (self , reader_remote , writer_remote , rauth , host_name , port , ** kw ):
99104 writer_remote .write (rauth + b'\x03 ' + packstr (host_name .encode ()) + port .to_bytes (2 , 'big' ))
100105
101- class SS (BaseProtocol ):
102- async def guess (self , reader , auth , authtable , ** kw ):
103- if auth :
104- header = await reader .read_w (len (auth ))
105- if header != auth :
106- reader .rollback (header )
107- return False
108- authtable .set_authed ()
109- header = await reader .read_w (1 )
110- reader .rollback (header )
111- return header [0 ] in (1 , 3 , 4 , 17 , 19 , 20 )
106+ class SS (SSR ):
112107 def patch_ota_reader (self , cipher , reader ):
113108 chunk_id , data_len , _buffer = 0 , None , bytearray ()
114109 def decrypt (s ):
@@ -143,7 +138,7 @@ def write(data, o=writer.write):
143138 chunk_id += 1
144139 return o (len (data ).to_bytes (2 , 'big' ) + checksum [:10 ] + data )
145140 writer .write = write
146- async def accept (self , reader , reader_cipher , ** kw ):
141+ async def accept (self , reader , user , reader_cipher , ** kw ):
147142 header = await reader .read_n (1 )
148143 ota = (header [0 ] & 0x10 == 0x10 )
149144 host_name , port , data = await socks_address_stream (reader , header [0 ])
@@ -152,7 +147,7 @@ async def accept(self, reader, reader_cipher, **kw):
152147 checksum = hmac .new (reader_cipher .iv + reader_cipher .key , header + data , hashlib .sha1 ).digest ()
153148 assert checksum [:10 ] == await reader .read_n (10 ), 'Unknown OTA checksum'
154149 self .patch_ota_reader (reader_cipher , reader )
155- return host_name , port
150+ return user , host_name , port
156151 async def connect (self , reader_remote , writer_remote , rauth , host_name , port , writer_cipher_r , ** kw ):
157152 writer_remote .write (rauth )
158153 if writer_cipher_r and writer_cipher_r .ota :
@@ -162,15 +157,19 @@ async def connect(self, reader_remote, writer_remote, rauth, host_name, port, wr
162157 self .patch_ota_writer (writer_cipher_r , writer_remote )
163158 else :
164159 writer_remote .write (b'\x03 ' + packstr (host_name .encode ()) + port .to_bytes (2 , 'big' ))
165- def udp_accept (self , data , auth , ** kw ):
160+ def udp_accept (self , data , users , ** kw ):
166161 reader = io .BytesIO (data )
167- if auth and reader .read (len (auth )) != auth :
168- return
162+ user = True
163+ if users :
164+ user = next (filter (lambda i : data [:len (i )]== i , users ), None )
165+ if user is None :
166+ return
167+ reader .read (len (user ))
169168 n = reader .read (1 )[0 ]
170169 if n not in (1 , 3 , 4 ):
171170 return
172171 host_name , port = socks_address (reader , n )
173- return host_name , port , reader .read ()
172+ return user , host_name , port , reader .read ()
174173 def udp_unpack (self , data ):
175174 reader = io .BytesIO (data )
176175 n = reader .read (1 )[0 ]
@@ -191,17 +190,20 @@ async def guess(self, reader, **kw):
191190 if header == b'\x04 ' :
192191 return True
193192 reader .rollback (header )
194- async def accept (self , reader , writer , auth , authtable , ** kw ):
193+ async def accept (self , reader , user , writer , users , authtable , ** kw ):
195194 assert await reader .read_n (1 ) == b'\x01 '
196195 port = int .from_bytes (await reader .read_n (2 ), 'big' )
197196 ip = await reader .read_n (4 )
198197 userid = (await reader .read_until (b'\x00 ' ))[:- 1 ]
199- if auth :
200- if auth != userid and not authtable .authed ():
201- raise Exception (f'Unauthorized SOCKS { auth } ' )
202- authtable .set_authed ()
198+ user = authtable .authed ()
199+ if users :
200+ if userid in users :
201+ user = userid
202+ elif not user :
203+ raise Exception (f'Unauthorized SOCKS { userid } ' )
204+ authtable .set_authed (user )
203205 writer .write (b'\x00 \x5a ' + port .to_bytes (2 , 'big' ) + ip )
204- return socket .inet_ntoa (ip ), port
206+ return user , socket .inet_ntoa (ip ), port
205207 async def connect (self , reader_remote , writer_remote , rauth , host_name , port , ** kw ):
206208 ip = socket .inet_aton ((await asyncio .get_event_loop ().getaddrinfo (host_name , port , family = socket .AF_INET ))[0 ][4 ][0 ])
207209 writer_remote .write (b'\x04 \x01 ' + port .to_bytes (2 , 'big' ) + ip + rauth + b'\x00 ' )
@@ -214,25 +216,31 @@ async def guess(self, reader, **kw):
214216 if header == b'\x05 ' :
215217 return True
216218 reader .rollback (header )
217- async def accept (self , reader , writer , auth , authtable , ** kw ):
219+ async def accept (self , reader , user , writer , users , authtable , ** kw ):
218220 methods = await reader .read_n ((await reader .read_n (1 ))[0 ])
219- if auth and (b'\x00 ' not in methods or not authtable .authed ()):
221+ user = authtable .authed ()
222+ if users and (not user or b'\x00 ' not in methods ):
223+ if b'\x02 ' not in methods :
224+ raise Exception (f'Unauthorized SOCKS' )
220225 writer .write (b'\x05 \x02 ' )
221226 assert (await reader .read_n (1 ))[0 ] == 1 , 'Unknown SOCKS auth'
222227 u = await reader .read_n ((await reader .read_n (1 ))[0 ])
223228 p = await reader .read_n ((await reader .read_n (1 ))[0 ])
224- if u + b':' + p != auth :
229+ user = u + b':' + p
230+ if user not in users :
225231 raise Exception (f'Unauthorized SOCKS { u } :{ p } ' )
226232 writer .write (b'\x01 \x00 ' )
233+ elif users and not user :
234+ raise Exception (f'Unauthorized SOCKS' )
227235 else :
228236 writer .write (b'\x05 \x00 ' )
229- if auth :
230- authtable .set_authed ()
237+ if users :
238+ authtable .set_authed (user )
231239 assert await reader .read_n (3 ) == b'\x05 \x01 \x00 ' , 'Unknown SOCKS protocol'
232240 header = await reader .read_n (1 )
233241 host_name , port , data = await socks_address_stream (reader , header [0 ])
234242 writer .write (b'\x05 \x00 \x00 ' + header + data )
235- return host_name , port
243+ return user , host_name , port
236244 async def connect (self , reader_remote , writer_remote , rauth , host_name , port , ** kw ):
237245 if rauth :
238246 writer_remote .write (b'\x05 \x01 \x02 ' )
@@ -254,7 +262,7 @@ def udp_accept(self, data, **kw):
254262 if n not in (1 , 3 , 4 ):
255263 return
256264 host_name , port = socks_address (reader , n )
257- return host_name , port , reader .read ()
265+ return True , host_name , port , reader .read ()
258266 def udp_connect (self , rauth , host_name , port , data , ** kw ):
259267 return b'\x00 \x00 \x00 \x03 ' + packstr (host_name .encode ()) + port .to_bytes (2 , 'big' ) + data
260268
@@ -263,7 +271,7 @@ async def guess(self, reader, **kw):
263271 header = await reader .read_w (4 )
264272 reader .rollback (header )
265273 return header in (b'GET ' , b'HEAD' , b'POST' , b'PUT ' , b'DELE' , b'CONN' , b'OPTI' , b'TRAC' , b'PATC' )
266- async def accept (self , reader , writer , auth , authtable , httpget = None , ** kw ):
274+ async def accept (self , reader , user , writer , users , authtable , httpget = None , ** kw ):
267275 lines = await reader .read_until (b'\r \n \r \n ' )
268276 headers = lines [:- 4 ].decode ().split ('\r \n ' )
269277 method , path , ver = HTTP_LINE .match (headers .pop (0 )).groups ()
@@ -273,24 +281,26 @@ async def accept(self, reader, writer, auth, authtable, httpget=None, **kw):
273281 if method == 'GET' and not url .hostname and httpget :
274282 for path , text in httpget .items ():
275283 if url .path == path :
276- authtable .set_authed ()
284+ # authtable.set_authed()
277285 if type (text ) is str :
278286 text = (text % dict (host = headers ["Host" ])).encode ()
279287 writer .write (f'{ ver } 200 OK\r \n Connection: close\r \n Content-Type: text/plain\r \n Cache-Control: max-age=900\r \n Content-Length: { len (text )} \r \n \r \n ' .encode () + text )
280288 await writer .drain ()
281289 raise Exception ('Connection closed' )
282290 raise Exception (f'404 { method } { url .path } ' )
283- if auth :
291+ if users :
284292 pauth = headers .get ('Proxy-Authorization' , None )
285- httpauth = 'Basic ' + base64 .b64encode (auth ).decode ()
286- if not authtable .authed () and pauth != httpauth :
287- writer .write (f'{ ver } 407 Proxy Authentication Required\r \n Connection: close\r \n Proxy-Authenticate: Basic realm="simple"\r \n \r \n ' .encode ())
288- raise Exception ('Unauthorized HTTP' )
289- authtable .set_authed ()
293+ user = authtable .authed ()
294+ if not user :
295+ user = next (filter (lambda i : ('Basic ' + base64 .b64encode (i ).decode ()) == pauth , users ), None )
296+ if user is None :
297+ writer .write (f'{ ver } 407 Proxy Authentication Required\r \n Connection: close\r \n Proxy-Authenticate: Basic realm="simple"\r \n \r \n ' .encode ())
298+ raise Exception ('Unauthorized HTTP' )
299+ authtable .set_authed (user )
290300 if method == 'CONNECT' :
291301 host_name , port = path .rsplit (':' , 1 )
292302 port = int (port )
293- return host_name , port , f'{ ver } 200 OK\r \n Connection: close\r \n \r \n ' .encode ()
303+ return user , host_name , port , f'{ ver } 200 OK\r \n Connection: close\r \n \r \n ' .encode ()
294304 else :
295305 url = urllib .parse .urlparse (path )
296306 if ':' in url .netloc :
@@ -299,7 +309,7 @@ async def accept(self, reader, writer, auth, authtable, httpget=None, **kw):
299309 else :
300310 host_name , port = url .netloc , 80
301311 newpath = url ._replace (netloc = '' , scheme = '' ).geturl ()
302- return host_name , port , b'' , f'{ method } { newpath } { ver } \r \n { lines } \r \n \r \n ' .encode ()
312+ return user , host_name , port , b'' , f'{ method } { newpath } { ver } \r \n { lines } \r \n \r \n ' .encode ()
303313 async def connect (self , reader_remote , writer_remote , rauth , host_name , port , myhost , ** kw ):
304314 writer_remote .write (f'CONNECT { host_name } :{ port } HTTP/1.1\r \n Host: { myhost } ' .encode () + (b'\r \n Proxy-Authorization: Basic ' + base64 .b64encode (rauth ) if rauth else b'' ) + b'\r \n \r \n ' )
305315 await reader_remote .read_until (b'\r \n \r \n ' )
@@ -358,12 +368,12 @@ class Transparent(BaseProtocol):
358368 async def guess (self , reader , sock , ** kw ):
359369 remote = self .query_remote (sock )
360370 return remote is not None and sock .getsockname () != remote
361- async def accept (self , reader , sock , ** kw ):
371+ async def accept (self , reader , user , sock , ** kw ):
362372 remote = self .query_remote (sock )
363- return remote [0 ], remote [1 ]
364- def udp_accept (self , data , auth , sock , ** kw ):
373+ return user , remote [0 ], remote [1 ]
374+ def udp_accept (self , data , sock , ** kw ):
365375 remote = self .query_remote (sock )
366- return remote [0 ], remote [1 ], data
376+ return True , remote [0 ], remote [1 ], data
367377
368378SO_ORIGINAL_DST = 80
369379SOL_IPV6 = 41
@@ -457,20 +467,22 @@ def write(data, o=writer.write):
457467 else :
458468 return o (b'\x02 ' + (bytes ([data_len ]) if data_len < 126 else b'\x7e ' + data_len .to_bytes (2 , 'big' ) if data_len < 65536 else b'\x7f ' + data_len .to_bytes (4 , 'big' )) + data )
459469 writer .write = write
460- async def accept (self , reader , writer , auth , authtable , sock , ** kw ):
470+ async def accept (self , reader , user , writer , users , authtable , sock , ** kw ):
461471 lines = await reader .read_until (b'\r \n \r \n ' )
462472 headers = lines [:- 4 ].decode ().split ('\r \n ' )
463473 method , path , ver = HTTP_LINE .match (headers .pop (0 )).groups ()
464474 lines = '\r \n ' .join (i for i in headers if not i .startswith ('Proxy-' ))
465475 headers = dict (i .split (': ' , 1 ) for i in headers if ': ' in i )
466476 url = urllib .parse .urlparse (path )
467- if auth :
477+ if users :
468478 pauth = headers .get ('Proxy-Authorization' , None )
469- httpauth = 'Basic ' + base64 .b64encode (auth ).decode ()
470- if not authtable .authed () and pauth != httpauth :
471- writer .write (f'{ ver } 407 Proxy Authentication Required\r \n Connection: close\r \n Proxy-Authenticate: Basic realm="simple"\r \n \r \n ' .encode ())
472- raise Exception ('Unauthorized WebSocket' )
473- authtable .set_authed ()
479+ user = authtable .authed ()
480+ if not user :
481+ user = next (filter (lambda i : ('Basic ' + base64 .b64encode (i ).decode ()) == pauth , users ), None )
482+ if user is None :
483+ writer .write (f'{ ver } 407 Proxy Authentication Required\r \n Connection: close\r \n Proxy-Authenticate: Basic realm="simple"\r \n \r \n ' .encode ())
484+ raise Exception ('Unauthorized WebSocket' )
485+ authtable .set_authed (user )
474486 if method != 'GET' :
475487 raise Exception (f'Unsupported method { method } ' )
476488 if headers .get ('Sec-WebSocket-Key' , None ) is None :
@@ -485,7 +497,7 @@ async def accept(self, reader, writer, auth, authtable, sock, **kw):
485497 dst = sock .getsockname ()
486498 host = host or dst [0 ]
487499 port = int (port ) if port else dst [1 ]
488- return host , port
500+ return user , host , port
489501 async def connect (self , reader_remote , writer_remote , rauth , host_name , port , myhost , ** kw ):
490502 seckey = base64 .b64encode (os .urandom (16 )).decode ()
491503 writer_remote .write (f'GET / HTTP/1.1\r \n Host: { myhost } \r \n Upgrade: websocket\r \n Connection: Upgrade\r \n Sec-WebSocket-Key: { seckey } \r \n Sec-WebSocket-Protocol: chat\r \n Sec-WebSocket-Version: 13' .encode () + (b'\r \n Proxy-Authorization: Basic ' + base64 .b64encode (rauth ) if rauth else b'' ) + b'\r \n \r \n ' )
@@ -578,12 +590,12 @@ def sendto(data):
578590async def accept (protos , reader , ** kw ):
579591 for proto in protos :
580592 try :
581- ok = await proto .guess (reader , ** kw )
593+ user = await proto .guess (reader , ** kw )
582594 except Exception :
583595 raise Exception ('Connection closed' )
584- if ok :
585- ret = await proto .accept (reader , ** kw )
586- while len (ret ) < 4 :
596+ if user :
597+ ret = await proto .accept (reader , user , ** kw )
598+ while len (ret ) < 5 :
587599 ret += (b'' ,)
588600 return (proto ,) + ret
589601 raise Exception (f'Unsupported protocol' )
0 commit comments