forked from databricks/databricks-sql-python
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_endpoint.py
More file actions
57 lines (47 loc) · 2.71 KB
/
test_endpoint.py
File metadata and controls
57 lines (47 loc) · 2.71 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import unittest
import os
import pytest
from unittest.mock import patch
from databricks.sql.auth.endpoint import infer_cloud_from_host, CloudType, get_oauth_endpoints, \
AzureOAuthEndpointCollection
aws_host = "foo-bar.cloud.databricks.com"
azure_host = "foo-bar.1.azuredatabricks.net"
class EndpointTest(unittest.TestCase):
def test_infer_cloud_from_host(self):
param_list = [(CloudType.AWS, aws_host), (CloudType.AZURE, azure_host), (None, "foo.example.com")]
for expected_type, host in param_list:
with self.subTest(expected_type or "None", expected_type=expected_type):
self.assertEqual(infer_cloud_from_host(host), expected_type)
self.assertEqual(infer_cloud_from_host(f"https://{host}/to/path"), expected_type)
def test_oauth_endpoint(self):
scopes = ["offline_access", "sql", "admin"]
scopes2 = ["sql", "admin"]
azure_scope = f"{AzureOAuthEndpointCollection.DATATRICKS_AZURE_APP}/user_impersonation"
param_list = [(CloudType.AWS,
aws_host,
f"https://{aws_host}/oidc/oauth2/v2.0/authorize",
f"https://{aws_host}/oidc/.well-known/oauth-authorization-server",
scopes,
scopes2
),
(
CloudType.AZURE,
azure_host,
f"https://{azure_host}/oidc/oauth2/v2.0/authorize",
"https://login.microsoftonline.com/organizations/v2.0/.well-known/openid-configuration",
[azure_scope, "offline_access"],
[azure_scope]
)]
for cloud_type, host, expected_auth_url, expected_config_url, expected_scopes, expected_scope2 in param_list:
with self.subTest(cloud_type):
endpoint = get_oauth_endpoints(cloud_type)
self.assertEqual(endpoint.get_authorization_url(host), expected_auth_url)
self.assertEqual(endpoint.get_openid_config_url(host), expected_config_url)
self.assertEqual(endpoint.get_scopes_mapping(scopes), expected_scopes)
self.assertEqual(endpoint.get_scopes_mapping(scopes2), expected_scope2)
@patch.dict(os.environ, {'DATABRICKS_AZURE_TENANT_ID': '052ee82f-b79d-443c-8682-3ec1749e56b0'})
def test_azure_oauth_scope_mappings_from_different_tenant_id(self):
scopes = ["offline_access", "sql", "all"]
endpoint = get_oauth_endpoints(CloudType.AZURE)
self.assertEqual(endpoint.get_scopes_mapping(scopes),
['052ee82f-b79d-443c-8682-3ec1749e56b0/user_impersonation', "offline_access"])