Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support PG array dimensionality #411

Merged
merged 7 commits into from
Jan 3, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions recap/clients/postgresql.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from recap.clients.dbapi import Connection, DbapiClient
from recap.converters.postgresql import PostgresqlConverter
from recap.types import StructType

PSYCOPG2_CONNECT_ARGS = {
"host",
Expand Down Expand Up @@ -78,3 +79,25 @@ def ls_catalogs(self) -> list[str]:
"""
)
return [row[0] for row in cursor.fetchall()]

def schema(self, catalog: str, schema: str, table: str) -> StructType:
cursor = self.connection.cursor()
cursor.execute(
f"""
SELECT
information_schema.columns.*,
pg_attribute.attndims
FROM information_schema.columns
JOIN pg_attribute on information_schema.columns.column_name = pg_attribute.attname
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is enough. I think you need to join on table, schema, and column. It appears pg_attribute doesn't have those:

               Table "pg_catalog.pg_attribute"
     Column     |   Type    | Collation | Nullable | Default 
----------------+-----------+-----------+----------+---------
 attrelid       | oid       |           | not null | 
 attname        | name      |           | not null | 
 atttypid       | oid       |           | not null | 
 attstattarget  | integer   |           | not null | 
 attlen         | smallint  |           | not null | 
 attnum         | smallint  |           | not null | 
 attndims       | integer   |           | not null | 
 attcacheoff    | integer   |           | not null | 
 atttypmod      | integer   |           | not null | 
 attbyval       | boolean   |           | not null | 
 attalign       | "char"    |           | not null | 
 attstorage     | "char"    |           | not null | 
 attcompression | "char"    |           | not null | 
 attnotnull     | boolean   |           | not null | 
 atthasdef      | boolean   |           | not null | 
 atthasmissing  | boolean   |           | not null | 
 attidentity    | 
"char"    |           | not null | 
 attgenerated   | "char"    |           | not null | 
 attisdropped   | boolean   |           | not null | 
 attislocal     | boolean   |           | not null | 
 attinhcount    | integer   |           | not null | 
 attcollation   | oid       |           | not null | 
 attacl         | aclitem[] |           |          | 
 attoptions     | text[]    | C         |          | 
 attfdwoptions  | text[]    | C         |          | 
 attmissingval  | anyarray  |           |          | 

So perhaps we need another join here as well?

Per-ChatGPT:

Screenshot 2023-12-23 at 11 52 44 AM

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah nice catch, I'll add in the other joins

WHERE table_name = {self.param_style}
AND table_schema = {self.param_style}
AND table_catalog = {self.param_style}
ORDER BY ordinal_position ASC
""",
(table, schema, catalog),
)
names = [name[0].upper() for name in cursor.description]
return self.converter.to_recap(
# Make each row be a dict with the column names as keys
[dict(zip(names, row)) for row in cursor.fetchall()]
)
48 changes: 15 additions & 33 deletions recap/converters/postgresql.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,32 +8,22 @@
FloatType,
IntType,
ListType,
ProxyType,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Niiiice. Took me a sec to grok why we didn't need this anymore. Fully walking the n_dimensions means we don't need self-references. Awesome.

One question/nuance here: the PG dimensions are just a suggestion.

The current implementation does not enforce the declared number of dimensions either. Arrays of a particular element type are all considered to be of the same type, regardless of size or number of dimensions. So, declaring the array size or number of dimensions in CREATE TABLE is simply documentation; it does not affect run-time behavior.

https://www.postgresql.org/docs/current/arrays.html

So the question is, do we want to have the Recap reflect the DB's data or its schema? My implementation (with ProxyType) reflected the data. Yours changes it to reflect the schema. Perhaps we want it configurable one as the default? WDYT?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like to think the schema is the beacon of truth for what the user intends for the column. If users are leveraging the column differently than schema's representation, they should fix the schema. But I could see past mistakes leading to a situation where this isn't true, which would then lead to recap constructing a false narrative about the data. I think making it configurable makes sense. Maybe default to ProxyType since that's the safer assumption? Would we want to add config params to the PostgresqlConverter constructor?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ya, can you add a param to the init to config. Defaulting to proxy is safer, as you say.

NullType,
RecapType,
RecapTypeRegistry,
StringType,
UnionType,
)

MAX_FIELD_SIZE = 1073741824

DEFAULT_NAMESPACE = "_root"
"""
Namespace to use when no namespace is specified in the schema.
"""


class PostgresqlConverter(DbapiConverter):
def __init__(self, namespace: str = DEFAULT_NAMESPACE) -> None:
self.namespace = namespace
self.registry = RecapTypeRegistry()

def _parse_type(self, column_props: dict[str, Any]) -> RecapType:
column_name = column_props["COLUMN_NAME"]
data_type = column_props["DATA_TYPE"].lower()
octet_length = column_props["CHARACTER_OCTET_LENGTH"]
max_length = column_props["CHARACTER_MAXIMUM_LENGTH"]
udt_name = (column_props["UDT_NAME"] or "").lower()
ndims = column_props["ATTNDIMS"]

if data_type in ["bigint", "int8", "bigserial", "serial8"]:
base_type = IntType(bits=64, signed=True)
Expand Down Expand Up @@ -90,7 +80,6 @@ def _parse_type(self, column_props: dict[str, Any]) -> RecapType:
# lengths, etc. Thus, we only set DATA_TYPE here. Sigh.
value_type = self._parse_type(
{
"COLUMN_NAME": None,
"DATA_TYPE": nested_data_type,
# Default strings, bits, etc. to the max field size since
# information_schema doesn't contain lengths for array
Expand All @@ -102,29 +91,22 @@ def _parse_type(self, column_props: dict[str, Any]) -> RecapType:
# * 8 because bit columns use bits not bytes.
"CHARACTER_MAXIMUM_LENGTH": MAX_FIELD_SIZE * 8,
"UDT_NAME": None,
"ATTNDIMS": 0,
}
)
column_name_without_periods = column_name.replace(".", "_")
base_type_alias = f"{self.namespace}.{column_name_without_periods}"
# Construct a self-referencing list comprised of the array's value
# type and a proxy to the list itself. This allows arrays to be an
# arbitrary number of dimensions, which is how PostgreSQL treats
# lists. See https://github.com/recap-build/recap/issues/264 for
# more details.
base_type = ListType(
alias=base_type_alias,
values=UnionType(
types=[
value_type,
ProxyType(
alias=base_type_alias,
registry=self.registry,
),
],
),
)
self.registry.register_alias(base_type)
base_type = self._create_n_dimension_list(value_type, ndims)
else:
raise ValueError(f"Unknown data type: {data_type}")

return base_type

def _create_n_dimension_list(self, base_type: RecapType, ndims: int) -> RecapType:
"""
Build a list type with `ndims` dimensions containing nullable `base_type` as the innermost value type.
"""
if ndims == 0:
return UnionType(types=[NullType(), base_type])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm curious about this one. It seems right, but I'm not 100% sure. As I read it, there are a few things:

  1. DbapiConverter handles root-level NULLABLE fields (https://github.com/recap-build/recap/blob/main/recap/converters/dbapi.py#L15-L16)
  2. This code here handles NULLABLE items in a PG ARRAY field.

I think this is the right behavior. But I'm curious: are PG arrays always allowed NULLs in their dimensional values? I couldn't find good docs on this. I haven't tested it out.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did some testing and digging, and afaict the answer is yes- the innermost value can always be null. Enforcing non-nulls requires adding some sort of validation to CHECK against https://stackoverflow.com/a/59421233. Which seems like a pretty challenging rabbit hole of digging through information_schema.check_constraints

else:
return ListType(
values=self._create_n_dimension_list(base_type, ndims - 1),
)
61 changes: 44 additions & 17 deletions tests/integration/clients/test_postgresql.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
IntType,
ListType,
NullType,
ProxyType,
StringType,
StructType,
UnionType,
Expand Down Expand Up @@ -51,7 +50,9 @@ def setup_class(cls):
test_default INTEGER DEFAULT 2,
test_int_array INTEGER[],
test_varchar_array VARCHAR(255)[] DEFAULT '{"Hello", "World"}',
test_bit_array BIT(8)[]
test_bit_array BIT(8)[],
test_int_array_2d INTEGER[][],
test_text_array_3d TEXT[][][]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mind adding a NOT NULL array as well? I realized we haven't tested that.

);
"""
)
Expand Down Expand Up @@ -164,14 +165,10 @@ def test_struct_method(self):
types=[
NullType(),
ListType(
alias="_root.test_int_array",
values=UnionType(
types=[
NullType(),
IntType(bits=32),
ProxyType(
alias="_root.test_int_array",
registry=client.converter.registry, # type: ignore
),
]
),
),
Expand All @@ -183,14 +180,10 @@ def test_struct_method(self):
types=[
NullType(),
ListType(
alias="_root.test_varchar_array",
values=UnionType(
types=[
NullType(),
StringType(bytes_=MAX_FIELD_SIZE),
ProxyType(
alias="_root.test_varchar_array",
registry=client.converter.registry, # type: ignore
),
]
),
),
Expand All @@ -202,19 +195,53 @@ def test_struct_method(self):
types=[
NullType(),
ListType(
alias="_root.test_bit_array",
values=UnionType(
types=[
NullType(),
BytesType(bytes_=MAX_FIELD_SIZE, variable=False),
ProxyType(
alias="_root.test_bit_array",
registry=client.converter.registry, # type: ignore
),
]
),
),
],
),
UnionType(
default=None,
name="test_int_array_2d",
types=[
NullType(),
ListType(
values=ListType(
values=UnionType(
types=[
NullType(),
IntType(bits=32),
]
)
),
),
],
),
UnionType(
default=None,
name="test_text_array_3d",
types=[
NullType(),
ListType(
values=ListType(
values=ListType(
values=UnionType(
types=[
NullType(),
StringType(
bytes_=MAX_FIELD_SIZE, variable=True
),
]
)
)
),
),
],
),
]

