-
Notifications
You must be signed in to change notification settings - Fork 141
Expand file tree
/
Copy pathauth.py
More file actions
executable file
·134 lines (116 loc) · 5.1 KB
/
auth.py
File metadata and controls
executable file
·134 lines (116 loc) · 5.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
from typing import Optional, List
from databricks.sql.auth.authenticators import (
AuthProvider,
AccessTokenAuthProvider,
ExternalAuthProvider,
DatabricksOAuthProvider,
AzureServicePrincipalCredentialProvider,
)
from databricks.sql.auth.common import AuthType, ClientContext
from databricks.sql.auth.token_federation import TokenFederationProvider
def get_auth_provider(cfg: ClientContext, http_client):
# Determine the base auth provider
base_provider: Optional[AuthProvider] = None
if cfg.credentials_provider:
base_provider = ExternalAuthProvider(cfg.credentials_provider)
elif cfg.auth_type == AuthType.AZURE_SP_M2M.value:
base_provider = ExternalAuthProvider(
AzureServicePrincipalCredentialProvider(
cfg.hostname,
cfg.azure_client_id,
cfg.azure_client_secret,
http_client,
cfg.azure_tenant_id,
cfg.azure_workspace_resource_id,
)
)
elif cfg.auth_type in [AuthType.DATABRICKS_OAUTH.value, AuthType.AZURE_OAUTH.value]:
assert cfg.oauth_redirect_port_range is not None
assert cfg.oauth_client_id is not None
assert cfg.oauth_scopes is not None
base_provider = DatabricksOAuthProvider(
cfg.hostname,
cfg.oauth_persistence,
cfg.oauth_redirect_port_range,
cfg.oauth_client_id,
cfg.oauth_scopes,
http_client,
cfg.auth_type,
)
elif cfg.access_token is not None:
base_provider = AccessTokenAuthProvider(cfg.access_token)
elif cfg.use_cert_as_auth and cfg.tls_client_cert_file:
# no op authenticator. authentication is performed using ssl certificate outside of headers
base_provider = AuthProvider()
else:
if (
cfg.oauth_redirect_port_range is not None
and cfg.oauth_client_id is not None
and cfg.oauth_scopes is not None
):
base_provider = DatabricksOAuthProvider(
cfg.hostname,
cfg.oauth_persistence,
cfg.oauth_redirect_port_range,
cfg.oauth_client_id,
cfg.oauth_scopes,
http_client,
cfg.auth_type or AuthType.DATABRICKS_OAUTH.value,
)
else:
raise RuntimeError("No valid authentication settings!")
# Always wrap with token federation (falls back gracefully if not needed)
if base_provider:
return TokenFederationProvider(
hostname=cfg.hostname,
external_provider=base_provider,
http_client=http_client,
identity_federation_client_id=cfg.identity_federation_client_id,
)
return base_provider
PYSQL_OAUTH_SCOPES = ["sql", "offline_access"]
PYSQL_OAUTH_CLIENT_ID = "databricks-sql-python"
PYSQL_OAUTH_AZURE_CLIENT_ID = "96eecda7-19ea-49cc-abb5-240097d554f5"
PYSQL_OAUTH_REDIRECT_PORT_RANGE = list(range(8020, 8025))
PYSQL_OAUTH_AZURE_REDIRECT_PORT_RANGE = [8030]
def normalize_host_name(hostname: str):
maybe_scheme = "https://" if not hostname.startswith("https://") else ""
maybe_trailing_slash = "/" if not hostname.endswith("/") else ""
return f"{maybe_scheme}{hostname}{maybe_trailing_slash}"
def get_client_id_and_redirect_port(use_azure_auth: bool):
return (
(PYSQL_OAUTH_CLIENT_ID, PYSQL_OAUTH_REDIRECT_PORT_RANGE)
if not use_azure_auth
else (PYSQL_OAUTH_AZURE_CLIENT_ID, PYSQL_OAUTH_AZURE_REDIRECT_PORT_RANGE)
)
def get_python_sql_connector_auth_provider(hostname: str, http_client, **kwargs):
# TODO : unify all the auth mechanisms with the Python SDK
auth_type = kwargs.get("auth_type")
(client_id, redirect_port_range) = get_client_id_and_redirect_port(
auth_type == AuthType.AZURE_OAUTH.value
)
if kwargs.get("username") or kwargs.get("password"):
raise ValueError(
"Username/password authentication is no longer supported. "
"Please use OAuth or access token instead."
)
cfg = ClientContext(
hostname=normalize_host_name(hostname),
auth_type=auth_type,
access_token=kwargs.get("access_token"),
use_cert_as_auth=kwargs.get("_use_cert_as_auth"),
tls_client_cert_file=kwargs.get("_tls_client_cert_file"),
oauth_scopes=PYSQL_OAUTH_SCOPES,
oauth_client_id=kwargs.get("oauth_client_id") or client_id,
azure_client_id=kwargs.get("azure_client_id"),
azure_client_secret=kwargs.get("azure_client_secret"),
azure_tenant_id=kwargs.get("azure_tenant_id"),
azure_workspace_resource_id=kwargs.get("azure_workspace_resource_id"),
oauth_redirect_port_range=[kwargs["oauth_redirect_port"]]
if kwargs.get("oauth_client_id") and kwargs.get("oauth_redirect_port")
else redirect_port_range,
oauth_persistence=kwargs.get("experimental_oauth_persistence"),
credentials_provider=kwargs.get("credentials_provider"),
identity_federation_client_id=kwargs.get("identity_federation_client_id"),
)
return get_auth_provider(cfg, http_client)