-
Notifications
You must be signed in to change notification settings - Fork 141
Expand file tree
/
Copy pathclient.py
More file actions
executable file
·1752 lines (1485 loc) · 67.5 KB
/
client.py
File metadata and controls
executable file
·1752 lines (1485 loc) · 67.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import time
from typing import Dict, Tuple, List, Optional, Any, Union, Sequence, BinaryIO
import pandas
try:
import pyarrow
except ImportError:
pyarrow = None
import json
import os
import decimal
from urllib.parse import urlparse
from uuid import UUID
from databricks.sql import __version__
from databricks.sql import *
from databricks.sql.exc import (
OperationalError,
SessionAlreadyClosedError,
CursorAlreadyClosedError,
InterfaceError,
NotSupportedError,
ProgrammingError,
TransactionError,
DatabaseError,
)
from databricks.sql.thrift_api.TCLIService import ttypes
from databricks.sql.backend.thrift_backend import ThriftDatabricksClient
from databricks.sql.backend.databricks_client import DatabricksClient
from databricks.sql.utils import (
ParamEscaper,
inject_parameters,
transform_paramstyle,
ColumnTable,
ColumnQueue,
build_client_context,
get_session_config_value,
serialize_query_tags,
)
from databricks.sql.parameters.native import (
DbsqlParameterBase,
TDbsqlParameter,
TParameterDict,
TParameterSequence,
TParameterCollection,
ParameterStructure,
dbsql_parameter_from_primitive,
ParameterApproach,
)
from databricks.sql.result_set import ResultSet, ThriftResultSet
from databricks.sql.types import Row, SSLOptions
from databricks.sql.auth.auth import get_python_sql_connector_auth_provider
from databricks.sql.experimental.oauth_persistence import OAuthPersistence
from databricks.sql.session import Session
from databricks.sql.backend.types import CommandId, BackendType, CommandState, SessionId
from databricks.sql.auth.common import ClientContext
from databricks.sql.common.unified_http_client import UnifiedHttpClient
from databricks.sql.common.http import HttpMethod
from databricks.sql.thrift_api.TCLIService.ttypes import (
TOpenSessionResp,
TSparkParameter,
TOperationState,
)
from databricks.sql.telemetry.telemetry_client import (
TelemetryHelper,
TelemetryClientFactory,
)
from databricks.sql.telemetry.models.enums import DatabricksClientType
from databricks.sql.telemetry.models.event import (
DriverConnectionParameters,
HostDetails,
)
from databricks.sql.telemetry.latency_logger import log_latency
from databricks.sql.telemetry.models.enums import StatementType
logger = logging.getLogger(__name__)
if pyarrow is None:
logger.warning(
"[WARN] pyarrow is not installed by default since databricks-sql-connector 4.0.0,"
"any arrow specific api (e.g. fetchmany_arrow) and cloud fetch will be disabled."
"If you need these features, please run pip install pyarrow or pip install databricks-sql-connector[pyarrow] to install"
)
DEFAULT_RESULT_BUFFER_SIZE_BYTES = 104857600
DEFAULT_ARRAY_SIZE = 100000
NO_NATIVE_PARAMS: List = []
# Transaction isolation level constants (extension to PEP 249)
TRANSACTION_ISOLATION_LEVEL_REPEATABLE_READ = "REPEATABLE_READ"
class Connection:
def __init__(
self,
server_hostname: str,
http_path: str,
access_token: Optional[str] = None,
http_headers: Optional[List[Tuple[str, str]]] = None,
session_configuration: Optional[Dict[str, Any]] = None,
catalog: Optional[str] = None,
schema: Optional[str] = None,
_use_arrow_native_complex_types: Optional[bool] = True,
ignore_transactions: bool = True,
query_tags: Optional[Dict[str, Optional[str]]] = None,
**kwargs,
) -> None:
"""
Connect to a Databricks SQL endpoint or a Databricks cluster.
Parameters:
:param use_sea: `bool`, optional (default is False)
Use the SEA backend instead of the Thrift backend.
:param use_hybrid_disposition: `bool`, optional (default is False)
Use the hybrid disposition instead of the inline disposition.
:param server_hostname: Databricks instance host name.
:param http_path: Http path either to a DBSQL endpoint (e.g. /sql/1.0/endpoints/1234567890abcdef)
or to a DBR interactive cluster (e.g. /sql/protocolv1/o/1234567890123456/1234-123456-slid123)
:param access_token: `str`, optional
Http Bearer access token, e.g. Databricks Personal Access Token.
Unless if you use auth_type=`databricks-oauth` you need to pass `access_token.
Examples:
```
connection = sql.connect(
server_hostname='dbc-12345.staging.cloud.databricks.com',
http_path='sql/protocolv1/o/6789/12abc567',
access_token='dabpi12345678'
)
```
:param http_headers: An optional list of (k, v) pairs that will be set as Http headers on every request
:param session_configuration: An optional dictionary of Spark session parameters. Defaults to None.
Execute the SQL command `SET -v` to get a full list of available commands.
:param catalog: An optional initial catalog to use. Requires DBR version 9.0+
:param schema: An optional initial schema to use. Requires DBR version 9.0+
Other Parameters:
use_inline_params: `boolean` | str, optional (default is False)
When True, parameterized calls to cursor.execute() will try to render parameter values inline with the
query text instead of using native bound parameters supported in DBR 14.1 and above. This connector will attempt to
sanitise parameterized inputs to prevent SQL injection. The inline parameter approach is maintained for
legacy purposes and will be deprecated in a future release. When this parameter is `True` you will see
a warning log message. To suppress this log message, set `use_inline_params="silent"`.
auth_type: `str`, optional (default is databricks-oauth if neither `access_token` nor `tls_client_cert_file` is set)
`databricks-oauth` : to use Databricks OAuth with fine-grained permission scopes, set to `databricks-oauth`.
`azure-oauth` : to use Microsoft Entra ID OAuth flow, set to `azure-oauth`.
oauth_client_id: `str`, optional
custom oauth client_id. If not specified, it will use the built-in client_id of databricks-sql-python.
oauth_redirect_port: `int`, optional
port of the oauth redirect uri (localhost). This is required when custom oauth client_id
`oauth_client_id` is set
user_agent_entry: `str`, optional
A custom tag to append to the User-Agent header. This is typically used by partners to identify their applications.. If not specified, it will use the default user agent PyDatabricksSqlConnector
experimental_oauth_persistence: configures preferred storage for persisting oauth tokens.
This has to be a class implementing `OAuthPersistence`.
When `auth_type` is set to `databricks-oauth` or `azure-oauth` without persisting the oauth token in a
persistence storage the oauth tokens will only be maintained in memory and if the python process
restarts the end user will have to login again.
Note this is beta (private preview)
For persisting the oauth token in a prod environment you should subclass and implement OAuthPersistence
from databricks.sql.experimental.oauth_persistence import OAuthPersistence, OAuthToken
class MyCustomImplementation(OAuthPersistence):
def __init__(self, file_path):
self._file_path = file_path
def persist(self, token: OAuthToken):
# implement this method to persist token.refresh_token and token.access_token
def read(self) -> Optional[OAuthToken]:
# implement this method to return an instance of the persisted token
connection = sql.connect(
server_hostname='dbc-12345.staging.cloud.databricks.com',
http_path='sql/protocolv1/o/6789/12abc567',
auth_type="databricks-oauth",
experimental_oauth_persistence=MyCustomImplementation()
)
For development purpose you can use the existing `DevOnlyFilePersistence` which stores the
raw oauth token in the provided file path. Please note this is only for development and for prod you should provide your
own implementation of OAuthPersistence.
Examples:
```
# for development only
from databricks.sql.experimental.oauth_persistence import DevOnlyFilePersistence
connection = sql.connect(
server_hostname='dbc-12345.staging.cloud.databricks.com',
http_path='sql/protocolv1/o/6789/12abc567',
auth_type="databricks-oauth",
experimental_oauth_persistence=DevOnlyFilePersistence("~/dev-oauth.json")
)
```
:param _use_arrow_native_complex_types: `bool`, optional
Controls whether a complex type field value is returned as a string or as a native Arrow type. Defaults to True.
When True:
MAP is returned as List[Tuple[str, Any]]
STRUCT is returned as Dict[str, Any]
ARRAY is returned as numpy.ndarray
When False, complex types are returned as a strings. These are generally deserializable as JSON.
:param enable_metric_view_metadata: `bool`, optional (default is False)
When True, enables metric view metadata support by setting the
spark.sql.thriftserver.metadata.metricview.enabled session configuration.
This allows
1. cursor.tables() to return METRIC_VIEW table type
2. cursor.columns() to return "measure" column type
:param fetch_autocommit_from_server: `bool`, optional (default is False)
When True, the connection.autocommit property queries the server for current state
using SET AUTOCOMMIT instead of returning cached value.
Set to True if autocommit might be changed by external means (e.g., external SQL commands).
When False (default), uses cached state for better performance.
:param ignore_transactions: `bool`, optional (default is True)
When True, transaction-related operations behave as follows:
- commit(): no-op (does nothing)
- rollback(): raises NotSupportedError
- autocommit setter: no-op (does nothing)
When False, transaction operations execute normally.
"""
# Internal arguments in **kwargs:
# _use_cert_as_auth
# Use a TLS cert instead of a token
# _enable_ssl
# Connect over HTTP instead of HTTPS
# _port
# Which port to connect to
# _skip_routing_headers:
# Don't set routing headers if set to True (for use when connecting directly to server)
# _tls_no_verify
# Set to True (Boolean) to completely disable SSL verification.
# _tls_verify_hostname
# Set to False (Boolean) to disable SSL hostname verification, but check certificate.
# _tls_trusted_ca_file
# Set to the path of the file containing trusted CA certificates for server certificate
# verification. If not provide, uses system truststore.
# _tls_client_cert_file, _tls_client_cert_key_file, _tls_client_cert_key_password
# Set client SSL certificate.
# See https://docs.python.org/3/library/ssl.html#ssl.SSLContext.load_cert_chain
# _retry_stop_after_attempts_count
# The maximum number of attempts during a request retry sequence (defaults to 24)
# _socket_timeout
# The timeout in seconds for socket send, recv and connect operations. Defaults to None for
# no timeout. Should be a positive float or integer.
# _disable_pandas
# In case the deserialisation through pandas causes any issues, it can be disabled with
# this flag.
# _use_arrow_native_decimals
# Databricks runtime will return native Arrow types for decimals instead of Arrow strings
# (True by default)
# _use_arrow_native_timestamps
# Databricks runtime will return native Arrow types for timestamps instead of Arrow strings
# (True by default)
# use_cloud_fetch
# Enable use of cloud fetch to extract large query results in parallel via cloud storage
logger.debug(
"Connection.__init__(server_hostname=%s, http_path=%s)",
server_hostname,
http_path,
)
if access_token:
access_token_kv = {"access_token": access_token}
kwargs = {**kwargs, **access_token_kv}
enable_metric_view_metadata = kwargs.get("enable_metric_view_metadata", False)
if enable_metric_view_metadata:
if session_configuration is None:
session_configuration = {}
session_configuration[
"spark.sql.thriftserver.metadata.metricview.enabled"
] = "true"
if query_tags is not None:
if session_configuration is None:
session_configuration = {}
serialized = serialize_query_tags(query_tags)
if serialized:
session_configuration["QUERY_TAGS"] = serialized
else:
session_configuration.pop("QUERY_TAGS", None)
self.disable_pandas = kwargs.get("_disable_pandas", False)
self.lz4_compression = kwargs.get("enable_query_result_lz4_compression", True)
self.use_cloud_fetch = kwargs.get("use_cloud_fetch", True)
self._cursors = [] # type: List[Cursor]
self.telemetry_batch_size = kwargs.get(
"telemetry_batch_size", TelemetryClientFactory.DEFAULT_BATCH_SIZE
)
client_context = build_client_context(server_hostname, __version__, **kwargs)
self.http_client = UnifiedHttpClient(client_context)
try:
self.session = Session(
server_hostname,
http_path,
self.http_client,
http_headers,
session_configuration,
catalog,
schema,
_use_arrow_native_complex_types,
**kwargs,
)
self.session.open()
except Exception as e:
# Respect user's telemetry preference even during connection failure
enable_telemetry = kwargs.get("enable_telemetry", True)
TelemetryClientFactory.connection_failure_log(
error_name="Exception",
error_message=str(e),
host_url=server_hostname,
http_path=http_path,
port=kwargs.get("_port", 443),
client_context=client_context,
user_agent=self.session.useragent_header
if hasattr(self, "session")
else None,
enable_telemetry=enable_telemetry,
)
raise e
self.use_inline_params = self._set_use_inline_params_with_warning(
kwargs.get("use_inline_params", False)
)
self.staging_allowed_local_path = kwargs.get("staging_allowed_local_path", None)
self._fetch_autocommit_from_server = kwargs.get(
"fetch_autocommit_from_server", False
)
self.ignore_transactions = ignore_transactions
self.force_enable_telemetry = kwargs.get("force_enable_telemetry", False)
self.enable_telemetry = kwargs.get("enable_telemetry", True)
self.telemetry_enabled = TelemetryHelper.is_telemetry_enabled(self)
TelemetryClientFactory.initialize_telemetry_client(
telemetry_enabled=self.telemetry_enabled,
session_id_hex=self.get_session_id_hex(),
auth_provider=self.session.auth_provider,
host_url=self.session.host,
batch_size=self.telemetry_batch_size,
client_context=client_context,
)
self._telemetry_client = TelemetryClientFactory.get_telemetry_client(
host_url=self.session.host
)
# Determine proxy usage
use_proxy = self.http_client.using_proxy()
proxy_host_info = None
if (
use_proxy
and self.http_client.proxy_uri
and isinstance(self.http_client.proxy_uri, str)
):
parsed = urlparse(self.http_client.proxy_uri)
proxy_host_info = HostDetails(
host_url=parsed.hostname or self.http_client.proxy_uri,
port=parsed.port or 8080,
)
driver_connection_params = DriverConnectionParameters(
http_path=http_path,
mode=DatabricksClientType.SEA
if self.session.use_sea
else DatabricksClientType.THRIFT,
host_info=HostDetails(host_url=server_hostname, port=self.session.port),
auth_mech=TelemetryHelper.get_auth_mechanism(self.session.auth_provider),
auth_flow=TelemetryHelper.get_auth_flow(self.session.auth_provider),
socket_timeout=kwargs.get("_socket_timeout", None),
azure_workspace_resource_id=kwargs.get("azure_workspace_resource_id", None),
azure_tenant_id=kwargs.get("azure_tenant_id", None),
use_proxy=use_proxy,
use_system_proxy=use_proxy,
proxy_host_info=proxy_host_info,
use_cf_proxy=False, # CloudFlare proxy not yet supported in Python
cf_proxy_host_info=None, # CloudFlare proxy not yet supported in Python
non_proxy_hosts=None,
allow_self_signed_support=kwargs.get("_tls_no_verify", False),
use_system_trust_store=True, # Python uses system SSL by default
enable_arrow=pyarrow is not None,
enable_direct_results=True, # Always enabled in Python
enable_sea_hybrid_results=kwargs.get("use_hybrid_disposition", False),
http_connection_pool_size=kwargs.get("pool_maxsize", None),
rows_fetched_per_block=DEFAULT_ARRAY_SIZE,
async_poll_interval_millis=2000, # Default polling interval
support_many_parameters=True, # Native parameters supported
enable_complex_datatype_support=_use_arrow_native_complex_types,
allowed_volume_ingestion_paths=self.staging_allowed_local_path,
query_tags=get_session_config_value(session_configuration, "query_tags"),
)
self._telemetry_client.export_initial_telemetry_log(
driver_connection_params=driver_connection_params,
user_agent=self.session.useragent_header,
session_id=self.get_session_id_hex(),
)
def _set_use_inline_params_with_warning(self, value: Union[bool, str]):
"""Valid values are True, False, and "silent"
False: Use native parameters
True: Use inline parameters and log a warning
"silent": Use inline parameters and don't log a warning
"""
if value is False:
return False
if value not in [True, "silent"]:
raise ValueError(
f"Invalid value for use_inline_params: {value}. "
+ 'Valid values are True, False, and "silent"'
)
if value is True:
logger.warning(
"Parameterised queries executed with this client will use the inline parameter approach."
"This approach will be deprecated in a future release. Consider using native parameters."
"Learn more: https://github.com/databricks/databricks-sql-python/tree/main/docs/parameters.md"
'To suppress this warning, set use_inline_params="silent"'
)
return value
# The ideal return type for this method is perhaps Self, but that was not added until 3.11, and we support pre-3.11 pythons, currently.
def __enter__(self) -> "Connection":
return self
def __exit__(self, exc_type, exc_value, traceback):
self.close()
def __del__(self):
if self.open:
logger.debug(
"Closing unclosed connection for session "
"{}".format(self.get_session_id_hex())
)
try:
self._close(close_cursors=False)
except OperationalError as e:
# Close on best-effort basis.
logger.debug("Couldn't close unclosed connection: {}".format(e.message))
def get_session_id(self):
"""Get the raw session ID (backend-specific)"""
return self.session.guid
def get_session_id_hex(self):
"""Get the session ID in hex format"""
return self.session.guid_hex
@staticmethod
def server_parameterized_queries_enabled(protocolVersion):
"""Check if parameterized queries are enabled for the given protocol version"""
return Session.server_parameterized_queries_enabled(protocolVersion)
@property
def protocol_version(self):
"""Get the protocol version from the Session object"""
return self.session.protocol_version
@staticmethod
def get_protocol_version(openSessionResp: TOpenSessionResp):
"""Get the protocol version from the OpenSessionResp object"""
properties = (
{"serverProtocolVersion": openSessionResp.serverProtocolVersion}
if openSessionResp.serverProtocolVersion
else {}
)
session_id = SessionId.from_thrift_handle(
openSessionResp.sessionHandle, properties
)
return Session.get_protocol_version(session_id)
@property
def open(self) -> bool:
"""Return whether the connection is open by checking if the session is open."""
return self.session.is_open
def cursor(
self,
arraysize: int = DEFAULT_ARRAY_SIZE,
buffer_size_bytes: int = DEFAULT_RESULT_BUFFER_SIZE_BYTES,
row_limit: Optional[int] = None,
) -> "Cursor":
"""
Args:
arraysize: The maximum number of rows in direct results.
buffer_size_bytes: The maximum number of bytes in direct results.
row_limit: The maximum number of rows in the result.
Return a new Cursor object using the connection.
Will throw an Error if the connection has been closed.
"""
if not self.open:
raise InterfaceError(
"Cannot create cursor from closed connection",
host_url=self.session.host,
session_id_hex=self.get_session_id_hex(),
)
cursor = Cursor(
self,
self.session.backend,
arraysize=arraysize,
result_buffer_size_bytes=buffer_size_bytes,
row_limit=row_limit,
)
self._cursors.append(cursor)
return cursor
def close(self) -> None:
"""Close the underlying session and mark all associated cursors as closed."""
self._close()
def _close(self, close_cursors=True) -> None:
if close_cursors:
for cursor in self._cursors:
cursor.close()
try:
self.session.close()
except Exception as e:
logger.error(f"Attempt to close session raised a local exception: {e}")
TelemetryClientFactory.close(host_url=self.session.host)
# Close HTTP client that was created by this connection
if self.http_client:
self.http_client.close()
@property
def autocommit(self) -> bool:
"""
Get auto-commit mode for this connection.
Extension to PEP 249. Returns cached value by default.
If fetch_autocommit_from_server=True was set during connection,
queries server for current state.
Returns:
bool: True if auto-commit is enabled, False otherwise
Raises:
InterfaceError: If connection is closed
TransactionError: If fetch_autocommit_from_server=True and query fails
"""
if not self.open:
raise InterfaceError(
"Cannot get autocommit on closed connection",
host_url=self.session.host,
session_id_hex=self.get_session_id_hex(),
)
if self._fetch_autocommit_from_server:
return self._fetch_autocommit_state_from_server()
return self.session.get_autocommit()
@autocommit.setter
def autocommit(self, value: bool) -> None:
"""
Set auto-commit mode for this connection.
Extension to PEP 249. Executes SET AUTOCOMMIT command on server.
Args:
value: True to enable auto-commit, False to disable
When ignore_transactions is True:
- This method is a no-op (does nothing)
Raises:
InterfaceError: If connection is closed
TransactionError: If server rejects the change
"""
# No-op when ignore_transactions is True
if self.ignore_transactions:
return
if not self.open:
raise InterfaceError(
"Cannot set autocommit on closed connection",
host_url=self.session.host,
session_id_hex=self.get_session_id_hex(),
)
# Create internal cursor for transaction control
cursor = None
try:
cursor = self.cursor()
sql = f"SET AUTOCOMMIT = {'TRUE' if value else 'FALSE'}"
cursor.execute(sql)
# Update cached state on success
self.session.set_autocommit(value)
except DatabaseError as e:
# Wrap in TransactionError with context
raise TransactionError(
f"Failed to set autocommit to {value}: {e.message}",
context={
**e.context,
"operation": "set_autocommit",
"autocommit_value": value,
},
host_url=self.session.host,
session_id_hex=self.get_session_id_hex(),
) from e
finally:
if cursor:
cursor.close()
def _fetch_autocommit_state_from_server(self) -> bool:
"""
Query server for current autocommit state using SET AUTOCOMMIT.
Returns:
bool: Server's autocommit state
Raises:
TransactionError: If query fails
"""
cursor = None
try:
cursor = self.cursor()
cursor.execute("SET AUTOCOMMIT")
# Fetch result: should return row with value column
result = cursor.fetchone()
if result is None:
raise TransactionError(
"No result returned from SET AUTOCOMMIT query",
context={"operation": "fetch_autocommit"},
host_url=self.session.host,
session_id_hex=self.get_session_id_hex(),
)
# Parse value (first column should be "true" or "false")
value_str = str(result[0]).lower()
autocommit_state = value_str == "true"
# Update cache
self.session.set_autocommit(autocommit_state)
return autocommit_state
except TransactionError:
# Re-raise TransactionError as-is
raise
except DatabaseError as e:
# Wrap other DatabaseErrors
raise TransactionError(
f"Failed to fetch autocommit state from server: {e.message}",
context={**e.context, "operation": "fetch_autocommit"},
host_url=self.session.host,
session_id_hex=self.get_session_id_hex(),
) from e
finally:
if cursor:
cursor.close()
def commit(self) -> None:
"""
Commit the current transaction.
Per PEP 249. Should be called only when autocommit is disabled.
When autocommit is False:
- Commits the current transaction
- Server automatically starts new transaction
When autocommit is True:
- Server may throw error if no active transaction
When ignore_transactions is True:
- This method is a no-op (does nothing)
Raises:
InterfaceError: If connection is closed
TransactionError: If commit fails (e.g., no active transaction)
"""
# No-op when ignore_transactions is True
if self.ignore_transactions:
return
if not self.open:
raise InterfaceError(
"Cannot commit on closed connection",
host_url=self.session.host,
session_id_hex=self.get_session_id_hex(),
)
cursor = None
try:
cursor = self.cursor()
cursor.execute("COMMIT")
except DatabaseError as e:
raise TransactionError(
f"Failed to commit transaction: {e.message}",
context={**e.context, "operation": "commit"},
host_url=self.session.host,
session_id_hex=self.get_session_id_hex(),
) from e
finally:
if cursor:
cursor.close()
def rollback(self) -> None:
"""
Rollback the current transaction.
Per PEP 249. Should be called only when autocommit is disabled.
When autocommit is False:
- Rolls back the current transaction
- Server automatically starts new transaction
When autocommit is True:
- ROLLBACK is forgiving (no-op, doesn't throw exception)
When ignore_transactions is True:
- Raises NotSupportedError
Note: ROLLBACK is safe to call even without active transaction.
Raises:
InterfaceError: If connection is closed
NotSupportedError: If ignore_transactions is True
TransactionError: If rollback fails
"""
# Raise NotSupportedError when ignore_transactions is True
if self.ignore_transactions:
raise NotSupportedError(
"Transactions are not supported on Databricks",
host_url=self.session.host,
session_id_hex=self.get_session_id_hex(),
)
if not self.open:
raise InterfaceError(
"Cannot rollback on closed connection",
host_url=self.session.host,
session_id_hex=self.get_session_id_hex(),
)
cursor = None
try:
cursor = self.cursor()
cursor.execute("ROLLBACK")
except DatabaseError as e:
raise TransactionError(
f"Failed to rollback transaction: {e.message}",
context={**e.context, "operation": "rollback"},
host_url=self.session.host,
session_id_hex=self.get_session_id_hex(),
) from e
finally:
if cursor:
cursor.close()
def get_transaction_isolation(self) -> str:
"""
Get the transaction isolation level.
Extension to PEP 249.
Databricks supports REPEATABLE_READ isolation level (Snapshot Isolation),
which is the default and only supported level.
Returns:
str: "REPEATABLE_READ" - the transaction isolation level constant
Raises:
InterfaceError: If connection is closed
"""
if not self.open:
raise InterfaceError(
"Cannot get transaction isolation on closed connection",
host_url=self.session.host,
session_id_hex=self.get_session_id_hex(),
)
return TRANSACTION_ISOLATION_LEVEL_REPEATABLE_READ
def set_transaction_isolation(self, level: str) -> None:
"""
Set transaction isolation level.
Extension to PEP 249.
Databricks supports only REPEATABLE_READ isolation level (Snapshot Isolation).
This method validates that the requested level is supported but does not
execute any SQL, as REPEATABLE_READ is the default server behavior.
Args:
level: Isolation level. Must be "REPEATABLE_READ" or "REPEATABLE READ"
(case-insensitive, underscores and spaces are interchangeable)
Raises:
InterfaceError: If connection is closed
NotSupportedError: If isolation level not supported
"""
if not self.open:
raise InterfaceError(
"Cannot set transaction isolation on closed connection",
host_url=self.session.host,
session_id_hex=self.get_session_id_hex(),
)
# Normalize and validate isolation level
normalized_level = level.upper().replace("_", " ")
if normalized_level != TRANSACTION_ISOLATION_LEVEL_REPEATABLE_READ.replace(
"_", " "
):
raise NotSupportedError(
f"Setting transaction isolation level '{level}' is not supported. "
f"Only {TRANSACTION_ISOLATION_LEVEL_REPEATABLE_READ} is supported.",
host_url=self.session.host,
session_id_hex=self.get_session_id_hex(),
)
class Cursor:
def __init__(
self,
connection: Connection,
backend: DatabricksClient,
result_buffer_size_bytes: int = DEFAULT_RESULT_BUFFER_SIZE_BYTES,
arraysize: int = DEFAULT_ARRAY_SIZE,
row_limit: Optional[int] = None,
) -> None:
"""
These objects represent a database cursor, which is used to manage the context of a fetch
operation.
Cursors are not isolated, i.e., any changes done to the database by a cursor are immediately
visible by other cursors or connections.
"""
self.connection: Connection = connection
self.rowcount: int = -1 # Return -1 as this is not supported
self.buffer_size_bytes: int = result_buffer_size_bytes
self.active_result_set: Union[ResultSet, None] = None
self.arraysize: int = arraysize
self.row_limit: Optional[int] = row_limit
# Note that Cursor closed => active result set closed, but not vice versa
self.open: bool = True
self.executing_command_id: Optional[CommandId] = None
self.backend: DatabricksClient = backend
self.active_command_id: Optional[CommandId] = None
self.escaper = ParamEscaper()
self.lastrowid = None
self.ASYNC_DEFAULT_POLLING_INTERVAL = 2
# The ideal return type for this method is perhaps Self, but that was not added until 3.11, and we support pre-3.11 pythons, currently.
def __enter__(self) -> "Cursor":
return self
def __exit__(self, exc_type, exc_value, traceback):
self.close()
def __iter__(self):
if self.active_result_set:
for row in self.active_result_set:
yield row
else:
raise ProgrammingError(
"There is no active result set",
host_url=self.connection.session.host,
session_id_hex=self.connection.get_session_id_hex(),
)
def _determine_parameter_approach(
self, params: Optional[TParameterCollection]
) -> ParameterApproach:
"""Encapsulates the logic for choosing whether to send parameters in native vs inline mode
If params is None then ParameterApproach.NONE is returned.
If self.use_inline_params is True then inline mode is used.
If self.use_inline_params is False, then check if the server supports them and proceed.
Else raise an exception.
Returns a ParameterApproach enumeration or raises an exception
If inline approach is used when the server supports native approach, a warning is logged
"""
if params is None:
return ParameterApproach.NONE
if self.connection.use_inline_params:
return ParameterApproach.INLINE
else:
return ParameterApproach.NATIVE
def _all_dbsql_parameters_are_named(self, params: List[TDbsqlParameter]) -> bool:
"""Return True if all members of the list have a non-null .name attribute"""
return all([i.name is not None for i in params])
def _normalize_tparametersequence(
self, params: TParameterSequence
) -> List[TDbsqlParameter]:
"""Retains the same order as the input list."""
output: List[TDbsqlParameter] = []
for p in params:
if isinstance(p, DbsqlParameterBase):
output.append(p)
else:
output.append(dbsql_parameter_from_primitive(value=p))
return output
def _normalize_tparameterdict(
self, params: TParameterDict
) -> List[TDbsqlParameter]:
return [
dbsql_parameter_from_primitive(value=value, name=name)
for name, value in params.items()
]
def _normalize_tparametercollection(
self, params: Optional[TParameterCollection]
) -> List[TDbsqlParameter]:
if params is None:
return []
if isinstance(params, dict):
return self._normalize_tparameterdict(params)
if isinstance(params, Sequence):
return self._normalize_tparametersequence(list(params))
def _determine_parameter_structure(
self,
parameters: List[TDbsqlParameter],
) -> ParameterStructure:
all_named = self._all_dbsql_parameters_are_named(parameters)
if all_named:
return ParameterStructure.NAMED
else:
return ParameterStructure.POSITIONAL
def _prepare_inline_parameters(
self, stmt: str, params: Optional[Union[Sequence, Dict[str, Any]]]
) -> Tuple[str, List]:
"""Return a statement and list of native parameters to be passed to thrift_backend for execution
:stmt:
A string SQL query containing parameter markers of PEP-249 paramstyle `pyformat`.
For example `%(param)s`.
:params:
An iterable of parameter values to be rendered inline. If passed as a Dict, the keys
must match the names of the markers included in :stmt:. If passed as a List, its length
must equal the count of parameter markers in :stmt:.
Returns a tuple of:
stmt: the passed statement with the param markers replaced by literal rendered values
params: an empty list representing the native parameters to be passed with this query.
The list is always empty because native parameters are never used under the inline approach
"""
escaped_values = self.escaper.escape_args(params)
rendered_statement = inject_parameters(stmt, escaped_values)
return rendered_statement, NO_NATIVE_PARAMS
def _prepare_native_parameters(
self,
stmt: str,
params: List[TDbsqlParameter],
param_structure: ParameterStructure,
) -> Tuple[str, List[TSparkParameter]]:
"""Return a statement and a list of native parameters to be passed to thrift_backend for execution
:stmt:
A string SQL query containing parameter markers of PEP-249 paramstyle `named`.
For example `:param`.