# Going field by field to make debugging easier when test fails
Expand Down
25 changes: 19 additions & 6 deletions tests/unit/converters/test_postgresql.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
FloatType,
IntType,
ListType,
ProxyType,
NullType,
StringType,
UnionType,
)
Expand All @@ -25,6 +25,7 @@
"NUMERIC_PRECISION": None,
"NUMERIC_SCALE": None,
"UDT_NAME": None,
"ATTNDIMS": 0,
},
IntType(bits=64, signed=True),
),
Expand All @@ -37,6 +38,7 @@
"NUMERIC_PRECISION": None,
"NUMERIC_SCALE": None,
"UDT_NAME": None,
"ATTNDIMS": 0,
},
IntType(bits=32, signed=True),
),
Expand All @@ -49,6 +51,7 @@
"NUMERIC_PRECISION": None,
"NUMERIC_SCALE": None,
"UDT_NAME": None,
"ATTNDIMS": 0,
},
IntType(bits=16, signed=True),
),
Expand All @@ -61,6 +64,7 @@
"NUMERIC_PRECISION": None,
"NUMERIC_SCALE": None,
"UDT_NAME": None,
"ATTNDIMS": 0,
},
FloatType(bits=64),
),
Expand All @@ -73,6 +77,7 @@
"NUMERIC_PRECISION": None,
"NUMERIC_SCALE": None,
"UDT_NAME": None,
"ATTNDIMS": 0,
},
FloatType(bits=32),
),
Expand All @@ -85,6 +90,7 @@
"NUMERIC_PRECISION": None,
"NUMERIC_SCALE": None,
"UDT_NAME": None,
"ATTNDIMS": 0,
},
BoolType(),
),
Expand All @@ -97,6 +103,7 @@
"NUMERIC_PRECISION": None,
"NUMERIC_SCALE": None,
"UDT_NAME": None,
"ATTNDIMS": 0,
},
StringType(bytes_=65536, variable=True),
),
Expand All @@ -109,6 +116,7 @@
"NUMERIC_PRECISION": None,
"NUMERIC_SCALE": None,
"UDT_NAME": None,
"ATTNDIMS": 0,
},
StringType(bytes_=255, variable=True),
),
Expand All @@ -121,6 +129,7 @@
"NUMERIC_PRECISION": None,
"NUMERIC_SCALE": None,
"UDT_NAME": None,
"ATTNDIMS": 0,
},
StringType(bytes_=255, variable=False),
),
Expand All @@ -133,6 +142,7 @@
"NUMERIC_PRECISION": None,
"NUMERIC_SCALE": None,
"UDT_NAME": None,
"ATTNDIMS": 0,
},
BytesType(bytes_=MAX_FIELD_SIZE),
),
Expand All @@ -145,6 +155,7 @@
"NUMERIC_PRECISION": None,
"NUMERIC_SCALE": None,
"UDT_NAME": None,
"ATTNDIMS": 0,
},
BytesType(bytes_=1, variable=False),
),
Expand All @@ -157,6 +168,7 @@
"NUMERIC_PRECISION": None,
"NUMERIC_SCALE": None,
"UDT_NAME": None,
"ATTNDIMS": 0,
},
BytesType(bytes_=3, variable=False),
),
Expand All @@ -170,6 +182,7 @@
"NUMERIC_SCALE": None,
"UDT_NAME": None,
"DATETIME_PRECISION": 3,
"ATTNDIMS": 0,
},
IntType(bits=64, logical="build.recap.Timestamp", unit="millisecond"),
),
Expand All @@ -183,6 +196,7 @@
"NUMERIC_SCALE": None,
"UDT_NAME": None,
"DATETIME_PRECISION": 3,
"ATTNDIMS": 0,
},
IntType(bits=64, logical="build.recap.Timestamp", unit="millisecond"),
),
Expand All @@ -195,6 +209,7 @@
"NUMERIC_PRECISION": 10,
"NUMERIC_SCALE": 2,
"UDT_NAME": None,
"ATTNDIMS": 0,
},
BytesType(
logical="build.recap.Decimal",
Expand All @@ -213,6 +228,7 @@
"NUMERIC_PRECISION": 5,
"NUMERIC_SCALE": 0,
"UDT_NAME": None,
"ATTNDIMS": 0,
},
BytesType(
logical="build.recap.Decimal",
Expand All @@ -239,16 +255,13 @@ def test_postgresql_converter_array():
"NUMERIC_PRECISION": 5,
"NUMERIC_SCALE": 0,
"UDT_NAME": "_int4",
"ATTNDIMS": 1,
}
expected = ListType(
alias="_root.test_column",
values=UnionType(
types=[
NullType(),
IntType(bits=32, signed=True),
ProxyType(
alias="_root.test_column",
registry=converter.registry,
),
],
),
)
Expand Down