X Tutup
from collections import namedtuple, OrderedDict from collections.abc import Iterable import datetime, decimal from enum import Enum from typing import Dict import pyarrow from databricks.sql import exc class ArrowQueue: def __init__( self, arrow_table: pyarrow.Table, n_valid_rows: int, start_row_index: int = 0 ): """ A queue-like wrapper over an Arrow table :param arrow_table: The Arrow table from which we want to take rows :param n_valid_rows: The index of the last valid row in the table :param start_row_index: The first row in the table we should start fetching from """ self.cur_row_index = start_row_index self.arrow_table = arrow_table self.n_valid_rows = n_valid_rows def next_n_rows(self, num_rows: int) -> pyarrow.Table: """Get upto the next n rows of the Arrow dataframe""" length = min(num_rows, self.n_valid_rows - self.cur_row_index) # Note that the table.slice API is not the same as Python's slice # The second argument should be length, not end index slice = self.arrow_table.slice(self.cur_row_index, length) self.cur_row_index += slice.num_rows return slice def remaining_rows(self) -> pyarrow.Table: slice = self.arrow_table.slice( self.cur_row_index, self.n_valid_rows - self.cur_row_index ) self.cur_row_index += slice.num_rows return slice ExecuteResponse = namedtuple( "ExecuteResponse", "status has_been_closed_server_side has_more_rows description lz4_compressed is_staging_operation " "command_handle arrow_queue arrow_schema_bytes", ) def _bound(min_x, max_x, x): """Bound x by [min_x, max_x] min_x or max_x being None means unbounded in that respective side. """ if min_x is None and max_x is None: return x if min_x is None: return min(max_x, x) if max_x is None: return max(min_x, x) return min(max_x, max(min_x, x)) class NoRetryReason(Enum): OUT_OF_TIME = "out of time" OUT_OF_ATTEMPTS = "out of attempts" NOT_RETRYABLE = "non-retryable error" class RequestErrorInfo( namedtuple( "RequestErrorInfo_", "error error_message retry_delay http_code method request" ) ): @property def request_session_id(self): if hasattr(self.request, "sessionHandle"): return self.request.sessionHandle.sessionId.guid else: return None @property def request_query_id(self): if hasattr(self.request, "operationHandle"): return self.request.operationHandle.operationId.guid else: return None def full_info_logging_context( self, no_retry_reason, attempt, max_attempts, elapsed, max_duration ): log_base_data_dict = OrderedDict( [ ("method", self.method), ("session-id", self.request_session_id), ("query-id", self.request_query_id), ("http-code", self.http_code), ("error-message", self.error_message), ("original-exception", str(self.error)), ] ) log_base_data_dict["no-retry-reason"] = ( no_retry_reason and no_retry_reason.value ) log_base_data_dict["bounded-retry-delay"] = self.retry_delay log_base_data_dict["attempt"] = "{}/{}".format(attempt, max_attempts) log_base_data_dict["elapsed-seconds"] = "{}/{}".format(elapsed, max_duration) return log_base_data_dict def user_friendly_error_message(self, no_retry_reason, attempt, elapsed): # This should be kept at the level that is appropriate to return to a Redash user user_friendly_error_message = "Error during request to server" if self.error_message: user_friendly_error_message = "{}: {}".format( user_friendly_error_message, self.error_message ) return user_friendly_error_message # Taken from PyHive class ParamEscaper: _DATE_FORMAT = "%Y-%m-%d" _TIME_FORMAT = "%H:%M:%S.%f" _DATETIME_FORMAT = "{} {}".format(_DATE_FORMAT, _TIME_FORMAT) def escape_args(self, parameters): if isinstance(parameters, dict): return {k: self.escape_item(v) for k, v in parameters.items()} elif isinstance(parameters, (list, tuple)): return tuple(self.escape_item(x) for x in parameters) else: raise exc.ProgrammingError( "Unsupported param format: {}".format(parameters) ) def escape_number(self, item): return item def escape_string(self, item): # Need to decode UTF-8 because of old sqlalchemy. # Newer SQLAlchemy checks dialect.supports_unicode_binds before encoding Unicode strings # as byte strings. The old version always encodes Unicode as byte strings, which breaks # string formatting here. if isinstance(item, bytes): item = item.decode("utf-8") # This is good enough when backslashes are literal, newlines are just followed, and the way # to escape a single quote is to put two single quotes. # (i.e. only special character is single quote) return "'{}'".format(item.replace("\\", "\\\\").replace("'", "\\'")) def escape_sequence(self, item): l = map(str, map(self.escape_item, item)) return "(" + ",".join(l) + ")" def escape_datetime(self, item, format, cutoff=0): dt_str = item.strftime(format) formatted = dt_str[:-cutoff] if cutoff and format.endswith(".%f") else dt_str return "'{}'".format(formatted) def escape_decimal(self, item): return str(item) def escape_item(self, item): if item is None: return "NULL" elif isinstance(item, (int, float)): return self.escape_number(item) elif isinstance(item, str): return self.escape_string(item) elif isinstance(item, Iterable): return self.escape_sequence(item) elif isinstance(item, datetime.datetime): return self.escape_datetime(item, self._DATETIME_FORMAT) elif isinstance(item, datetime.date): return self.escape_datetime(item, self._DATE_FORMAT) elif isinstance(item, decimal.Decimal): return self.escape_decimal(item) else: raise exc.ProgrammingError("Unsupported object {}".format(item)) def inject_parameters(operation: str, parameters: Dict[str, str]): return operation % parameters
X Tutup