Skip to content

Commit

Permalink
Add array type for PostgreSQL converter
Browse files Browse the repository at this point in the history
Recap's PostgreSQL converter now supports ARRAY column types. The
implementation uses `UDT_NAME` in `information_schema.columns` to get the type
of the array elements.

There are some weird behaviors with the implementation because of the way
PostgreSQL handles arrays. Importantly, the element size is ignored (int(4)[]
is just treated as int[]). PostgreSQL also treats all arrays as having an
arbitrary number of dimensions, so int[] is the same as int[][], and so on.

See #264 for more discussion on the
nuances of this implementation.

Closes #264
  • Loading branch information
criccomini committed Dec 20, 2023
1 parent fcab653 commit 97245ca
Show file tree
Hide file tree
Showing 3 changed files with 196 additions and 3 deletions.
66 changes: 65 additions & 1 deletion recap/converters/postgresql.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,38 @@
from typing import Any

from recap.converters.dbapi import DbapiConverter
from recap.types import BoolType, BytesType, FloatType, IntType, RecapType, StringType
from recap.types import (
BoolType,
BytesType,
FloatType,
IntType,
ListType,
ProxyType,
RecapType,
RecapTypeRegistry,
StringType,
UnionType,
)

MAX_FIELD_SIZE = 1073741824

DEFAULT_NAMESPACE = "_root"
"""
Namespace to use when no namespace is specified in the schema.
"""


class PostgresqlConverter(DbapiConverter):
def __init__(self, namespace: str = DEFAULT_NAMESPACE) -> None:
self.namespace = namespace
self.registry = RecapTypeRegistry()

def _parse_type(self, column_props: dict[str, Any]) -> RecapType:
column_name = column_props["COLUMN_NAME"]
data_type = column_props["DATA_TYPE"].lower()
octet_length = column_props["CHARACTER_OCTET_LENGTH"]
max_length = column_props["CHARACTER_MAXIMUM_LENGTH"]
udt_name = (column_props["UDT_NAME"] or "").lower()

if data_type in ["bigint", "int8", "bigserial", "serial8"]:
base_type = IntType(bits=64, signed=True)
Expand Down Expand Up @@ -60,6 +82,48 @@ def _parse_type(self, column_props: dict[str, Any]) -> RecapType:
precision=column_props["NUMERIC_PRECISION"],
scale=column_props["NUMERIC_SCALE"],
)
elif data_type == "array":
# Remove _ for standard PG types like _int4
nested_data_type = udt_name.lstrip("_")
# Recurse to get the array value type (the int4 in int4[])
# Postgres arrays ignore value type octet lengths for varchars, bit
# lengths, etc. Thus, we only set DATA_TYPE here. Sigh.
value_type = self._parse_type(
{
"COLUMN_NAME": None,
"DATA_TYPE": nested_data_type,
# Default strings, bits, etc. to the max field size since
# information_schema doesn't contain lengths for array
# types.
# TODO Technically, we could consult pg_attribute and
# pg_type for this information, but that's not implemented
# right now.
"CHARACTER_OCTET_LENGTH": MAX_FIELD_SIZE,
# * 8 because bit columns use bits not bytes.
"CHARACTER_MAXIMUM_LENGTH": MAX_FIELD_SIZE * 8,
"UDT_NAME": None,
}
)
column_name_without_periods = column_name.replace(".", "_")
base_type_alias = f"{self.namespace}.{column_name_without_periods}"
# Construct a self-referencing list comprised of the array's value
# type and a proxy to the list itself. This allows arrays to be an
# arbitrary number of dimensions, which is how PostgreSQL treats
# lists. See https://github.com/recap-build/recap/issues/264 for
# more details.
base_type = ListType(
alias=base_type_alias,
values=UnionType(
types=[
value_type,
ProxyType(
alias=base_type_alias,
registry=self.registry,
),
],
),
)
self.registry.register_alias(base_type)
else:
raise ValueError(f"Unknown data type: {data_type}")

Expand Down
63 changes: 62 additions & 1 deletion tests/integration/clients/test_postgresql.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
BytesType,
FloatType,
IntType,
ListType,
NullType,
ProxyType,
StringType,
StructType,
UnionType,
Expand Down Expand Up @@ -46,7 +48,10 @@ def setup_class(cls):
test_decimal DECIMAL(10,2),
test_not_null INTEGER NOT NULL,
test_not_null_default INTEGER NOT NULL DEFAULT 1,
test_default INTEGER DEFAULT 2
test_default INTEGER DEFAULT 2,
test_int_array INTEGER[],
test_varchar_array VARCHAR(255)[] DEFAULT '{"Hello", "World"}',
test_bit_array BIT(8)[]
);
"""
)
Expand Down Expand Up @@ -153,6 +158,62 @@ def test_struct_method(self):
name="test_default",
types=[NullType(), IntType(bits=32, signed=True)],
),
UnionType(
default=None,
name="test_int_array",
types=[
NullType(),
ListType(
alias="_root.test_int_array",
values=UnionType(
types=[
IntType(bits=32),
ProxyType(
alias="_root.test_int_array",
registry=client.converter.registry, # type: ignore
),
]
),
),
],
),
UnionType(
default='{"Hello", "World"}',
name="test_varchar_array",
types=[
NullType(),
ListType(
alias="_root.test_varchar_array",
values=UnionType(
types=[
StringType(bytes_=MAX_FIELD_SIZE),
ProxyType(
alias="_root.test_varchar_array",
registry=client.converter.registry, # type: ignore
),
]
),
),
],
),
UnionType(
name="test_bit_array",
types=[
NullType(),
ListType(
alias="_root.test_bit_array",
values=UnionType(
types=[
BytesType(bytes_=MAX_FIELD_SIZE),
ProxyType(
alias="_root.test_bit_array",
registry=client.converter.registry, # type: ignore
),
]
),
),
],
),
]

# Going field by field to make debugging easier when test fails
Expand Down
Loading

0 comments on commit 97245ca

Please sign in to comment.