Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support setting session properties on a individual statement #339

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,29 @@ conn = trino.dbapi.connect(
)
```

## Session properties

Session properties can be set on the connection

```python
import trino
conn = trino.dbapi.connect(
...,
session_properties={"query_max_run_time": "1d"}
)
```

### Statement properties

It's also possible to set a session property for a specific statement by setting it on the Cursor. This is especially handy in the case of hive partitions.

```python
import trino
conn = trino.dbapi.connect()
cur = conn.cursor(statement_properties={"hive.insert_existing_partitions_behavior": "OVERWRITE"})
cur.execute("INSERT INTO hive_partitioned_table SELECT * from another_table")
```

## Timezone

The time zone for the session can be explicitly set using the IANA time zone
Expand Down
16 changes: 16 additions & 0 deletions tests/integration/test_dbapi_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -1381,6 +1381,22 @@ def test_rowcount_insert(trino_connection):
assert cur.rowcount == 1


def test_statement_properties(trino_connection):
exchange_compression_statement = "SHOW SESSION LIKE 'exchange_compression'"
cur = trino_connection.cursor()
cur.execute(exchange_compression_statement)
result = cur.fetchall()
assert result[0][1] == "false"
cur = trino_connection.cursor(statement_properties={"exchange_compression": True})
cur.execute(exchange_compression_statement)
result = cur.fetchall()
assert result[0][1] == "True"
cur = trino_connection.cursor()
cur.execute(exchange_compression_statement)
result = cur.fetchall()
assert result[0][1] == "false"


def assert_cursor_description(cur, trino_type, size=None, precision=None, scale=None):
assert cur.description[0][1] == trino_type
assert cur.description[0][2] is None
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1011,7 +1011,7 @@ def json(self):
result = query.execute(additional_http_headers=additional_headers)

# Validate the the post function was called with the right argguments
mock_post.assert_called_once_with(sql, additional_headers)
mock_post.assert_called_once_with(sql, additional_headers, None)

# Validate the result is an instance of TrinoResult
assert isinstance(result, TrinoResult)
Expand Down
27 changes: 23 additions & 4 deletions trino/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,9 @@ def transaction_id(self, value):

@property
def http_headers(self) -> Dict[str, str]:
return self._create_headers()

def _create_headers(self, statement_properties: Dict[str, Any] = None):
headers = {}

headers[constants.HEADER_CATALOG] = self._client_session.catalog
Expand All @@ -469,10 +472,13 @@ def http_headers(self) -> Dict[str, str]:
if self._client_session.client_tags is not None and len(self._client_session.client_tags) > 0:
headers[constants.HEADER_CLIENT_TAGS] = ",".join(self._client_session.client_tags)

session_properties = copy.deepcopy(self._client_session.properties)
if statement_properties is not None:
session_properties.update(statement_properties)
headers[constants.HEADER_SESSION] = ",".join(
# ``name`` must not contain ``=``
"{}={}".format(name, urllib.parse.quote(str(value)))
for name, value in self._client_session.properties.items()
for name, value in session_properties.items()
)

if len(self._client_session.prepared_statements) != 0:
Expand Down Expand Up @@ -506,6 +512,9 @@ def http_headers(self) -> Dict[str, str]:

return headers

def with_statement_properties(self, statement_properties: Optional[Dict[str, Any]]):
return self._create_headers(statement_properties)

@property
def max_attempts(self) -> int:
return self._max_attempts
Expand Down Expand Up @@ -546,11 +555,15 @@ def statement_url(self) -> str:
def next_uri(self) -> Optional[str]:
return self._next_uri

def post(self, sql: str, additional_http_headers: Optional[Dict[str, Any]] = None):
def post(
self, sql: str,
additional_http_headers: Optional[Dict[str, Any]] = None,
statement_properties: Optional[Dict[str, Any]] = None,
):
data = sql.encode("utf-8")
# Deep copy of the http_headers dict since they may be modified for this
# request by the provided additional_http_headers
http_headers = copy.deepcopy(self.http_headers)
http_headers = copy.deepcopy(self.with_statement_properties(statement_properties))

# Update the request headers with the additional_http_headers
http_headers.update(additional_http_headers or {})
Expand Down Expand Up @@ -737,6 +750,7 @@ def __init__(
request: TrinoRequest,
query: str,
legacy_primitive_types: bool = False,
statement_properties: Optional[Dict[str, Any]] = None,
) -> None:
self._query_id: Optional[str] = None
self._stats: Dict[Any, Any] = {}
Expand All @@ -752,6 +766,7 @@ def __init__(
self._query = query
self._result: Optional[TrinoResult] = None
self._legacy_primitive_types = legacy_primitive_types
self._statement_properties = statement_properties
self._row_mapper: Optional[RowMapper] = None

@property
Expand Down Expand Up @@ -806,7 +821,11 @@ def execute(self, additional_http_headers=None) -> TrinoResult:
if self.cancelled:
raise exceptions.TrinoUserError("Query has been cancelled", self.query_id)

response = self._request.post(self._query, additional_http_headers)
response = self._request.post(
self._query,
additional_http_headers,
self._statement_properties,
)
status = self._request.process(response)
self._info_uri = status.info_uri
self._query_id = status.id
Expand Down
48 changes: 38 additions & 10 deletions trino/dbapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,10 @@ def _create_request(self):
self.request_timeout,
)

def cursor(self, legacy_primitive_types: bool = None):
def cursor(
self, legacy_primitive_types: bool = None,
statement_properties: Optional[Dict[str, Any]] = None,
):
"""Return a new :py:class:`Cursor` object using the connection."""
if self.isolation_level != IsolationLevel.AUTOCOMMIT:
if self.transaction is None:
Expand All @@ -226,7 +229,8 @@ def cursor(self, legacy_primitive_types: bool = None):
self,
request,
# if legacy_primitive_types is not explicitly set in Cursor, take from Connection
legacy_primitive_types if legacy_primitive_types is not None else self.legacy_primitive_types
legacy_primitive_types if legacy_primitive_types is not None else self.legacy_primitive_types,
statement_properties
)


Expand Down Expand Up @@ -277,7 +281,13 @@ class Cursor(object):

"""

def __init__(self, connection, request, legacy_primitive_types: bool = False):
def __init__(
self,
connection,
request,
legacy_primitive_types: bool = False,
statement_properties: Optional[Dict[str, Any]] = None
):
if not isinstance(connection, Connection):
raise ValueError(
"connection must be a Connection object: {}".format(type(connection))
Expand All @@ -289,6 +299,7 @@ def __init__(self, connection, request, legacy_primitive_types: bool = False):
self._iterator = None
self._query = None
self._legacy_primitive_types = legacy_primitive_types
self._statement_properties = statement_properties

def __iter__(self):
return self._iterator
Expand Down Expand Up @@ -376,8 +387,12 @@ def _prepare_statement(self, statement: str, name: str) -> None:
:param name: name that will be assigned to the prepared statement.
"""
sql = f"PREPARE {name} FROM {statement}"
query = trino.client.TrinoQuery(self.connection._create_request(), query=sql,
legacy_primitive_types=self._legacy_primitive_types)
query = trino.client.TrinoQuery(
self.connection._create_request(),
query=sql,
legacy_primitive_types=self._legacy_primitive_types,
statement_properties=self._statement_properties,
)
query.execute()

def _execute_prepared_statement(
Expand All @@ -386,7 +401,12 @@ def _execute_prepared_statement(
params
):
sql = 'EXECUTE ' + statement_name + ' USING ' + ','.join(map(self._format_prepared_param, params))
return trino.client.TrinoQuery(self._request, query=sql, legacy_primitive_types=self._legacy_primitive_types)
return trino.client.TrinoQuery(
self._request,
query=sql,
legacy_primitive_types=self._legacy_primitive_types,
statement_properties=self._statement_properties,
)

def _format_prepared_param(self, param):
"""
Expand Down Expand Up @@ -475,8 +495,12 @@ def _format_prepared_param(self, param):

def _deallocate_prepared_statement(self, statement_name: str) -> None:
sql = 'DEALLOCATE PREPARE ' + statement_name
query = trino.client.TrinoQuery(self.connection._create_request(), query=sql,
legacy_primitive_types=self._legacy_primitive_types)
query = trino.client.TrinoQuery(
self.connection._create_request(),
query=sql,
legacy_primitive_types=self._legacy_primitive_types,
statement_properties=self._statement_properties,
)
query.execute()

def _generate_unique_statement_name(self):
Expand Down Expand Up @@ -507,8 +531,12 @@ def execute(self, operation, params=None):
self._deallocate_prepared_statement(statement_name)

else:
self._query = trino.client.TrinoQuery(self._request, query=operation,
legacy_primitive_types=self._legacy_primitive_types)
self._query = trino.client.TrinoQuery(
self._request,
query=operation,
legacy_primitive_types=self._legacy_primitive_types,
statement_properties=self._statement_properties,
)
self._iterator = iter(self._query.execute())
return self

Expand Down