Skip to content

Commit

Permalink
feat: Add OpenLineage support for CloudSQLExecuteQueryOperator
Browse files Browse the repository at this point in the history
Signed-off-by: Kacper Muda <[email protected]>
  • Loading branch information
kacpermuda committed Dec 27, 2024
1 parent 60cd5ad commit 3d8df61
Show file tree
Hide file tree
Showing 7 changed files with 220 additions and 12 deletions.
2 changes: 1 addition & 1 deletion generated/provider_dependencies.json
Original file line number Diff line number Diff line change
Expand Up @@ -631,7 +631,7 @@
"google": {
"deps": [
"PyOpenSSL>=23.0.0",
"apache-airflow-providers-common-compat>=1.3.0",
"apache-airflow-providers-common-compat>=1.4.0",
"apache-airflow-providers-common-sql>=1.20.0",
"apache-airflow>=2.9.0",
"asgiref>=3.5.2",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

import logging
from typing import TYPE_CHECKING

log = logging.getLogger(__name__)

if TYPE_CHECKING:
from airflow.providers.openlineage.sqlparser import get_openlineage_facets_with_sql

else:
try:
from airflow.providers.openlineage.sqlparser import get_openlineage_facets_with_sql
except ImportError:

def get_openlineage_facets_with_sql(
hook,
sql: str | list[str],
conn_id: str,
database: str | None,
):
try:
from airflow.providers.openlineage.sqlparser import SQLParser
except ImportError:
log.debug("SQLParser could not be imported from OpenLineage provider.")
return None

try:
from airflow.providers.openlineage.utils.utils import should_use_external_connection

use_external_connection = should_use_external_connection(hook)
except ImportError:
# OpenLineage provider release < 1.8.0 - we always use connection
use_external_connection = True

connection = hook.get_connection(conn_id)
try:
database_info = hook.get_openlineage_database_info(connection)
except AttributeError:
log.debug("%s has no database info provided", hook)
database_info = None

if database_info is None:
return None

try:
sql_parser = SQLParser(
dialect=hook.get_openlineage_database_dialect(connection),
default_schema=hook.get_openlineage_default_schema(),
)
except AttributeError:
log.debug("%s failed to get database dialect", hook)
return None

operator_lineage = sql_parser.generate_openlineage_metadata_from_sql(
sql=sql,
hook=hook,
database_info=database_info,
database=database,
sqlalchemy_engine=hook.get_sqlalchemy_engine(),
use_connection=use_external_connection,
)

return operator_lineage


__all__ = ["get_openlineage_facets_with_sql"]
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@
from google.cloud.secretmanager_v1 import AccessSecretVersionResponse
from requests import Session

from airflow.providers.common.sql.hooks.sql import DbApiHook

UNIX_PATH_MAX = 108

# Time to sleep between active checks of the operation results
Expand Down Expand Up @@ -1146,7 +1148,7 @@ def get_sqlproxy_runner(self) -> CloudSqlProxyRunner:
gcp_conn_id=self.gcp_conn_id,
)

def get_database_hook(self, connection: Connection) -> BaseHook:
def get_database_hook(self, connection: Connection) -> DbApiHook:
"""
Retrieve database hook.
Expand All @@ -1156,7 +1158,7 @@ def get_database_hook(self, connection: Connection) -> BaseHook:
if self.database_type == "postgres":
from airflow.providers.postgres.hooks.postgres import PostgresHook

db_hook: BaseHook = PostgresHook(connection=connection, database=self.database)
db_hook: DbApiHook = PostgresHook(connection=connection, database=self.database)
else:
from airflow.providers.mysql.hooks.mysql import MySqlHook

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from __future__ import annotations

from collections.abc import Iterable, Mapping, Sequence
from contextlib import contextmanager
from functools import cached_property
from typing import TYPE_CHECKING, Any

Expand All @@ -38,8 +39,7 @@

if TYPE_CHECKING:
from airflow.models import Connection
from airflow.providers.mysql.hooks.mysql import MySqlHook
from airflow.providers.postgres.hooks.postgres import PostgresHook
from airflow.providers.openlineage.extractors import OperatorLineage
from airflow.utils.context import Context


Expand Down Expand Up @@ -1256,7 +1256,8 @@ def __init__(
self.ssl_client_key = ssl_client_key
self.ssl_secret_id = ssl_secret_id

def _execute_query(self, hook: CloudSQLDatabaseHook, database_hook: PostgresHook | MySqlHook) -> None:
@contextmanager
def cloud_sql_proxy_context(self, hook: CloudSQLDatabaseHook):
cloud_sql_proxy_runner = None
try:
if hook.use_proxy:
Expand All @@ -1266,27 +1267,27 @@ def _execute_query(self, hook: CloudSQLDatabaseHook, database_hook: PostgresHook
# be taken over here by another bind(0).
# It's quite unlikely to happen though!
cloud_sql_proxy_runner.start_proxy()
self.log.info('Executing: "%s"', self.sql)
database_hook.run(self.sql, self.autocommit, parameters=self.parameters)
yield
finally:
if cloud_sql_proxy_runner:
cloud_sql_proxy_runner.stop_proxy()

def execute(self, context: Context):
self.gcp_connection = BaseHook.get_connection(self.gcp_conn_id)

hook = self.hook
hook.validate_ssl_certs()
connection = hook.create_connection()
hook.validate_socket_path_length()
database_hook = hook.get_database_hook(connection=connection)
try:
self._execute_query(hook, database_hook)
with self.cloud_sql_proxy_context(hook):
self.log.info('Executing: "%s"', self.sql)
database_hook.run(self.sql, self.autocommit, parameters=self.parameters)
finally:
hook.cleanup_database_hook()

@cached_property
def hook(self):
self.gcp_connection = BaseHook.get_connection(self.gcp_conn_id)
return CloudSQLDatabaseHook(
gcp_cloudsql_conn_id=self.gcp_cloudsql_conn_id,
gcp_conn_id=self.gcp_conn_id,
Expand All @@ -1297,3 +1298,14 @@ def hook(self):
ssl_key=self.ssl_client_key,
ssl_secret_id=self.ssl_secret_id,
)

def get_openlineage_facets_on_complete(self, _) -> OperatorLineage | None:
from airflow.providers.common.compat.openlineage.utils.sql import get_openlineage_facets_with_sql

with self.cloud_sql_proxy_context(self.hook):
return get_openlineage_facets_with_sql(
hook=self.hook.db_hook,
sql=self.sql, # type:ignore[arg-type] # Iterable[str] instead of list[str]
conn_id=self.gcp_cloudsql_conn_id,
database=self.hook.database,
)
2 changes: 1 addition & 1 deletion providers/src/airflow/providers/google/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ versions:

dependencies:
- apache-airflow>=2.9.0
- apache-airflow-providers-common-compat>=1.3.0
- apache-airflow-providers-common-compat>=1.4.0
- apache-airflow-providers-common-sql>=1.20.0
- asgiref>=3.5.2
- dill>=0.2.3
Expand Down
39 changes: 39 additions & 0 deletions providers/src/airflow/providers/openlineage/sqlparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.
from __future__ import annotations

import logging
from typing import TYPE_CHECKING, Callable

import sqlparse
Expand All @@ -30,6 +31,7 @@
create_information_schema_query,
get_table_schemas,
)
from airflow.providers.openlineage.utils.utils import should_use_external_connection
from airflow.typing_compat import TypedDict
from airflow.utils.log.logging_mixin import LoggingMixin

Expand All @@ -38,6 +40,9 @@
from sqlalchemy.engine import Engine

from airflow.hooks.base import BaseHook
from airflow.providers.common.sql.hooks.sql import DbApiHook

log = logging.getLogger(__name__)

DEFAULT_NAMESPACE = "default"
DEFAULT_INFORMATION_SCHEMA_COLUMNS = [
Expand Down Expand Up @@ -397,3 +402,37 @@ def _get_tables_hierarchy(
tables = schemas.setdefault(normalize_name(table.schema) if table.schema else None, [])
tables.append(table.name)
return hierarchy


def get_openlineage_facets_with_sql(
hook: DbApiHook, sql: str | list[str], conn_id: str, database: str | None
) -> OperatorLineage | None:
connection = hook.get_connection(conn_id)
try:
database_info = hook.get_openlineage_database_info(connection)
except AttributeError:
database_info = None

if database_info is None:
log.debug("%s has no database info provided", hook)
return None

try:
sql_parser = SQLParser(
dialect=hook.get_openlineage_database_dialect(connection),
default_schema=hook.get_openlineage_default_schema(),
)
except AttributeError:
log.debug("%s failed to get database dialect", hook)
return None

operator_lineage = sql_parser.generate_openlineage_metadata_from_sql(
sql=sql,
hook=hook,
database_info=database_info,
database=database,
sqlalchemy_engine=hook.get_sqlalchemy_engine(),
use_connection=should_use_external_connection(hook),
)

return operator_lineage
71 changes: 71 additions & 0 deletions providers/tests/google/cloud/operators/test_cloud_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,19 @@

import os
from unittest import mock
from unittest.mock import MagicMock

import pytest

from airflow.exceptions import AirflowException, TaskDeferred
from airflow.models import Connection
from airflow.providers.common.compat.openlineage.facet import (
Dataset,
SchemaDatasetFacet,
SchemaDatasetFacetFields,
SQLJobFacet,
)
from airflow.providers.common.sql.hooks.sql import DbApiHook
from airflow.providers.google.cloud.operators.cloud_sql import (
CloudSQLCloneInstanceOperator,
CloudSQLCreateInstanceDatabaseOperator,
Expand Down Expand Up @@ -822,3 +830,66 @@ def test_create_operator_with_too_long_unix_socket_path(self, get_connection):
operator.execute(None)
err = ctx.value
assert "The UNIX socket path length cannot exceed" in str(err)

@pytest.mark.parametrize(
"connection_port, default_port, expected_port",
[(None, 4321, 4321), (1234, None, 1234), (1234, 4321, 1234)],
)
def test_execute_openlineage_events(self, connection_port, default_port, expected_port):
class DBApiHookForTests(DbApiHook):
conn_name_attr = "sql_default"
get_conn = MagicMock(name="conn")
get_connection = MagicMock()

def get_openlineage_database_info(self, connection):
from airflow.providers.openlineage.sqlparser import DatabaseInfo

return DatabaseInfo(
scheme="sqlscheme",
authority=DbApiHook.get_openlineage_authority_part(connection, default_port=default_port),
)

dbapi_hook = DBApiHookForTests()

class CloudSQLExecuteQueryOperatorForTest(CloudSQLExecuteQueryOperator):
@property
def hook(self):
return MagicMock(db_hook=dbapi_hook, database="")

sql = """CREATE TABLE IF NOT EXISTS popular_orders_day_of_week (
order_day_of_week VARCHAR(64) NOT NULL,
order_placed_on TIMESTAMP NOT NULL,
orders_placed INTEGER NOT NULL
);
FORGOT TO COMMENT"""
op = CloudSQLExecuteQueryOperatorForTest(task_id="task_id", sql=sql)
DB_SCHEMA_NAME = "PUBLIC"
rows = [
(DB_SCHEMA_NAME, "popular_orders_day_of_week", "order_day_of_week", 1, "varchar"),
(DB_SCHEMA_NAME, "popular_orders_day_of_week", "order_placed_on", 2, "timestamp"),
(DB_SCHEMA_NAME, "popular_orders_day_of_week", "orders_placed", 3, "int4"),
]
dbapi_hook.get_connection.return_value = Connection(
conn_id="sql_default", conn_type="postgresql", host="host", port=connection_port
)
dbapi_hook.get_conn.return_value.cursor.return_value.fetchall.side_effect = [rows, []]

lineage = op.get_openlineage_facets_on_complete(None)
assert len(lineage.inputs) == 0
assert lineage.job_facets == {"sql": SQLJobFacet(query=sql)}
assert lineage.run_facets["extractionError"].failedTasks == 1
assert lineage.outputs == [
Dataset(
namespace=f"sqlscheme://host:{expected_port}",
name="PUBLIC.popular_orders_day_of_week",
facets={
"schema": SchemaDatasetFacet(
fields=[
SchemaDatasetFacetFields(name="order_day_of_week", type="varchar"),
SchemaDatasetFacetFields(name="order_placed_on", type="timestamp"),
SchemaDatasetFacetFields(name="orders_placed", type="int4"),
]
)
},
)
]

0 comments on commit 3d8df61

Please sign in to comment.