diff --git a/generated/provider_dependencies.json b/generated/provider_dependencies.json index 96a1b420e3522..6e8d08c44f7a6 100644 --- a/generated/provider_dependencies.json +++ b/generated/provider_dependencies.json @@ -631,7 +631,7 @@ "google": { "deps": [ "PyOpenSSL>=23.0.0", - "apache-airflow-providers-common-compat>=1.3.0", + "apache-airflow-providers-common-compat>=1.4.0", "apache-airflow-providers-common-sql>=1.20.0", "apache-airflow>=2.9.0", "asgiref>=3.5.2", diff --git a/providers/src/airflow/providers/common/compat/openlineage/utils/sql.py b/providers/src/airflow/providers/common/compat/openlineage/utils/sql.py new file mode 100644 index 0000000000000..9a9618d7a2eb6 --- /dev/null +++ b/providers/src/airflow/providers/common/compat/openlineage/utils/sql.py @@ -0,0 +1,84 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +log = logging.getLogger(__name__) + +if TYPE_CHECKING: + from airflow.providers.openlineage.sqlparser import get_openlineage_facets_with_sql + +else: + try: + from airflow.providers.openlineage.sqlparser import get_openlineage_facets_with_sql + except ImportError: + + def get_openlineage_facets_with_sql( + hook, + sql: str | list[str], + conn_id: str, + database: str | None, + ): + try: + from airflow.providers.openlineage.sqlparser import SQLParser + except ImportError: + log.debug("SQLParser could not be imported from OpenLineage provider.") + return None + + try: + from airflow.providers.openlineage.utils.utils import should_use_external_connection + + use_external_connection = should_use_external_connection(hook) + except ImportError: + # OpenLineage provider release < 1.8.0 - we always use connection + use_external_connection = True + + connection = hook.get_connection(conn_id) + try: + database_info = hook.get_openlineage_database_info(connection) + except AttributeError: + log.debug("%s has no database info provided", hook) + database_info = None + + if database_info is None: + return None + + try: + sql_parser = SQLParser( + dialect=hook.get_openlineage_database_dialect(connection), + default_schema=hook.get_openlineage_default_schema(), + ) + except AttributeError: + log.debug("%s failed to get database dialect", hook) + return None + + operator_lineage = sql_parser.generate_openlineage_metadata_from_sql( + sql=sql, + hook=hook, + database_info=database_info, + database=database, + sqlalchemy_engine=hook.get_sqlalchemy_engine(), + use_connection=use_external_connection, + ) + + return operator_lineage + + +__all__ = ["get_openlineage_facets_with_sql"] diff --git a/providers/src/airflow/providers/google/cloud/hooks/cloud_sql.py b/providers/src/airflow/providers/google/cloud/hooks/cloud_sql.py index 79be9c556b028..b9419d49432ed 100644 --- a/providers/src/airflow/providers/google/cloud/hooks/cloud_sql.py +++ b/providers/src/airflow/providers/google/cloud/hooks/cloud_sql.py @@ -67,6 +67,8 @@ from google.cloud.secretmanager_v1 import AccessSecretVersionResponse from requests import Session + from airflow.providers.common.sql.hooks.sql import DbApiHook + UNIX_PATH_MAX = 108 # Time to sleep between active checks of the operation results @@ -1146,7 +1148,7 @@ def get_sqlproxy_runner(self) -> CloudSqlProxyRunner: gcp_conn_id=self.gcp_conn_id, ) - def get_database_hook(self, connection: Connection) -> BaseHook: + def get_database_hook(self, connection: Connection) -> DbApiHook: """ Retrieve database hook. @@ -1156,7 +1158,7 @@ def get_database_hook(self, connection: Connection) -> BaseHook: if self.database_type == "postgres": from airflow.providers.postgres.hooks.postgres import PostgresHook - db_hook: BaseHook = PostgresHook(connection=connection, database=self.database) + db_hook: DbApiHook = PostgresHook(connection=connection, database=self.database) else: from airflow.providers.mysql.hooks.mysql import MySqlHook diff --git a/providers/src/airflow/providers/google/cloud/operators/cloud_sql.py b/providers/src/airflow/providers/google/cloud/operators/cloud_sql.py index eb75817ae7f4b..18f3d9a443d12 100644 --- a/providers/src/airflow/providers/google/cloud/operators/cloud_sql.py +++ b/providers/src/airflow/providers/google/cloud/operators/cloud_sql.py @@ -20,6 +20,7 @@ from __future__ import annotations from collections.abc import Iterable, Mapping, Sequence +from contextlib import contextmanager from functools import cached_property from typing import TYPE_CHECKING, Any @@ -38,8 +39,7 @@ if TYPE_CHECKING: from airflow.models import Connection - from airflow.providers.mysql.hooks.mysql import MySqlHook - from airflow.providers.postgres.hooks.postgres import PostgresHook + from airflow.providers.openlineage.extractors import OperatorLineage from airflow.utils.context import Context @@ -1256,7 +1256,8 @@ def __init__( self.ssl_client_key = ssl_client_key self.ssl_secret_id = ssl_secret_id - def _execute_query(self, hook: CloudSQLDatabaseHook, database_hook: PostgresHook | MySqlHook) -> None: + @contextmanager + def cloud_sql_proxy_context(self, hook: CloudSQLDatabaseHook): cloud_sql_proxy_runner = None try: if hook.use_proxy: @@ -1266,27 +1267,27 @@ def _execute_query(self, hook: CloudSQLDatabaseHook, database_hook: PostgresHook # be taken over here by another bind(0). # It's quite unlikely to happen though! cloud_sql_proxy_runner.start_proxy() - self.log.info('Executing: "%s"', self.sql) - database_hook.run(self.sql, self.autocommit, parameters=self.parameters) + yield finally: if cloud_sql_proxy_runner: cloud_sql_proxy_runner.stop_proxy() def execute(self, context: Context): - self.gcp_connection = BaseHook.get_connection(self.gcp_conn_id) - hook = self.hook hook.validate_ssl_certs() connection = hook.create_connection() hook.validate_socket_path_length() database_hook = hook.get_database_hook(connection=connection) try: - self._execute_query(hook, database_hook) + with self.cloud_sql_proxy_context(hook): + self.log.info('Executing: "%s"', self.sql) + database_hook.run(self.sql, self.autocommit, parameters=self.parameters) finally: hook.cleanup_database_hook() @cached_property def hook(self): + self.gcp_connection = BaseHook.get_connection(self.gcp_conn_id) return CloudSQLDatabaseHook( gcp_cloudsql_conn_id=self.gcp_cloudsql_conn_id, gcp_conn_id=self.gcp_conn_id, @@ -1297,3 +1298,14 @@ def hook(self): ssl_key=self.ssl_client_key, ssl_secret_id=self.ssl_secret_id, ) + + def get_openlineage_facets_on_complete(self, _) -> OperatorLineage | None: + from airflow.providers.common.compat.openlineage.utils.sql import get_openlineage_facets_with_sql + + with self.cloud_sql_proxy_context(self.hook): + return get_openlineage_facets_with_sql( + hook=self.hook.db_hook, + sql=self.sql, # type:ignore[arg-type] # Iterable[str] instead of list[str] + conn_id=self.gcp_cloudsql_conn_id, + database=self.hook.database, + ) diff --git a/providers/src/airflow/providers/google/provider.yaml b/providers/src/airflow/providers/google/provider.yaml index b253967472d53..c67c8432f4cec 100644 --- a/providers/src/airflow/providers/google/provider.yaml +++ b/providers/src/airflow/providers/google/provider.yaml @@ -101,7 +101,7 @@ versions: dependencies: - apache-airflow>=2.9.0 - - apache-airflow-providers-common-compat>=1.3.0 + - apache-airflow-providers-common-compat>=1.4.0 - apache-airflow-providers-common-sql>=1.20.0 - asgiref>=3.5.2 - dill>=0.2.3 diff --git a/providers/src/airflow/providers/openlineage/sqlparser.py b/providers/src/airflow/providers/openlineage/sqlparser.py index 9751af3f7941e..b4225909d7b20 100644 --- a/providers/src/airflow/providers/openlineage/sqlparser.py +++ b/providers/src/airflow/providers/openlineage/sqlparser.py @@ -16,6 +16,7 @@ # under the License. from __future__ import annotations +import logging from typing import TYPE_CHECKING, Callable import sqlparse @@ -30,6 +31,7 @@ create_information_schema_query, get_table_schemas, ) +from airflow.providers.openlineage.utils.utils import should_use_external_connection from airflow.typing_compat import TypedDict from airflow.utils.log.logging_mixin import LoggingMixin @@ -38,6 +40,9 @@ from sqlalchemy.engine import Engine from airflow.hooks.base import BaseHook + from airflow.providers.common.sql.hooks.sql import DbApiHook + +log = logging.getLogger(__name__) DEFAULT_NAMESPACE = "default" DEFAULT_INFORMATION_SCHEMA_COLUMNS = [ @@ -397,3 +402,37 @@ def _get_tables_hierarchy( tables = schemas.setdefault(normalize_name(table.schema) if table.schema else None, []) tables.append(table.name) return hierarchy + + +def get_openlineage_facets_with_sql( + hook: DbApiHook, sql: str | list[str], conn_id: str, database: str | None +) -> OperatorLineage | None: + connection = hook.get_connection(conn_id) + try: + database_info = hook.get_openlineage_database_info(connection) + except AttributeError: + database_info = None + + if database_info is None: + log.debug("%s has no database info provided", hook) + return None + + try: + sql_parser = SQLParser( + dialect=hook.get_openlineage_database_dialect(connection), + default_schema=hook.get_openlineage_default_schema(), + ) + except AttributeError: + log.debug("%s failed to get database dialect", hook) + return None + + operator_lineage = sql_parser.generate_openlineage_metadata_from_sql( + sql=sql, + hook=hook, + database_info=database_info, + database=database, + sqlalchemy_engine=hook.get_sqlalchemy_engine(), + use_connection=should_use_external_connection(hook), + ) + + return operator_lineage diff --git a/providers/tests/google/cloud/operators/test_cloud_sql.py b/providers/tests/google/cloud/operators/test_cloud_sql.py index 50ed9afbd325d..fdfcd4ee804e1 100644 --- a/providers/tests/google/cloud/operators/test_cloud_sql.py +++ b/providers/tests/google/cloud/operators/test_cloud_sql.py @@ -19,11 +19,19 @@ import os from unittest import mock +from unittest.mock import MagicMock import pytest from airflow.exceptions import AirflowException, TaskDeferred from airflow.models import Connection +from airflow.providers.common.compat.openlineage.facet import ( + Dataset, + SchemaDatasetFacet, + SchemaDatasetFacetFields, + SQLJobFacet, +) +from airflow.providers.common.sql.hooks.sql import DbApiHook from airflow.providers.google.cloud.operators.cloud_sql import ( CloudSQLCloneInstanceOperator, CloudSQLCreateInstanceDatabaseOperator, @@ -822,3 +830,66 @@ def test_create_operator_with_too_long_unix_socket_path(self, get_connection): operator.execute(None) err = ctx.value assert "The UNIX socket path length cannot exceed" in str(err) + + @pytest.mark.parametrize( + "connection_port, default_port, expected_port", + [(None, 4321, 4321), (1234, None, 1234), (1234, 4321, 1234)], + ) + def test_execute_openlineage_events(self, connection_port, default_port, expected_port): + class DBApiHookForTests(DbApiHook): + conn_name_attr = "sql_default" + get_conn = MagicMock(name="conn") + get_connection = MagicMock() + + def get_openlineage_database_info(self, connection): + from airflow.providers.openlineage.sqlparser import DatabaseInfo + + return DatabaseInfo( + scheme="sqlscheme", + authority=DbApiHook.get_openlineage_authority_part(connection, default_port=default_port), + ) + + dbapi_hook = DBApiHookForTests() + + class CloudSQLExecuteQueryOperatorForTest(CloudSQLExecuteQueryOperator): + @property + def hook(self): + return MagicMock(db_hook=dbapi_hook, database="") + + sql = """CREATE TABLE IF NOT EXISTS popular_orders_day_of_week ( + order_day_of_week VARCHAR(64) NOT NULL, + order_placed_on TIMESTAMP NOT NULL, + orders_placed INTEGER NOT NULL + ); +FORGOT TO COMMENT""" + op = CloudSQLExecuteQueryOperatorForTest(task_id="task_id", sql=sql) + DB_SCHEMA_NAME = "PUBLIC" + rows = [ + (DB_SCHEMA_NAME, "popular_orders_day_of_week", "order_day_of_week", 1, "varchar"), + (DB_SCHEMA_NAME, "popular_orders_day_of_week", "order_placed_on", 2, "timestamp"), + (DB_SCHEMA_NAME, "popular_orders_day_of_week", "orders_placed", 3, "int4"), + ] + dbapi_hook.get_connection.return_value = Connection( + conn_id="sql_default", conn_type="postgresql", host="host", port=connection_port + ) + dbapi_hook.get_conn.return_value.cursor.return_value.fetchall.side_effect = [rows, []] + + lineage = op.get_openlineage_facets_on_complete(None) + assert len(lineage.inputs) == 0 + assert lineage.job_facets == {"sql": SQLJobFacet(query=sql)} + assert lineage.run_facets["extractionError"].failedTasks == 1 + assert lineage.outputs == [ + Dataset( + namespace=f"sqlscheme://host:{expected_port}", + name="PUBLIC.popular_orders_day_of_week", + facets={ + "schema": SchemaDatasetFacet( + fields=[ + SchemaDatasetFacetFields(name="order_day_of_week", type="varchar"), + SchemaDatasetFacetFields(name="order_placed_on", type="timestamp"), + SchemaDatasetFacetFields(name="orders_placed", type="int4"), + ] + ) + }, + ) + ]