From db07dd37d4745f4906882c95a4a55ec034a6e358 Mon Sep 17 00:00:00 2001 From: Gyubong Lee Date: Tue, 10 Dec 2024 06:41:31 +0000 Subject: [PATCH] fix: Broken scheduled_at field in queryfilter --- src/ai/backend/manager/models/kernel.py | 10 ++++- .../manager/models/minilang/__init__.py | 12 +++++- .../manager/models/minilang/ordering.py | 6 ++- .../manager/models/minilang/queryfilter.py | 38 ++++++++++++++++++- src/ai/backend/manager/models/session.py | 15 ++++++-- 5 files changed, 72 insertions(+), 9 deletions(-) diff --git a/src/ai/backend/manager/models/kernel.py b/src/ai/backend/manager/models/kernel.py index 3db283f899..47d1ddb193 100644 --- a/src/ai/backend/manager/models/kernel.py +++ b/src/ai/backend/manager/models/kernel.py @@ -47,6 +47,7 @@ VFolderMount, ) from ai.backend.logging import BraceStyleAdapter +from ai.backend.manager.models.minilang import JSONArrayFieldItem from ..api.exceptions import ( BackendError, @@ -992,7 +993,14 @@ async def resolve_abusing_report( "created_at": ("created_at", dtparse), "status_changed": ("status_changed", dtparse), "terminated_at": ("terminated_at", dtparse), - "scheduled_at": ("scheduled_at", None), + "scheduled_at": ( + JSONArrayFieldItem( + column_name="status_history", + conditions={"status": KernelStatus.SCHEDULED.name}, + key_name="timestamp", + ), + dtparse, + ), } _queryorder_colmap: ColumnMapType = { diff --git a/src/ai/backend/manager/models/minilang/__init__.py b/src/ai/backend/manager/models/minilang/__init__.py index 74f3c0ea99..ff046868c4 100644 --- a/src/ai/backend/manager/models/minilang/__init__.py +++ b/src/ai/backend/manager/models/minilang/__init__.py @@ -13,6 +13,12 @@ class JSONFieldItem(NamedTuple): key_name: str +class JSONArrayFieldItem(NamedTuple): + column_name: str + conditions: dict[str, str] + key_name: str + + TEnum = TypeVar("TEnum", bound=Enum) @@ -22,10 +28,12 @@ class EnumFieldItem(NamedTuple, Generic[TEnum]): FieldSpecItem = tuple[ - str | ArrayFieldItem | JSONFieldItem | EnumFieldItem, Callable[[str], Any] | None + str | ArrayFieldItem | JSONFieldItem | EnumFieldItem | JSONArrayFieldItem, + Callable[[str], Any] | None, ] OrderSpecItem = tuple[ - str | ArrayFieldItem | JSONFieldItem | EnumFieldItem, Callable[[sa.Column], Any] | None + str | ArrayFieldItem | JSONFieldItem | EnumFieldItem | JSONArrayFieldItem, + Callable[[sa.Column], Any] | None, ] diff --git a/src/ai/backend/manager/models/minilang/ordering.py b/src/ai/backend/manager/models/minilang/ordering.py index 0643392628..63c54fefa4 100644 --- a/src/ai/backend/manager/models/minilang/ordering.py +++ b/src/ai/backend/manager/models/minilang/ordering.py @@ -5,7 +5,7 @@ from lark import Lark, LarkError, Transformer from lark.lexer import Token -from . import JSONFieldItem, OrderSpecItem, get_col_from_table +from . import JSONArrayFieldItem, JSONFieldItem, OrderSpecItem, get_col_from_table __all__ = ( "ColumnMapType", @@ -56,6 +56,10 @@ def _get_col(self, col_name: str) -> sa.Column: case JSONFieldItem(_col, _key): _column = get_col_from_table(self._sa_table, _col) matched_col = _column.op("->>")(_key) + case JSONArrayFieldItem(_col_name, _conditions, _key_name): + # TODO: Implement this. + pass + # ... case _: raise ValueError("Invalid type of field name", col_name) col = func(matched_col) if func is not None else matched_col diff --git a/src/ai/backend/manager/models/minilang/queryfilter.py b/src/ai/backend/manager/models/minilang/queryfilter.py index 4973ffd54c..761b63f8d0 100644 --- a/src/ai/backend/manager/models/minilang/queryfilter.py +++ b/src/ai/backend/manager/models/minilang/queryfilter.py @@ -4,8 +4,16 @@ import sqlalchemy as sa from lark import Lark, LarkError, Transformer, Tree from lark.lexer import Token +from sqlalchemy.dialects.postgresql import JSONB -from . import ArrayFieldItem, EnumFieldItem, FieldSpecItem, JSONFieldItem, get_col_from_table +from . import ( + ArrayFieldItem, + EnumFieldItem, + FieldSpecItem, + JSONArrayFieldItem, + JSONFieldItem, + get_col_from_table, +) __all__ = ( "FieldSpecType", @@ -172,6 +180,34 @@ def build_expr(op: str, col, val): # to retrieve the value used in the expression. col = get_col_from_table(self._sa_table, col_name).op("->>")(obj_key) expr = build_expr(op, col, val) + case JSONArrayFieldItem(col_name, conditions, key_name): + col = get_col_from_table(self._sa_table, col_name) + json_array = sa.func.jsonb_array_elements(col.cast(JSONB)).alias("item") + + condition_list = [] + for key, expected_value in conditions.items(): + condition_list.append( + sa.column("item").op("->>")(key) == expected_value + ) + + element_timestamp = ( + sa.column("item") + .op("->>")(key_name) + .cast(sa.types.TIMESTAMP(timezone=True)) + ) + + combined_conditions = sa.and_(*condition_list) + + subq = ( + sa.select([sa.literal(1)]) + .select_from(json_array) + .where( + sa.and_(combined_conditions, build_expr(op, element_timestamp, val)) + ) + ) + + expr = sa.exists(subq) + case EnumFieldItem(col_name, enum_cls): col = get_col_from_table(self._sa_table, col_name) # allow both key and value of enum to be specified on variable `val` diff --git a/src/ai/backend/manager/models/session.py b/src/ai/backend/manager/models/session.py index 8b0a8f297d..e0dff6ef8a 100644 --- a/src/ai/backend/manager/models/session.py +++ b/src/ai/backend/manager/models/session.py @@ -89,7 +89,7 @@ from .group import GroupRow from .image import ImageRow from .kernel import ComputeContainer, KernelRow, KernelStatus -from .minilang import ArrayFieldItem +from .minilang import ArrayFieldItem, JSONArrayFieldItem from .minilang.ordering import ColumnMapType, QueryOrderParser from .minilang.queryfilter import FieldSpecType, QueryFilterParser, enum_field_getter from .rbac import ( @@ -1570,7 +1570,7 @@ class Meta: status_info = graphene.String() status_data = graphene.JSONString() status_history = graphene.JSONString( - deprecation_reason="Deprecated since 24.12.0; Use `status_history_log`" + deprecation_reason="Deprecated since 24.12.0; use `status_history_log`" ) status_history_log = graphene.JSONString(description="Added in 24.12.0") @@ -1757,7 +1757,14 @@ async def resolve_idle_checks(self, info: graphene.ResolveInfo) -> Mapping[str, "created_at": ("sessions_created_at", dtparse), "terminated_at": ("sessions_terminated_at", dtparse), "starts_at": ("sessions_starts_at", dtparse), - "scheduled_at": ("scheduled_at", None), + "scheduled_at": ( + JSONArrayFieldItem( + column_name="sessions_status_history", + conditions={"status": SessionStatus.SCHEDULED.name}, + key_name="timestamp", + ), + dtparse, + ), "startup_command": ("sessions_startup_command", None), } @@ -1786,7 +1793,7 @@ async def resolve_idle_checks(self, info: graphene.ResolveInfo) -> Mapping[str, "created_at": ("sessions_created_at", None), "terminated_at": ("sessions_terminated_at", None), "starts_at": ("sessions_starts_at", None), - "scheduled_at": ("scheduled_at", None), + "scheduled_at": ("sessions_scheduled_at", None), } @classmethod