diff --git a/providers/src/airflow/providers/google/cloud/hooks/cloud_sql.py b/providers/src/airflow/providers/google/cloud/hooks/cloud_sql.py index 79be9c556b028..b9419d49432ed 100644 --- a/providers/src/airflow/providers/google/cloud/hooks/cloud_sql.py +++ b/providers/src/airflow/providers/google/cloud/hooks/cloud_sql.py @@ -67,6 +67,8 @@ from google.cloud.secretmanager_v1 import AccessSecretVersionResponse from requests import Session + from airflow.providers.common.sql.hooks.sql import DbApiHook + UNIX_PATH_MAX = 108 # Time to sleep between active checks of the operation results @@ -1146,7 +1148,7 @@ def get_sqlproxy_runner(self) -> CloudSqlProxyRunner: gcp_conn_id=self.gcp_conn_id, ) - def get_database_hook(self, connection: Connection) -> BaseHook: + def get_database_hook(self, connection: Connection) -> DbApiHook: """ Retrieve database hook. @@ -1156,7 +1158,7 @@ def get_database_hook(self, connection: Connection) -> BaseHook: if self.database_type == "postgres": from airflow.providers.postgres.hooks.postgres import PostgresHook - db_hook: BaseHook = PostgresHook(connection=connection, database=self.database) + db_hook: DbApiHook = PostgresHook(connection=connection, database=self.database) else: from airflow.providers.mysql.hooks.mysql import MySqlHook diff --git a/providers/src/airflow/providers/google/cloud/openlineage/mixins.py b/providers/src/airflow/providers/google/cloud/openlineage/mixins.py index ce7a14e03ae32..1547646d432d0 100644 --- a/providers/src/airflow/providers/google/cloud/openlineage/mixins.py +++ b/providers/src/airflow/providers/google/cloud/openlineage/mixins.py @@ -20,9 +20,12 @@ import copy import json import traceback +from collections.abc import Iterable from typing import TYPE_CHECKING, cast if TYPE_CHECKING: + from logging import Logger + from airflow.providers.common.compat.openlineage.facet import ( Dataset, InputDataset, @@ -31,12 +34,64 @@ RunFacet, SchemaDatasetFacet, ) + from airflow.providers.common.sql.hooks.sql import DbApiHook from airflow.providers.google.cloud.openlineage.utils import BigQueryJobRunFacet + from airflow.providers.openlineage.extractors import OperatorLineage BIGQUERY_NAMESPACE = "bigquery" +class _SQLOpenLineageMixin: + @staticmethod + def _get_openlineage_facets( + hook: DbApiHook, sql: str | Iterable[str], conn_id: str, database: str | None, logger: Logger + ) -> OperatorLineage | None: + try: + from airflow.providers.openlineage.sqlparser import SQLParser + except ImportError: + logger.debug("SQLParser could not be imported from OpenLineage provider.") + return None + + try: + from airflow.providers.openlineage.utils.utils import should_use_external_connection + + use_external_connection = should_use_external_connection(hook) + except ImportError: + # OpenLineage provider release < 1.8.0 - we always use connection + use_external_connection = True + + connection = hook.get_connection(conn_id) + try: + database_info = hook.get_openlineage_database_info(connection) + except AttributeError: + logger.debug("%s has no database info provided", hook) + database_info = None + + if database_info is None: + return None + + try: + sql_parser = SQLParser( + dialect=hook.get_openlineage_database_dialect(connection), + default_schema=hook.get_openlineage_default_schema(), + ) + except AttributeError: + logger.debug("%s failed to get database dialect", hook) + return None + + operator_lineage = sql_parser.generate_openlineage_metadata_from_sql( + sql=sql, # type:ignore[arg-type] # we expect list[str] but get Iterable[str], it's ok + hook=hook, + database_info=database_info, + database=database, + sqlalchemy_engine=hook.get_sqlalchemy_engine(), + use_connection=use_external_connection, + ) + + return operator_lineage + + class _BigQueryOpenLineageMixin: def get_openlineage_facets_on_complete(self, _): """ diff --git a/providers/src/airflow/providers/google/cloud/operators/cloud_sql.py b/providers/src/airflow/providers/google/cloud/operators/cloud_sql.py index eb75817ae7f4b..1b3ee63c2011a 100644 --- a/providers/src/airflow/providers/google/cloud/operators/cloud_sql.py +++ b/providers/src/airflow/providers/google/cloud/operators/cloud_sql.py @@ -20,6 +20,7 @@ from __future__ import annotations from collections.abc import Iterable, Mapping, Sequence +from contextlib import contextmanager from functools import cached_property from typing import TYPE_CHECKING, Any @@ -30,6 +31,7 @@ from airflow.hooks.base import BaseHook from airflow.providers.google.cloud.hooks.cloud_sql import CloudSQLDatabaseHook, CloudSQLHook from airflow.providers.google.cloud.links.cloud_sql import CloudSQLInstanceDatabaseLink, CloudSQLInstanceLink +from airflow.providers.google.cloud.openlineage.mixins import _SQLOpenLineageMixin from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator from airflow.providers.google.cloud.triggers.cloud_sql import CloudSQLExportTrigger from airflow.providers.google.cloud.utils.field_validator import GcpBodyFieldValidator @@ -38,8 +40,7 @@ if TYPE_CHECKING: from airflow.models import Connection - from airflow.providers.mysql.hooks.mysql import MySqlHook - from airflow.providers.postgres.hooks.postgres import PostgresHook + from airflow.providers.openlineage.extractors import OperatorLineage from airflow.utils.context import Context @@ -1167,7 +1168,7 @@ def execute(self, context: Context) -> None: return hook.import_instance(project_id=self.project_id, instance=self.instance, body=self.body) -class CloudSQLExecuteQueryOperator(GoogleCloudBaseOperator): +class CloudSQLExecuteQueryOperator(GoogleCloudBaseOperator, _SQLOpenLineageMixin): """ Perform DML or DDL query on an existing Cloud Sql instance. @@ -1256,7 +1257,8 @@ def __init__( self.ssl_client_key = ssl_client_key self.ssl_secret_id = ssl_secret_id - def _execute_query(self, hook: CloudSQLDatabaseHook, database_hook: PostgresHook | MySqlHook) -> None: + @contextmanager + def cloud_sql_proxy_context(self, hook: CloudSQLDatabaseHook): cloud_sql_proxy_runner = None try: if hook.use_proxy: @@ -1266,27 +1268,27 @@ def _execute_query(self, hook: CloudSQLDatabaseHook, database_hook: PostgresHook # be taken over here by another bind(0). # It's quite unlikely to happen though! cloud_sql_proxy_runner.start_proxy() - self.log.info('Executing: "%s"', self.sql) - database_hook.run(self.sql, self.autocommit, parameters=self.parameters) + yield finally: if cloud_sql_proxy_runner: cloud_sql_proxy_runner.stop_proxy() def execute(self, context: Context): - self.gcp_connection = BaseHook.get_connection(self.gcp_conn_id) - hook = self.hook hook.validate_ssl_certs() connection = hook.create_connection() hook.validate_socket_path_length() database_hook = hook.get_database_hook(connection=connection) try: - self._execute_query(hook, database_hook) + with self.cloud_sql_proxy_context(hook): + self.log.info('Executing: "%s"', self.sql) + database_hook.run(self.sql, self.autocommit, parameters=self.parameters) finally: hook.cleanup_database_hook() @cached_property def hook(self): + self.gcp_connection = BaseHook.get_connection(self.gcp_conn_id) return CloudSQLDatabaseHook( gcp_cloudsql_conn_id=self.gcp_cloudsql_conn_id, gcp_conn_id=self.gcp_conn_id, @@ -1297,3 +1299,13 @@ def hook(self): ssl_key=self.ssl_client_key, ssl_secret_id=self.ssl_secret_id, ) + + def get_openlineage_facets_on_complete(self, _) -> OperatorLineage | None: + with self.cloud_sql_proxy_context(self.hook): + return self._get_openlineage_facets( + hook=self.hook.db_hook, + sql=self.sql, + conn_id=self.gcp_cloudsql_conn_id, + database=self.hook.database, + logger=self.log, + ) diff --git a/providers/tests/google/cloud/operators/test_cloud_sql.py b/providers/tests/google/cloud/operators/test_cloud_sql.py index 50ed9afbd325d..fdfcd4ee804e1 100644 --- a/providers/tests/google/cloud/operators/test_cloud_sql.py +++ b/providers/tests/google/cloud/operators/test_cloud_sql.py @@ -19,11 +19,19 @@ import os from unittest import mock +from unittest.mock import MagicMock import pytest from airflow.exceptions import AirflowException, TaskDeferred from airflow.models import Connection +from airflow.providers.common.compat.openlineage.facet import ( + Dataset, + SchemaDatasetFacet, + SchemaDatasetFacetFields, + SQLJobFacet, +) +from airflow.providers.common.sql.hooks.sql import DbApiHook from airflow.providers.google.cloud.operators.cloud_sql import ( CloudSQLCloneInstanceOperator, CloudSQLCreateInstanceDatabaseOperator, @@ -822,3 +830,66 @@ def test_create_operator_with_too_long_unix_socket_path(self, get_connection): operator.execute(None) err = ctx.value assert "The UNIX socket path length cannot exceed" in str(err) + + @pytest.mark.parametrize( + "connection_port, default_port, expected_port", + [(None, 4321, 4321), (1234, None, 1234), (1234, 4321, 1234)], + ) + def test_execute_openlineage_events(self, connection_port, default_port, expected_port): + class DBApiHookForTests(DbApiHook): + conn_name_attr = "sql_default" + get_conn = MagicMock(name="conn") + get_connection = MagicMock() + + def get_openlineage_database_info(self, connection): + from airflow.providers.openlineage.sqlparser import DatabaseInfo + + return DatabaseInfo( + scheme="sqlscheme", + authority=DbApiHook.get_openlineage_authority_part(connection, default_port=default_port), + ) + + dbapi_hook = DBApiHookForTests() + + class CloudSQLExecuteQueryOperatorForTest(CloudSQLExecuteQueryOperator): + @property + def hook(self): + return MagicMock(db_hook=dbapi_hook, database="") + + sql = """CREATE TABLE IF NOT EXISTS popular_orders_day_of_week ( + order_day_of_week VARCHAR(64) NOT NULL, + order_placed_on TIMESTAMP NOT NULL, + orders_placed INTEGER NOT NULL + ); +FORGOT TO COMMENT""" + op = CloudSQLExecuteQueryOperatorForTest(task_id="task_id", sql=sql) + DB_SCHEMA_NAME = "PUBLIC" + rows = [ + (DB_SCHEMA_NAME, "popular_orders_day_of_week", "order_day_of_week", 1, "varchar"), + (DB_SCHEMA_NAME, "popular_orders_day_of_week", "order_placed_on", 2, "timestamp"), + (DB_SCHEMA_NAME, "popular_orders_day_of_week", "orders_placed", 3, "int4"), + ] + dbapi_hook.get_connection.return_value = Connection( + conn_id="sql_default", conn_type="postgresql", host="host", port=connection_port + ) + dbapi_hook.get_conn.return_value.cursor.return_value.fetchall.side_effect = [rows, []] + + lineage = op.get_openlineage_facets_on_complete(None) + assert len(lineage.inputs) == 0 + assert lineage.job_facets == {"sql": SQLJobFacet(query=sql)} + assert lineage.run_facets["extractionError"].failedTasks == 1 + assert lineage.outputs == [ + Dataset( + namespace=f"sqlscheme://host:{expected_port}", + name="PUBLIC.popular_orders_day_of_week", + facets={ + "schema": SchemaDatasetFacet( + fields=[ + SchemaDatasetFacetFields(name="order_day_of_week", type="varchar"), + SchemaDatasetFacetFields(name="order_placed_on", type="timestamp"), + SchemaDatasetFacetFields(name="orders_placed", type="int4"), + ] + ) + }, + ) + ]