From 7100d3da5ff215aad60498182d78894bf6f43b52 Mon Sep 17 00:00:00 2001 From: Chris Riccomini Date: Wed, 20 Dec 2023 11:49:11 -0800 Subject: [PATCH] Add `array` type for PostgreSQL converter Recap's PostgreSQL converter now supports ARRAY column types. The implementation uses `UDT_NAME` in `information_schema.columns` to get the type of the array elements. There are some weird behaviors with the implementation because of the way PostgreSQL handles arrays. Importantly, the element size is ignored (int(4)[] is just treated as int[]). PostgreSQL also treats all arrays as having an arbitrary number of dimensions, so int[] is the same as int[][], and so on. See https://github.com/recap-build/recap/issues/264 for more discussion on the nuances of this implementation. Closes #264 --- recap/converters/postgresql.py | 66 +++++++++++++++++- tests/integration/clients/test_postgresql.py | 63 +++++++++++++++++- tests/unit/converters/test_postgresql.py | 70 +++++++++++++++++++- 3 files changed, 196 insertions(+), 3 deletions(-) diff --git a/recap/converters/postgresql.py b/recap/converters/postgresql.py index 75a9397..70d1dbc 100644 --- a/recap/converters/postgresql.py +++ b/recap/converters/postgresql.py @@ -2,16 +2,38 @@ from typing import Any from recap.converters.dbapi import DbapiConverter -from recap.types import BoolType, BytesType, FloatType, IntType, RecapType, StringType +from recap.types import ( + BoolType, + BytesType, + FloatType, + IntType, + ListType, + ProxyType, + RecapType, + RecapTypeRegistry, + StringType, + UnionType, +) MAX_FIELD_SIZE = 1073741824 +DEFAULT_NAMESPACE = "_root" +""" +Namespace to use when no namespace is specified in the schema. +""" + class PostgresqlConverter(DbapiConverter): + def __init__(self, namespace: str = DEFAULT_NAMESPACE) -> None: + self.namespace = namespace + self.registry = RecapTypeRegistry() + def _parse_type(self, column_props: dict[str, Any]) -> RecapType: + column_name = column_props["COLUMN_NAME"] data_type = column_props["DATA_TYPE"].lower() octet_length = column_props["CHARACTER_OCTET_LENGTH"] max_length = column_props["CHARACTER_MAXIMUM_LENGTH"] + udt_name = (column_props["UDT_NAME"] or "").lower() if data_type in ["bigint", "int8", "bigserial", "serial8"]: base_type = IntType(bits=64, signed=True) @@ -60,6 +82,48 @@ def _parse_type(self, column_props: dict[str, Any]) -> RecapType: precision=column_props["NUMERIC_PRECISION"], scale=column_props["NUMERIC_SCALE"], ) + elif data_type == "array": + # Remove _ for standard PG types like _int4 + nested_data_type = udt_name.lstrip("_") + # Recurse to get the array value type (the int4 in int4[]) + # Postgres arrays ignore value type octet lengths for varchars, bit + # lengths, etc. Thus, we only set DATA_TYPE here. Sigh. + value_type = self._parse_type( + { + "COLUMN_NAME": None, + "DATA_TYPE": nested_data_type, + # Default strings, bits, etc. to the max field size since + # information_schema doesn't contain lengths for array + # types. + # TODO Technically, we could consult pg_attribute and + # pg_type for this information, but that's not implemented + # right now. + "CHARACTER_OCTET_LENGTH": MAX_FIELD_SIZE, + # * 8 because bit columns use bits not bytes. + "CHARACTER_MAXIMUM_LENGTH": MAX_FIELD_SIZE * 8, + "UDT_NAME": None, + } + ) + column_name_without_periods = column_name.replace(".", "_") + base_type_alias = f"{self.namespace}.{column_name_without_periods}" + # Construct a self-referencing list comprised of the array's value + # type and a proxy to the list itself. This allows arrays to be an + # arbitrary number of dimensions, which is how PostgreSQL treats + # lists. See https://github.com/recap-build/recap/issues/264 for + # more details. + base_type = ListType( + alias=base_type_alias, + values=UnionType( + types=[ + value_type, + ProxyType( + alias=base_type_alias, + registry=self.registry, + ), + ], + ), + ) + self.registry.register_alias(base_type) else: raise ValueError(f"Unknown data type: {data_type}") diff --git a/tests/integration/clients/test_postgresql.py b/tests/integration/clients/test_postgresql.py index abbb99b..651d2f7 100644 --- a/tests/integration/clients/test_postgresql.py +++ b/tests/integration/clients/test_postgresql.py @@ -8,7 +8,9 @@ BytesType, FloatType, IntType, + ListType, NullType, + ProxyType, StringType, StructType, UnionType, @@ -46,7 +48,10 @@ def setup_class(cls): test_decimal DECIMAL(10,2), test_not_null INTEGER NOT NULL, test_not_null_default INTEGER NOT NULL DEFAULT 1, - test_default INTEGER DEFAULT 2 + test_default INTEGER DEFAULT 2, + test_int_array INTEGER[], + test_varchar_array VARCHAR(255)[] DEFAULT '{"Hello", "World"}', + test_bit_array BIT(8)[] ); """ ) @@ -153,6 +158,62 @@ def test_struct_method(self): name="test_default", types=[NullType(), IntType(bits=32, signed=True)], ), + UnionType( + default=None, + name="test_int_array", + types=[ + NullType(), + ListType( + alias="_root.test_int_array", + values=UnionType( + types=[ + IntType(bits=32), + ProxyType( + alias="_root.test_int_array", + registry=client.converter.registry, # type: ignore + ), + ] + ), + ), + ], + ), + UnionType( + default="'{Hello,World}'::character varying[]", + name="test_varchar_array", + types=[ + NullType(), + ListType( + alias="_root.test_varchar_array", + values=UnionType( + types=[ + StringType(bytes_=MAX_FIELD_SIZE), + ProxyType( + alias="_root.test_varchar_array", + registry=client.converter.registry, # type: ignore + ), + ] + ), + ), + ], + ), + UnionType( + name="test_bit_array", + types=[ + NullType(), + ListType( + alias="_root.test_bit_array", + values=UnionType( + types=[ + BytesType(bytes_=MAX_FIELD_SIZE), + ProxyType( + alias="_root.test_bit_array", + registry=client.converter.registry, # type: ignore + ), + ] + ), + ), + ], + ), ] # Going field by field to make debugging easier when test fails diff --git a/tests/unit/converters/test_postgresql.py b/tests/unit/converters/test_postgresql.py index febb9b3..5e7bf3b 100644 --- a/tests/unit/converters/test_postgresql.py +++ b/tests/unit/converters/test_postgresql.py @@ -1,7 +1,16 @@ import pytest from recap.converters.postgresql import MAX_FIELD_SIZE, PostgresqlConverter -from recap.types import BoolType, BytesType, FloatType, IntType, StringType +from recap.types import ( + BoolType, + BytesType, + FloatType, + IntType, + ListType, + ProxyType, + StringType, + UnionType, +) @pytest.mark.parametrize( @@ -9,153 +18,183 @@ [ ( { + "COLUMN_NAME": "test_column", "DATA_TYPE": "bigint", "CHARACTER_MAXIMUM_LENGTH": None, "CHARACTER_OCTET_LENGTH": None, "NUMERIC_PRECISION": None, "NUMERIC_SCALE": None, + "UDT_NAME": None, }, IntType(bits=64, signed=True), ), ( { + "COLUMN_NAME": "test_column", "DATA_TYPE": "int", "CHARACTER_MAXIMUM_LENGTH": None, "CHARACTER_OCTET_LENGTH": None, "NUMERIC_PRECISION": None, "NUMERIC_SCALE": None, + "UDT_NAME": None, }, IntType(bits=32, signed=True), ), ( { + "COLUMN_NAME": "test_column", "DATA_TYPE": "smallint", "CHARACTER_MAXIMUM_LENGTH": None, "CHARACTER_OCTET_LENGTH": None, "NUMERIC_PRECISION": None, "NUMERIC_SCALE": None, + "UDT_NAME": None, }, IntType(bits=16, signed=True), ), ( { + "COLUMN_NAME": "test_column", "DATA_TYPE": "double precision", "CHARACTER_MAXIMUM_LENGTH": None, "CHARACTER_OCTET_LENGTH": None, "NUMERIC_PRECISION": None, "NUMERIC_SCALE": None, + "UDT_NAME": None, }, FloatType(bits=64), ), ( { + "COLUMN_NAME": "test_column", "DATA_TYPE": "real", "CHARACTER_MAXIMUM_LENGTH": None, "CHARACTER_OCTET_LENGTH": None, "NUMERIC_PRECISION": None, "NUMERIC_SCALE": None, + "UDT_NAME": None, }, FloatType(bits=32), ), ( { + "COLUMN_NAME": "test_column", "DATA_TYPE": "boolean", "CHARACTER_MAXIMUM_LENGTH": None, "CHARACTER_OCTET_LENGTH": None, "NUMERIC_PRECISION": None, "NUMERIC_SCALE": None, + "UDT_NAME": None, }, BoolType(), ), ( { + "COLUMN_NAME": "test_column", "DATA_TYPE": "text", "CHARACTER_MAXIMUM_LENGTH": None, "CHARACTER_OCTET_LENGTH": 65536, "NUMERIC_PRECISION": None, "NUMERIC_SCALE": None, + "UDT_NAME": None, }, StringType(bytes_=65536, variable=True), ), ( { + "COLUMN_NAME": "test_column", "DATA_TYPE": "character varying", "CHARACTER_MAXIMUM_LENGTH": None, "CHARACTER_OCTET_LENGTH": 255, "NUMERIC_PRECISION": None, "NUMERIC_SCALE": None, + "UDT_NAME": None, }, StringType(bytes_=255, variable=True), ), ( { + "COLUMN_NAME": "test_column", "DATA_TYPE": "char", "CHARACTER_MAXIMUM_LENGTH": None, "CHARACTER_OCTET_LENGTH": 255, "NUMERIC_PRECISION": None, "NUMERIC_SCALE": None, + "UDT_NAME": None, }, StringType(bytes_=255, variable=False), ), ( { + "COLUMN_NAME": "test_column", "DATA_TYPE": "bytea", "CHARACTER_MAXIMUM_LENGTH": None, "CHARACTER_OCTET_LENGTH": None, "NUMERIC_PRECISION": None, "NUMERIC_SCALE": None, + "UDT_NAME": None, }, BytesType(bytes_=MAX_FIELD_SIZE), ), ( { + "COLUMN_NAME": "test_column", "DATA_TYPE": "bit", "CHARACTER_MAXIMUM_LENGTH": 1, "CHARACTER_OCTET_LENGTH": None, "NUMERIC_PRECISION": None, "NUMERIC_SCALE": None, + "UDT_NAME": None, }, BytesType(bytes_=1, variable=False), ), ( { + "COLUMN_NAME": "test_column", "DATA_TYPE": "bit", "CHARACTER_MAXIMUM_LENGTH": 17, "CHARACTER_OCTET_LENGTH": None, "NUMERIC_PRECISION": None, "NUMERIC_SCALE": None, + "UDT_NAME": None, }, BytesType(bytes_=3, variable=False), ), ( { + "COLUMN_NAME": "test_column", "DATA_TYPE": "timestamp", "CHARACTER_MAXIMUM_LENGTH": None, "CHARACTER_OCTET_LENGTH": None, "NUMERIC_PRECISION": None, "NUMERIC_SCALE": None, + "UDT_NAME": None, "DATETIME_PRECISION": 3, }, IntType(bits=64, logical="build.recap.Timestamp", unit="millisecond"), ), ( { + "COLUMN_NAME": "test_column", "DATA_TYPE": "timestamp", "CHARACTER_MAXIMUM_LENGTH": None, "CHARACTER_OCTET_LENGTH": None, "NUMERIC_PRECISION": None, "NUMERIC_SCALE": None, + "UDT_NAME": None, "DATETIME_PRECISION": 3, }, IntType(bits=64, logical="build.recap.Timestamp", unit="millisecond"), ), ( { + "COLUMN_NAME": "test_column", "DATA_TYPE": "decimal", "CHARACTER_MAXIMUM_LENGTH": None, "CHARACTER_OCTET_LENGTH": None, "NUMERIC_PRECISION": 10, "NUMERIC_SCALE": 2, + "UDT_NAME": None, }, BytesType( logical="build.recap.Decimal", @@ -167,11 +206,13 @@ ), ( { + "COLUMN_NAME": "test_column", "DATA_TYPE": "numeric", "CHARACTER_MAXIMUM_LENGTH": None, "CHARACTER_OCTET_LENGTH": None, "NUMERIC_PRECISION": 5, "NUMERIC_SCALE": 0, + "UDT_NAME": None, }, BytesType( logical="build.recap.Decimal", @@ -186,3 +227,30 @@ def test_postgresql_converter(column_props, expected): result = PostgresqlConverter()._parse_type(column_props) assert result == expected + + +def test_postgresql_converter_array(): + converter = PostgresqlConverter() + column_props = { + "COLUMN_NAME": "test_column", + "DATA_TYPE": "array", + "CHARACTER_MAXIMUM_LENGTH": None, + "CHARACTER_OCTET_LENGTH": None, + "NUMERIC_PRECISION": 5, + "NUMERIC_SCALE": 0, + "UDT_NAME": "_int4", + } + expected = ListType( + alias="_root.test_column", + values=UnionType( + types=[ + IntType(bits=32, signed=True), + ProxyType( + alias="_root.test_column", + registry=converter.registry, + ), + ], + ), + ) + result = converter._parse_type(column_props) + assert result == expected