forked from databricks/databricks-sql-python
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsession.py
More file actions
216 lines (183 loc) · 7.61 KB
/
session.py
File metadata and controls
216 lines (183 loc) · 7.61 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
import logging
from typing import Dict, Tuple, List, Optional, Any, Type
from databricks.sql.thrift_api.TCLIService import ttypes
from databricks.sql.types import SSLOptions
from databricks.sql.auth.auth import get_python_sql_connector_auth_provider
from databricks.sql.auth.common import ClientContext
from databricks.sql.exc import SessionAlreadyClosedError, DatabaseError, RequestError
from databricks.sql import __version__
from databricks.sql import USER_AGENT_NAME
from databricks.sql.backend.thrift_backend import ThriftDatabricksClient
from databricks.sql.backend.sea.backend import SeaDatabricksClient
from databricks.sql.backend.databricks_client import DatabricksClient
from databricks.sql.backend.types import SessionId, BackendType
from databricks.sql.common.unified_http_client import UnifiedHttpClient
logger = logging.getLogger(__name__)
class Session:
def __init__(
self,
server_hostname: str,
http_path: str,
http_client: UnifiedHttpClient,
http_headers: Optional[List[Tuple[str, str]]] = None,
session_configuration: Optional[Dict[str, Any]] = None,
catalog: Optional[str] = None,
schema: Optional[str] = None,
_use_arrow_native_complex_types: Optional[bool] = True,
**kwargs,
) -> None:
"""
Create a session to a Databricks SQL endpoint or a Databricks cluster.
This class handles all session-related behavior and communication with the backend.
"""
self.is_open = False
self.host = server_hostname
self.port = kwargs.get("_port", 443)
self.session_configuration = session_configuration
self.catalog = catalog
self.schema = schema
self.http_path = http_path
# Initialize autocommit state (JDBC default is True)
self._autocommit = True
user_agent_entry = kwargs.get("user_agent_entry")
if user_agent_entry is None:
user_agent_entry = kwargs.get("_user_agent_entry")
if user_agent_entry is not None:
logger.warning(
"[WARN] Parameter '_user_agent_entry' is deprecated; use 'user_agent_entry' instead. "
"This parameter will be removed in the upcoming releases."
)
if user_agent_entry:
self.useragent_header = "{}/{} ({})".format(
USER_AGENT_NAME, __version__, user_agent_entry
)
else:
self.useragent_header = "{}/{}".format(USER_AGENT_NAME, __version__)
base_headers = [("User-Agent", self.useragent_header)]
all_headers = (http_headers or []) + base_headers
self.ssl_options = SSLOptions(
# Double negation is generally a bad thing, but we have to keep backward compatibility
tls_verify=not kwargs.get(
"_tls_no_verify", False
), # by default - verify cert and host
tls_verify_hostname=kwargs.get("_tls_verify_hostname", True),
tls_trusted_ca_file=kwargs.get("_tls_trusted_ca_file"),
tls_client_cert_file=kwargs.get("_tls_client_cert_file"),
tls_client_cert_key_file=kwargs.get("_tls_client_cert_key_file"),
tls_client_cert_key_password=kwargs.get("_tls_client_cert_key_password"),
)
# Use the provided HTTP client (created in Connection)
self.http_client = http_client
# Create auth provider with HTTP client context
self.auth_provider = get_python_sql_connector_auth_provider(
server_hostname, http_client=self.http_client, **kwargs
)
self.backend = self._create_backend(
server_hostname,
http_path,
all_headers,
self.auth_provider,
_use_arrow_native_complex_types,
kwargs,
)
self.protocol_version = None
def _create_backend(
self,
server_hostname: str,
http_path: str,
all_headers: List[Tuple[str, str]],
auth_provider,
_use_arrow_native_complex_types: Optional[bool],
kwargs: dict,
) -> DatabricksClient:
"""Create and return the appropriate backend client."""
self.use_sea = kwargs.get("use_sea", False)
databricks_client_class: Type[DatabricksClient]
if self.use_sea:
logger.debug("Creating SEA backend client")
databricks_client_class = SeaDatabricksClient
else:
logger.debug("Creating Thrift backend client")
databricks_client_class = ThriftDatabricksClient
common_args = {
"server_hostname": server_hostname,
"port": self.port,
"http_path": http_path,
"http_headers": all_headers,
"auth_provider": auth_provider,
"ssl_options": self.ssl_options,
"http_client": self.http_client,
"_use_arrow_native_complex_types": _use_arrow_native_complex_types,
**kwargs,
}
return databricks_client_class(**common_args)
def open(self):
self._session_id = self.backend.open_session(
session_configuration=self.session_configuration,
catalog=self.catalog,
schema=self.schema,
)
self.protocol_version = self.get_protocol_version(self._session_id)
self.is_open = True
logger.info("Successfully opened session %s", str(self.guid_hex))
@staticmethod
def get_protocol_version(session_id: SessionId):
return session_id.protocol_version
@staticmethod
def server_parameterized_queries_enabled(protocolVersion):
if (
protocolVersion
and protocolVersion >= ttypes.TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V8
):
return True
else:
return False
@property
def session_id(self) -> SessionId:
"""Get the normalized session ID"""
return self._session_id
@property
def guid(self) -> Any:
"""Get the raw session ID (backend-specific)"""
return self._session_id.guid
@property
def guid_hex(self) -> str:
"""Get the session ID in hex format"""
return self._session_id.hex_guid
def get_autocommit(self) -> bool:
"""
Get the cached autocommit state for this session.
Returns:
bool: True if autocommit is enabled, False otherwise
"""
return self._autocommit
def set_autocommit(self, value: bool) -> None:
"""
Update the cached autocommit state for this session.
Args:
value: True to cache autocommit as enabled, False as disabled
"""
self._autocommit = value
def close(self) -> None:
"""Close the underlying session."""
logger.info("Closing session %s", self.guid_hex)
if not self.is_open:
logger.debug("Session appears to have been closed already")
return
try:
self.backend.close_session(self._session_id)
except RequestError as e:
if isinstance(e.args[1], SessionAlreadyClosedError):
logger.info("Session was closed by a prior request")
except DatabaseError as e:
if "Invalid SessionHandle" in str(e):
logger.warning(
"Attempted to close session that was already closed: %s", e
)
else:
logger.warning(
"Attempt to close session raised an exception at the server: %s", e
)
except Exception as e:
logger.error("Attempt to close session raised a local exception: %s", e)
self.is_open = False