From de7d26e2c823313b350ec316ad501535f5543c79 Mon Sep 17 00:00:00 2001 From: Michiel De Smet Date: Tue, 28 Feb 2023 12:41:02 +0100 Subject: [PATCH] Support setting session properties on a individual statement --- README.md | 23 ++++++++++ tests/integration/test_dbapi_integration.py | 16 +++++++ tests/unit/test_client.py | 2 +- trino/client.py | 27 ++++++++++-- trino/dbapi.py | 48 ++++++++++++++++----- 5 files changed, 101 insertions(+), 15 deletions(-) diff --git a/README.md b/README.md index 41992438..6de22d58 100644 --- a/README.md +++ b/README.md @@ -373,6 +373,29 @@ conn = trino.dbapi.connect( ) ``` +## Session properties + +Session properties can be set on the connection + +```python +import trino +conn = trino.dbapi.connect( + ..., + session_properties={"query_max_run_time": "1d"} +) +``` + +### Statement properties + +It's also possible to set a session property for a specific statement by setting it on the Cursor. This is especially handy in the case of hive partitions. + +```python +import trino +conn = trino.dbapi.connect() +cur = conn.cursor(statement_properties={"hive.insert_existing_partitions_behavior": "OVERWRITE"}) +cur.execute("INSERT INTO hive_partitioned_table SELECT * from another_table") +``` + ## Timezone The time zone for the session can be explicitly set using the IANA time zone diff --git a/tests/integration/test_dbapi_integration.py b/tests/integration/test_dbapi_integration.py index 0f442bb1..ea8d5e31 100644 --- a/tests/integration/test_dbapi_integration.py +++ b/tests/integration/test_dbapi_integration.py @@ -1381,6 +1381,22 @@ def test_rowcount_insert(trino_connection): assert cur.rowcount == 1 +def test_statement_properties(trino_connection): + exchange_compression_statement = "SHOW SESSION LIKE 'exchange_compression'" + cur = trino_connection.cursor() + cur.execute(exchange_compression_statement) + result = cur.fetchall() + assert result[0][1] == "false" + cur = trino_connection.cursor(statement_properties={"exchange_compression": True}) + cur.execute(exchange_compression_statement) + result = cur.fetchall() + assert result[0][1] == "True" + cur = trino_connection.cursor() + cur.execute(exchange_compression_statement) + result = cur.fetchall() + assert result[0][1] == "false" + + def assert_cursor_description(cur, trino_type, size=None, precision=None, scale=None): assert cur.description[0][1] == trino_type assert cur.description[0][2] is None diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 123ddf86..87a84ee4 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -1011,7 +1011,7 @@ def json(self): result = query.execute(additional_http_headers=additional_headers) # Validate the the post function was called with the right argguments - mock_post.assert_called_once_with(sql, additional_headers) + mock_post.assert_called_once_with(sql, additional_headers, None) # Validate the result is an instance of TrinoResult assert isinstance(result, TrinoResult) diff --git a/trino/client.py b/trino/client.py index 9f9b75f8..845c8cc3 100644 --- a/trino/client.py +++ b/trino/client.py @@ -452,6 +452,9 @@ def transaction_id(self, value): @property def http_headers(self) -> Dict[str, str]: + return self._create_headers() + + def _create_headers(self, statement_properties: Dict[str, Any] = None): headers = {} headers[constants.HEADER_CATALOG] = self._client_session.catalog @@ -469,10 +472,13 @@ def http_headers(self) -> Dict[str, str]: if self._client_session.client_tags is not None and len(self._client_session.client_tags) > 0: headers[constants.HEADER_CLIENT_TAGS] = ",".join(self._client_session.client_tags) + session_properties = copy.deepcopy(self._client_session.properties) + if statement_properties is not None: + session_properties.update(statement_properties) headers[constants.HEADER_SESSION] = ",".join( # ``name`` must not contain ``=`` "{}={}".format(name, urllib.parse.quote(str(value))) - for name, value in self._client_session.properties.items() + for name, value in session_properties.items() ) if len(self._client_session.prepared_statements) != 0: @@ -506,6 +512,9 @@ def http_headers(self) -> Dict[str, str]: return headers + def with_statement_properties(self, statement_properties: Optional[Dict[str, Any]]): + return self._create_headers(statement_properties) + @property def max_attempts(self) -> int: return self._max_attempts @@ -546,11 +555,15 @@ def statement_url(self) -> str: def next_uri(self) -> Optional[str]: return self._next_uri - def post(self, sql: str, additional_http_headers: Optional[Dict[str, Any]] = None): + def post( + self, sql: str, + additional_http_headers: Optional[Dict[str, Any]] = None, + statement_properties: Optional[Dict[str, Any]] = None, + ): data = sql.encode("utf-8") # Deep copy of the http_headers dict since they may be modified for this # request by the provided additional_http_headers - http_headers = copy.deepcopy(self.http_headers) + http_headers = copy.deepcopy(self.with_statement_properties(statement_properties)) # Update the request headers with the additional_http_headers http_headers.update(additional_http_headers or {}) @@ -737,6 +750,7 @@ def __init__( request: TrinoRequest, query: str, legacy_primitive_types: bool = False, + statement_properties: Optional[Dict[str, Any]] = None, ) -> None: self._query_id: Optional[str] = None self._stats: Dict[Any, Any] = {} @@ -752,6 +766,7 @@ def __init__( self._query = query self._result: Optional[TrinoResult] = None self._legacy_primitive_types = legacy_primitive_types + self._statement_properties = statement_properties self._row_mapper: Optional[RowMapper] = None @property @@ -806,7 +821,11 @@ def execute(self, additional_http_headers=None) -> TrinoResult: if self.cancelled: raise exceptions.TrinoUserError("Query has been cancelled", self.query_id) - response = self._request.post(self._query, additional_http_headers) + response = self._request.post( + self._query, + additional_http_headers, + self._statement_properties, + ) status = self._request.process(response) self._info_uri = status.info_uri self._query_id = status.id diff --git a/trino/dbapi.py b/trino/dbapi.py index 82e394e9..ec1a520b 100644 --- a/trino/dbapi.py +++ b/trino/dbapi.py @@ -213,7 +213,10 @@ def _create_request(self): self.request_timeout, ) - def cursor(self, legacy_primitive_types: bool = None): + def cursor( + self, legacy_primitive_types: bool = None, + statement_properties: Optional[Dict[str, Any]] = None, + ): """Return a new :py:class:`Cursor` object using the connection.""" if self.isolation_level != IsolationLevel.AUTOCOMMIT: if self.transaction is None: @@ -226,7 +229,8 @@ def cursor(self, legacy_primitive_types: bool = None): self, request, # if legacy_primitive_types is not explicitly set in Cursor, take from Connection - legacy_primitive_types if legacy_primitive_types is not None else self.legacy_primitive_types + legacy_primitive_types if legacy_primitive_types is not None else self.legacy_primitive_types, + statement_properties ) @@ -277,7 +281,13 @@ class Cursor(object): """ - def __init__(self, connection, request, legacy_primitive_types: bool = False): + def __init__( + self, + connection, + request, + legacy_primitive_types: bool = False, + statement_properties: Optional[Dict[str, Any]] = None + ): if not isinstance(connection, Connection): raise ValueError( "connection must be a Connection object: {}".format(type(connection)) @@ -289,6 +299,7 @@ def __init__(self, connection, request, legacy_primitive_types: bool = False): self._iterator = None self._query = None self._legacy_primitive_types = legacy_primitive_types + self._statement_properties = statement_properties def __iter__(self): return self._iterator @@ -376,8 +387,12 @@ def _prepare_statement(self, statement: str, name: str) -> None: :param name: name that will be assigned to the prepared statement. """ sql = f"PREPARE {name} FROM {statement}" - query = trino.client.TrinoQuery(self.connection._create_request(), query=sql, - legacy_primitive_types=self._legacy_primitive_types) + query = trino.client.TrinoQuery( + self.connection._create_request(), + query=sql, + legacy_primitive_types=self._legacy_primitive_types, + statement_properties=self._statement_properties, + ) query.execute() def _execute_prepared_statement( @@ -386,7 +401,12 @@ def _execute_prepared_statement( params ): sql = 'EXECUTE ' + statement_name + ' USING ' + ','.join(map(self._format_prepared_param, params)) - return trino.client.TrinoQuery(self._request, query=sql, legacy_primitive_types=self._legacy_primitive_types) + return trino.client.TrinoQuery( + self._request, + query=sql, + legacy_primitive_types=self._legacy_primitive_types, + statement_properties=self._statement_properties, + ) def _format_prepared_param(self, param): """ @@ -475,8 +495,12 @@ def _format_prepared_param(self, param): def _deallocate_prepared_statement(self, statement_name: str) -> None: sql = 'DEALLOCATE PREPARE ' + statement_name - query = trino.client.TrinoQuery(self.connection._create_request(), query=sql, - legacy_primitive_types=self._legacy_primitive_types) + query = trino.client.TrinoQuery( + self.connection._create_request(), + query=sql, + legacy_primitive_types=self._legacy_primitive_types, + statement_properties=self._statement_properties, + ) query.execute() def _generate_unique_statement_name(self): @@ -507,8 +531,12 @@ def execute(self, operation, params=None): self._deallocate_prepared_statement(statement_name) else: - self._query = trino.client.TrinoQuery(self._request, query=operation, - legacy_primitive_types=self._legacy_primitive_types) + self._query = trino.client.TrinoQuery( + self._request, + query=operation, + legacy_primitive_types=self._legacy_primitive_types, + statement_properties=self._statement_properties, + ) self._iterator = iter(self._query.execute()) return self