diff --git a/recap/clients/postgresql.py b/recap/clients/postgresql.py index 0427193..d35a8d0 100644 --- a/recap/clients/postgresql.py +++ b/recap/clients/postgresql.py @@ -5,6 +5,7 @@ from recap.clients.dbapi import Connection, DbapiClient from recap.converters.postgresql import PostgresqlConverter +from recap.types import StructType PSYCOPG2_CONNECT_ARGS = { "host", @@ -48,8 +49,12 @@ class PostgresqlClient(DbapiClient): - def __init__(self, connection: Connection) -> None: - super().__init__(connection, PostgresqlConverter()) + def __init__( + self, + connection: Connection, + converter: PostgresqlConverter = PostgresqlConverter(), + ) -> None: + super().__init__(connection, converter) @staticmethod @contextmanager @@ -78,3 +83,32 @@ def ls_catalogs(self) -> list[str]: """ ) return [row[0] for row in cursor.fetchall()] + + def schema(self, catalog: str, schema: str, table: str) -> StructType: + cursor = self.connection.cursor() + cursor.execute( + f""" + SELECT + information_schema.columns.*, + pg_attribute.attndims + FROM information_schema.columns + JOIN pg_catalog.pg_namespace + ON pg_catalog.pg_namespace.nspname = information_schema.columns.table_schema + JOIN pg_catalog.pg_class + ON pg_catalog.pg_class.relname = information_schema.columns.table_name + AND pg_catalog.pg_class.relnamespace = pg_catalog.pg_namespace.oid + JOIN pg_catalog.pg_attribute + ON pg_catalog.pg_attribute.attrelid = pg_catalog.pg_class.oid + AND pg_catalog.pg_attribute.attname = information_schema.columns.column_name + WHERE table_name = {self.param_style} + AND table_schema = {self.param_style} + AND table_catalog = {self.param_style} + ORDER BY ordinal_position ASC + """, + (table, schema, catalog), + ) + names = [name[0].upper() for name in cursor.description] + return self.converter.to_recap( + # Make each row be a dict with the column names as keys + [dict(zip(names, row)) for row in cursor.fetchall()] + ) diff --git a/recap/converters/postgresql.py b/recap/converters/postgresql.py index 70d1dbc..e5f11f5 100644 --- a/recap/converters/postgresql.py +++ b/recap/converters/postgresql.py @@ -8,6 +8,7 @@ FloatType, IntType, ListType, + NullType, ProxyType, RecapType, RecapTypeRegistry, @@ -24,7 +25,15 @@ class PostgresqlConverter(DbapiConverter): - def __init__(self, namespace: str = DEFAULT_NAMESPACE) -> None: + def __init__( + self, + enforce_array_dimensions: bool = False, + namespace: str = DEFAULT_NAMESPACE, + ): + # since array dimensionality is not enforced by PG schemas: + # if `enforce_array_dimensions = False` then read arrays irrespective of how many dimensions they have + # if `enforce_array_dimensions = True` then read arrays as nested lists + self.enforce_array_dimensions = enforce_array_dimensions self.namespace = namespace self.registry = RecapTypeRegistry() @@ -34,6 +43,7 @@ def _parse_type(self, column_props: dict[str, Any]) -> RecapType: octet_length = column_props["CHARACTER_OCTET_LENGTH"] max_length = column_props["CHARACTER_MAXIMUM_LENGTH"] udt_name = (column_props["UDT_NAME"] or "").lower() + ndims = column_props["ATTNDIMS"] if data_type in ["bigint", "int8", "bigserial", "serial8"]: base_type = IntType(bits=64, signed=True) @@ -102,29 +112,44 @@ def _parse_type(self, column_props: dict[str, Any]) -> RecapType: # * 8 because bit columns use bits not bytes. "CHARACTER_MAXIMUM_LENGTH": MAX_FIELD_SIZE * 8, "UDT_NAME": None, + "ATTNDIMS": 0, } ) - column_name_without_periods = column_name.replace(".", "_") - base_type_alias = f"{self.namespace}.{column_name_without_periods}" - # Construct a self-referencing list comprised of the array's value - # type and a proxy to the list itself. This allows arrays to be an - # arbitrary number of dimensions, which is how PostgreSQL treats - # lists. See https://github.com/recap-build/recap/issues/264 for - # more details. - base_type = ListType( - alias=base_type_alias, - values=UnionType( - types=[ - value_type, - ProxyType( - alias=base_type_alias, - registry=self.registry, - ), - ], - ), - ) - self.registry.register_alias(base_type) + if self.enforce_array_dimensions: + base_type = self._create_n_dimension_list(value_type, ndims) + else: + column_name_without_periods = column_name.replace(".", "_") + base_type_alias = f"{self.namespace}.{column_name_without_periods}" + # Construct a self-referencing list comprised of the array's value + # type and a proxy to the list itself. This allows arrays to be an + # arbitrary number of dimensions, which is how PostgreSQL treats + # lists. See https://github.com/recap-build/recap/issues/264 for + # more details. + base_type = ListType( + alias=base_type_alias, + values=UnionType( + types=[ + value_type, + ProxyType( + alias=base_type_alias, + registry=self.registry, + ), + ], + ), + ) + self.registry.register_alias(base_type) else: raise ValueError(f"Unknown data type: {data_type}") return base_type + + def _create_n_dimension_list(self, base_type: RecapType, ndims: int) -> RecapType: + """ + Build a list type with `ndims` dimensions containing nullable `base_type` as the innermost value type. + """ + if ndims == 0: + return UnionType(types=[NullType(), base_type]) + else: + return ListType( + values=self._create_n_dimension_list(base_type, ndims - 1), + ) diff --git a/tests/integration/clients/test_postgresql.py b/tests/integration/clients/test_postgresql.py index 6668e3c..27c70e0 100644 --- a/tests/integration/clients/test_postgresql.py +++ b/tests/integration/clients/test_postgresql.py @@ -2,7 +2,7 @@ from recap.clients import create_client from recap.clients.postgresql import PostgresqlClient -from recap.converters.postgresql import MAX_FIELD_SIZE +from recap.converters.postgresql import MAX_FIELD_SIZE, PostgresqlConverter from recap.types import ( BoolType, BytesType, @@ -11,6 +11,7 @@ ListType, NullType, ProxyType, + RecapType, StringType, StructType, UnionType, @@ -51,7 +52,10 @@ def setup_class(cls): test_default INTEGER DEFAULT 2, test_int_array INTEGER[], test_varchar_array VARCHAR(255)[] DEFAULT '{"Hello", "World"}', - test_bit_array BIT(8)[] + test_bit_array BIT(8)[], + test_not_null_array INTEGER[] NOT NULL, + test_int_array_2d INTEGER[][], + test_text_array_3d TEXT[][][] ); """ ) @@ -67,14 +71,10 @@ def teardown_class(cls): # Close the connection cls.connection.close() - def test_struct_method(self): - # Initiate the PostgresqlClient class - client = PostgresqlClient(self.connection) # type: ignore - - # Test 'test_types' table + def test_struct_method_arrays_no_enforce_dimensions(self): + client = PostgresqlClient(self.connection, PostgresqlConverter(False)) test_types_struct = client.schema("testdb", "public", "test_types") - # Define the expected output for 'test_types' table expected_fields = [ UnionType( default=None, @@ -215,13 +215,241 @@ def test_struct_method(self): ), ], ), + ListType( + name="test_not_null_array", + alias="_root.test_not_null_array", + values=UnionType( + types=[ + IntType(bits=32), + ProxyType( + alias="_root.test_not_null_array", + registry=client.converter.registry, # type: ignore + ), + ] + ), + ), + UnionType( + default=None, + name="test_int_array_2d", + types=[ + NullType(), + ListType( + alias="_root.test_int_array_2d", + values=UnionType( + types=[ + IntType(bits=32), + ProxyType( + alias="_root.test_int_array_2d", + registry=client.converter.registry, # type: ignore + ), + ] + ), + ), + ], + ), + UnionType( + default=None, + name="test_text_array_3d", + types=[ + NullType(), + ListType( + alias="_root.test_text_array_3d", + values=UnionType( + types=[ + StringType(bytes_=MAX_FIELD_SIZE, variable=True), + ProxyType( + alias="_root.test_text_array_3d", + registry=client.converter.registry, # type: ignore + ), + ] + ), + ), + ], + ), ] + validate_results(test_types_struct, expected_fields) - # Going field by field to make debugging easier when test fails - for field, expected_field in zip(test_types_struct.fields, expected_fields): - assert field == expected_field + def test_struct_method_arrays_enforce_dimensions(self): + client = PostgresqlClient(self.connection, PostgresqlConverter(True)) # type: ignore + test_types_struct = client.schema("testdb", "public", "test_types") - assert test_types_struct == StructType(fields=expected_fields) + expected_fields = [ + UnionType( + default=None, + name="test_bigint", + types=[NullType(), IntType(bits=64, signed=True)], + ), + UnionType( + default=None, + name="test_integer", + types=[NullType(), IntType(bits=32, signed=True)], + ), + UnionType( + default=None, + name="test_smallint", + types=[NullType(), IntType(bits=16, signed=True)], + ), + UnionType( + default=None, + name="test_float", + types=[NullType(), FloatType(bits=64)], + ), + UnionType( + default=None, + name="test_real", + types=[NullType(), FloatType(bits=32)], + ), + UnionType( + default=None, + name="test_boolean", + types=[NullType(), BoolType()], + ), + UnionType( + default=None, + name="test_text", + types=[NullType(), StringType(bytes_=MAX_FIELD_SIZE, variable=True)], + ), + UnionType( + default=None, + name="test_char", + # 40 = max of 4 bytes in a UTF-8 encoded unicode character * 10 chars + types=[NullType(), StringType(bytes_=40, variable=False)], + ), + UnionType( + default=None, + name="test_bytea", + types=[NullType(), BytesType(bytes_=MAX_FIELD_SIZE, variable=True)], + ), + UnionType( + default=None, + name="test_bit", + types=[NullType(), BytesType(bytes_=2, variable=False)], + ), + UnionType( + default=None, + name="test_timestamp", + types=[ + NullType(), + IntType( + bits=64, logical="build.recap.Timestamp", unit="microsecond" + ), + ], + ), + UnionType( + default=None, + name="test_decimal", + types=[ + NullType(), + BytesType( + logical="build.recap.Decimal", + bytes_=32, + variable=False, + precision=10, + scale=2, + ), + ], + ), + IntType(bits=32, signed=True, name="test_not_null"), + IntType(bits=32, signed=True, name="test_not_null_default", default="1"), + UnionType( + default="2", + name="test_default", + types=[NullType(), IntType(bits=32, signed=True)], + ), + UnionType( + default=None, + name="test_int_array", + types=[ + NullType(), + ListType( + values=UnionType( + types=[ + NullType(), + IntType(bits=32), + ] + ), + ), + ], + ), + UnionType( + default="'{Hello,World}'::character varying[]", + name="test_varchar_array", + types=[ + NullType(), + ListType( + values=UnionType( + types=[ + NullType(), + StringType(bytes_=MAX_FIELD_SIZE), + ] + ), + ), + ], + ), + UnionType( + default=None, + name="test_bit_array", + types=[ + NullType(), + ListType( + values=UnionType( + types=[ + NullType(), + BytesType(bytes_=MAX_FIELD_SIZE, variable=False), + ] + ), + ), + ], + ), + ListType( + name="test_not_null_array", + values=UnionType( + types=[ + NullType(), + IntType(bits=32), + ] + ), + ), + UnionType( + default=None, + name="test_int_array_2d", + types=[ + NullType(), + ListType( + values=ListType( + values=UnionType( + types=[ + NullType(), + IntType(bits=32), + ] + ) + ), + ), + ], + ), + UnionType( + default=None, + name="test_text_array_3d", + types=[ + NullType(), + ListType( + values=ListType( + values=ListType( + values=UnionType( + types=[ + NullType(), + StringType( + bytes_=MAX_FIELD_SIZE, variable=True + ), + ] + ) + ) + ), + ), + ], + ), + ] + validate_results(test_types_struct, expected_fields) def test_create_client(self): postgresql_url = "postgresql://postgres:password@localhost:5432/testdb" @@ -235,3 +463,13 @@ def test_create_client(self): "information_schema", ] assert client.ls("testdb", "public") == ["test_types"] + + +def validate_results( + test_types_struct: StructType, expected_fields: list[RecapType] +) -> None: + # Going field by field to make debugging easier when test fails + for field, expected_field in zip(test_types_struct.fields, expected_fields): + assert field == expected_field + + assert test_types_struct == StructType(fields=expected_fields) diff --git a/tests/unit/converters/test_postgresql.py b/tests/unit/converters/test_postgresql.py index 5e7bf3b..4970cf8 100644 --- a/tests/unit/converters/test_postgresql.py +++ b/tests/unit/converters/test_postgresql.py @@ -7,6 +7,7 @@ FloatType, IntType, ListType, + NullType, ProxyType, StringType, UnionType, @@ -25,6 +26,7 @@ "NUMERIC_PRECISION": None, "NUMERIC_SCALE": None, "UDT_NAME": None, + "ATTNDIMS": 0, }, IntType(bits=64, signed=True), ), @@ -37,6 +39,7 @@ "NUMERIC_PRECISION": None, "NUMERIC_SCALE": None, "UDT_NAME": None, + "ATTNDIMS": 0, }, IntType(bits=32, signed=True), ), @@ -49,6 +52,7 @@ "NUMERIC_PRECISION": None, "NUMERIC_SCALE": None, "UDT_NAME": None, + "ATTNDIMS": 0, }, IntType(bits=16, signed=True), ), @@ -61,6 +65,7 @@ "NUMERIC_PRECISION": None, "NUMERIC_SCALE": None, "UDT_NAME": None, + "ATTNDIMS": 0, }, FloatType(bits=64), ), @@ -73,6 +78,7 @@ "NUMERIC_PRECISION": None, "NUMERIC_SCALE": None, "UDT_NAME": None, + "ATTNDIMS": 0, }, FloatType(bits=32), ), @@ -85,6 +91,7 @@ "NUMERIC_PRECISION": None, "NUMERIC_SCALE": None, "UDT_NAME": None, + "ATTNDIMS": 0, }, BoolType(), ), @@ -97,6 +104,7 @@ "NUMERIC_PRECISION": None, "NUMERIC_SCALE": None, "UDT_NAME": None, + "ATTNDIMS": 0, }, StringType(bytes_=65536, variable=True), ), @@ -109,6 +117,7 @@ "NUMERIC_PRECISION": None, "NUMERIC_SCALE": None, "UDT_NAME": None, + "ATTNDIMS": 0, }, StringType(bytes_=255, variable=True), ), @@ -121,6 +130,7 @@ "NUMERIC_PRECISION": None, "NUMERIC_SCALE": None, "UDT_NAME": None, + "ATTNDIMS": 0, }, StringType(bytes_=255, variable=False), ), @@ -133,6 +143,7 @@ "NUMERIC_PRECISION": None, "NUMERIC_SCALE": None, "UDT_NAME": None, + "ATTNDIMS": 0, }, BytesType(bytes_=MAX_FIELD_SIZE), ), @@ -145,6 +156,7 @@ "NUMERIC_PRECISION": None, "NUMERIC_SCALE": None, "UDT_NAME": None, + "ATTNDIMS": 0, }, BytesType(bytes_=1, variable=False), ), @@ -157,6 +169,7 @@ "NUMERIC_PRECISION": None, "NUMERIC_SCALE": None, "UDT_NAME": None, + "ATTNDIMS": 0, }, BytesType(bytes_=3, variable=False), ), @@ -170,6 +183,7 @@ "NUMERIC_SCALE": None, "UDT_NAME": None, "DATETIME_PRECISION": 3, + "ATTNDIMS": 0, }, IntType(bits=64, logical="build.recap.Timestamp", unit="millisecond"), ), @@ -183,6 +197,7 @@ "NUMERIC_SCALE": None, "UDT_NAME": None, "DATETIME_PRECISION": 3, + "ATTNDIMS": 0, }, IntType(bits=64, logical="build.recap.Timestamp", unit="millisecond"), ), @@ -195,6 +210,7 @@ "NUMERIC_PRECISION": 10, "NUMERIC_SCALE": 2, "UDT_NAME": None, + "ATTNDIMS": 0, }, BytesType( logical="build.recap.Decimal", @@ -213,6 +229,7 @@ "NUMERIC_PRECISION": 5, "NUMERIC_SCALE": 0, "UDT_NAME": None, + "ATTNDIMS": 0, }, BytesType( logical="build.recap.Decimal", @@ -229,8 +246,8 @@ def test_postgresql_converter(column_props, expected): assert result == expected -def test_postgresql_converter_array(): - converter = PostgresqlConverter() +def test_postgresql_converter_array_enforce_dimensions(): + converter = PostgresqlConverter(True) column_props = { "COLUMN_NAME": "test_column", "DATA_TYPE": "array", @@ -239,6 +256,31 @@ def test_postgresql_converter_array(): "NUMERIC_PRECISION": 5, "NUMERIC_SCALE": 0, "UDT_NAME": "_int4", + "ATTNDIMS": 1, + } + expected = ListType( + values=UnionType( + types=[ + NullType(), + IntType(bits=32, signed=True), + ], + ), + ) + result = converter._parse_type(column_props) + assert result == expected + + +def test_postgresql_converter_array_no_enforce_dimensions(): + converter = PostgresqlConverter(False) + column_props = { + "COLUMN_NAME": "test_column", + "DATA_TYPE": "array", + "CHARACTER_MAXIMUM_LENGTH": None, + "CHARACTER_OCTET_LENGTH": None, + "NUMERIC_PRECISION": 5, + "NUMERIC_SCALE": 0, + "UDT_NAME": "_int4", + "ATTNDIMS": 1, } expected = ListType( alias="_root.test_column",