1+ # -*- coding: utf-8 -*-
2+
3+ """
4+ Copyright 2006-2008 SpringSource (http://springsource.com), All Rights Reserved
5+
6+ Licensed under the Apache License, Version 2.0 (the "License");
7+ you may not use this file except in compliance with the License.
8+ You may obtain a copy of the License at
9+
10+ http://www.apache.org/licenses/LICENSE-2.0
11+
12+ Unless required by applicable law or agreed to in writing, software
13+ distributed under the License is distributed on an "AS IS" BASIS,
14+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+ See the License for the specific language governing permissions and
16+ limitations under the License.
17+ """
18+
19+ # stdlib
20+ import socket
21+ import ssl
22+ import threading
23+ import time
24+ import unittest
25+
26+ from SocketServer import StreamRequestHandler
27+ from xmlrpclib import Transport
28+
29+ # Spring Python
30+ from springpython .remoting .xmlrpc import SSLServer , SSLClient , RequestHandler , \
31+ SSLClientTransport , VerificationException
32+
33+ RESULT_OK = "All good"
34+
35+ server_key = "./support/pki/server-key.pem"
36+ server_cert = "./support/pki/server-cert.pem"
37+ client_key = "./support/pki/client-key.pem"
38+ client_cert = "./support/pki/client-cert.pem"
39+ ca_certs = "./support/pki/ca-chain.pem"
40+
41+ class MySSLServer (SSLServer ):
42+
43+ def test_server (self ):
44+ return RESULT_OK
45+
46+ def register_functions (self ):
47+ self .register_function (self .shutdown )
48+ self .register_function (self .test_server )
49+
50+ class _DummyServer (SSLServer ):
51+ pass
52+
53+ class _DummyRequest ():
54+ def recv (self , * ignored_args , ** ignored_kwargs ):
55+ pass
56+
57+ class _MyClientTransport (object ):
58+ def __init__ (self , ca_certs = None , keyfile = None , certfile = None , cert_reqs = None ,
59+ ssl_version = None , timeout = None , strict = None ):
60+ self .ca_certs = ca_certs
61+ self .keyfile = keyfile
62+ self .certfile = certfile
63+ self .cert_reqs = cert_reqs
64+ self .ssl_version = ssl_version
65+ self .timeout = timeout
66+ self .strict = strict
67+
68+ class TestInitDefaultArguments (unittest .TestCase ):
69+ def test_init_default_arguments (self ):
70+ """ Tests various defaults various and those passed to __init__'s.
71+ """
72+
73+ self .assertTrue (issubclass (VerificationException , Exception ))
74+ self .assertEqual (RequestHandler .rpc_paths , ("/" , "/RPC2" ))
75+ self .assertEqual (SSLClientTransport .user_agent ,
76+ "SSL XML-RPC Client (by http://springpython.webfactional.com)" )
77+
78+ server1 = MySSLServer ("127.0.0.1" , 8001 )
79+
80+ self .assertEqual (server1 .keyfile , None )
81+ self .assertEqual (server1 .certfile , None )
82+ self .assertEqual (server1 .ca_certs , None )
83+ self .assertEqual (server1 .cert_reqs , ssl .CERT_NONE )
84+ self .assertEqual (server1 .ssl_version , ssl .PROTOCOL_TLSv1 )
85+ self .assertEqual (server1 .do_handshake_on_connect , True )
86+ self .assertEqual (server1 .suppress_ragged_eofs , True )
87+ self .assertEqual (server1 .ciphers , None )
88+ self .assertEqual (server1 .verify_fields , None )
89+
90+ server_host = "127.0.0.1"
91+ server_port = 8002
92+ server_keyfile = "server_keyfile"
93+ server_certfile = "server_certfile"
94+ server_ca_certs = "server_ca_certs"
95+ server_cert_reqs = ssl .CERT_OPTIONAL
96+ server_ssl_version = ssl .PROTOCOL_SSLv3
97+ server_do_handshake_on_connect = False
98+ server_suppress_ragged_eofs = False
99+ server_ciphers = "ALL"
100+ server_verify_fields = {"commonName" : "Foo" , "organizationName" :"Baz" }
101+
102+ server2 = MySSLServer (server_host , server_port , server_keyfile ,
103+ server_certfile , server_ca_certs , server_cert_reqs ,
104+ server_ssl_version , server_do_handshake_on_connect ,
105+ server_suppress_ragged_eofs , server_ciphers ,
106+ verify_fields = server_verify_fields )
107+
108+ # inherited from SocketServer.BaseServer
109+ self .assertEqual (server2 .server_address , (server_host , server_port ))
110+
111+ self .assertEqual (server2 .keyfile , server_keyfile )
112+ self .assertEqual (server2 .certfile , server_certfile )
113+ self .assertEqual (server2 .ca_certs , server_ca_certs )
114+ self .assertEqual (server2 .cert_reqs , server_cert_reqs )
115+ self .assertEqual (server2 .ssl_version , server_ssl_version )
116+ self .assertEqual (server2 .do_handshake_on_connect , server_do_handshake_on_connect )
117+ self .assertEqual (server2 .suppress_ragged_eofs , server_suppress_ragged_eofs )
118+ self .assertEqual (server2 .ciphers , server_ciphers )
119+ self .assertEqual (sorted (server2 .verify_fields ), sorted (server_verify_fields ))
120+
121+ client_uri = "https://127.0.0.1:8000/RPC2"
122+ client_ca_certs = "client_ca_certs"
123+ client_keyfile = "client_keyfile"
124+ client_certfile = "client_certfile"
125+ client_cert_reqs = ssl .CERT_OPTIONAL
126+ client_ssl_version = ssl .PROTOCOL_SSLv23
127+ client_transport = _MyClientTransport
128+ client_encoding = "utf-16"
129+ client_verbose = 1
130+ client_allow_none = False
131+ client_use_datetime = False
132+ client_timeout = 13
133+ client_strict = True
134+
135+ client2 = SSLClient (client_uri , client_ca_certs , client_keyfile ,
136+ client_certfile , client_cert_reqs , client_ssl_version ,
137+ client_transport , client_encoding , client_verbose ,
138+ client_allow_none , client_use_datetime , client_timeout ,
139+ client_strict )
140+
141+ self .assertEqual (client2 ._ServerProxy__host , "127.0.0.1:8000" )
142+ self .assertEqual (client2 ._ServerProxy__transport .ca_certs , client_ca_certs )
143+ self .assertEqual (client2 ._ServerProxy__transport .keyfile , client_keyfile )
144+ self .assertEqual (client2 ._ServerProxy__transport .certfile , client_certfile )
145+ self .assertEqual (client2 ._ServerProxy__transport .cert_reqs , client_cert_reqs )
146+ self .assertEqual (client2 ._ServerProxy__transport .ssl_version , client_ssl_version )
147+ self .assertTrue (isinstance (client2 ._ServerProxy__transport , _MyClientTransport ))
148+ self .assertEqual (client2 ._ServerProxy__encoding , client_encoding )
149+ self .assertEqual (client2 ._ServerProxy__verbose , client_verbose )
150+ self .assertEqual (client2 ._ServerProxy__allow_none , client_allow_none )
151+ self .assertEqual (client2 ._ServerProxy__transport .timeout , client_timeout )
152+ self .assertEqual (client2 ._ServerProxy__transport .strict , client_strict )
153+
154+ self .assertRaises (NotImplementedError , _DummyServer , "127.0.0.1" , 8003 )
155+
156+ def test_request_handler (self ):
157+ request = _DummyRequest ()
158+ rh = RequestHandler (request , None , None )
159+ rh .setup ()
160+ self .assertTrue (rh .connection is request )
161+ self .assertTrue (isinstance (rh .rfile , socket ._fileobject ))
162+ self .assertTrue (isinstance (rh .wfile , socket ._fileobject ))
163+ self .assertTrue (rh .rfile ._sock is request )
164+ self .assertEqual (rh .rfile .mode , "rb" )
165+ self .assertEqual (rh .rfile .bufsize , socket ._fileobject .default_bufsize )
166+ self .assertTrue (rh .wfile ._sock is request )
167+ self .assertEqual (rh .wfile .mode , "wb" )
168+ self .assertEqual (rh .wfile .bufsize , StreamRequestHandler .wbufsize )
169+
170+ def xtest_imports (self ):
171+ raise NotImplemented ()
172+
173+ class TestSSL (unittest .TestCase ):
174+
175+ class _ClientServerContextManager (object ):
176+ def __init__ (self , server_port , cert_reqs = ssl .CERT_NONE , verify_fields = {}):
177+ self .server_port = server_port
178+ self .cert_reqs = cert_reqs
179+ self .verify_fields = verify_fields
180+
181+ def __enter__ (self ):
182+ server = MySSLServer ("127.0.0.1" , self .server_port , server_key ,
183+ server_cert , ca_certs , cert_reqs = self .cert_reqs ,
184+ verify_fields = self .verify_fields )
185+ self .server_thread = self ._start_server (server )
186+ time .sleep (0.5 )
187+
188+ def __exit__ (self , * ignored_args ):
189+ self .server_thread .server .shutdown ()
190+
191+ def _start_server (self , server ):
192+
193+ class _ServerController (threading .Thread ):
194+ def __init__ (self , server ):
195+ self .server = server
196+ self .isDaemon = False
197+ super (_ServerController , self ).__init__ ()
198+
199+ def run (self ):
200+ self .server .serve_forever ()
201+
202+ server_thread = _ServerController (server )
203+ server_thread .start ()
204+
205+ return server_thread
206+
207+
208+ def xtest_simple_ssl (self ):
209+ """ Server uses its cert, client uses none.
210+ """
211+ server_port = 9001
212+ with TestSSL ._ClientServerContextManager (server_port ):
213+ client = SSLClient ("https://localhost:%d/RPC2" % server_port , ca_certs )
214+ self .assertEqual (client .test_server (), RESULT_OK )
215+
216+ def xtest_client_cert (self ):
217+ """ Server & client use certs.
218+ """
219+ server_port = 9002
220+ with TestSSL ._ClientServerContextManager (server_port , ssl .CERT_REQUIRED ):
221+ client = SSLClient ("https://localhost:%d/RPC2" % server_port , ca_certs ,
222+ client_key , client_cert )
223+ self .assertEqual (client .test_server (), RESULT_OK )
224+
225+ def xtest_client_cert_ok (self ):
226+ """ Server & client use certs. Server succesfully validates client certificate's fields.
227+ """
228+ server_port = 9003
229+ verify_fields = {"commonName" :"My Client" , "countryName" :"US" ,
230+ "organizationalUnitName" :"My Unit" , "organizationName" :"My Company" ,
231+ "stateOrProvinceName" :"My State" }
232+
233+ with TestSSL ._ClientServerContextManager (server_port , ssl .CERT_REQUIRED , verify_fields ):
234+ client = SSLClient ("https://localhost:%d/RPC2" % server_port , ca_certs ,
235+ client_key , client_cert )
236+ self .assertEqual (client .test_server (), RESULT_OK )
237+
238+ def xtest_client_cert_failure_missing_field (self ):
239+ """ Server & client use certs. Server fails to validate client certificate's fields
240+ (a field is missing).
241+ """
242+ server_port = 9004
243+ verify_fields = {"commonName" :"My Client" , "countryName" :"US" ,
244+ "organizationalUnitName" :"My Unit" , "organizationName" :"My Company" ,
245+ "stateOrProvinceName" :"My State" , "FOO" : "BAR" }
246+
247+ with TestSSL ._ClientServerContextManager (server_port , ssl .CERT_REQUIRED , verify_fields ):
248+ client = SSLClient ("https://localhost:%d/RPC2" % server_port , ca_certs ,
249+ client_key , client_cert )
250+ self .assertRaises (Exception , client .test_server )
251+
252+ def xtest_client_cert_failure_field_incorrect_value (self ):
253+ """ Server & client use certs. Server fails to validate client certificate's fields
254+ (all fields are in place, but commonName has an incorrect value).
255+ """
256+ server_port = 9005
257+ verify_fields = {"commonName" :"Invalid" }
258+ with TestSSL ._ClientServerContextManager (server_port , ssl .CERT_REQUIRED , verify_fields ):
259+ client = SSLClient ("https://localhost:%d/RPC2" % server_port , ca_certs ,
260+ client_key , client_cert )
261+ self .assertRaises (Exception , client .test_server )
262+
263+ def test_client_cert_failure_no_client_cert (self ):
264+ """ Server optionally requires a client to send the certificate
265+ and validates its fields but client sends none.
266+ """
267+ server_port = 9006
268+ verify_fields = {"commonName" :"My Client" }
269+ with TestSSL ._ClientServerContextManager (server_port , ssl .CERT_OPTIONAL , verify_fields ):
270+ client = SSLClient ("https://localhost:%d/RPC2" % server_port , ca_certs )
271+ self .assertRaises (Exception , client .test_server )
272+
273+ if __name__ == "__main__" :
274+ unittest .main ()
0 commit comments