Skip to content

Commit

Permalink
fix: Broken scheduled_at field in queryfilter
Browse files Browse the repository at this point in the history
  • Loading branch information
jopemachine committed Dec 13, 2024
1 parent 6fcf96b commit db07dd3
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 9 deletions.
10 changes: 9 additions & 1 deletion src/ai/backend/manager/models/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
VFolderMount,
)
from ai.backend.logging import BraceStyleAdapter
from ai.backend.manager.models.minilang import JSONArrayFieldItem

from ..api.exceptions import (
BackendError,
Expand Down Expand Up @@ -992,7 +993,14 @@ async def resolve_abusing_report(
"created_at": ("created_at", dtparse),
"status_changed": ("status_changed", dtparse),
"terminated_at": ("terminated_at", dtparse),
"scheduled_at": ("scheduled_at", None),
"scheduled_at": (
JSONArrayFieldItem(
column_name="status_history",
conditions={"status": KernelStatus.SCHEDULED.name},
key_name="timestamp",
),
dtparse,
),
}

_queryorder_colmap: ColumnMapType = {
Expand Down
12 changes: 10 additions & 2 deletions src/ai/backend/manager/models/minilang/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,12 @@ class JSONFieldItem(NamedTuple):
key_name: str


class JSONArrayFieldItem(NamedTuple):
column_name: str
conditions: dict[str, str]
key_name: str


TEnum = TypeVar("TEnum", bound=Enum)


Expand All @@ -22,10 +28,12 @@ class EnumFieldItem(NamedTuple, Generic[TEnum]):


FieldSpecItem = tuple[
str | ArrayFieldItem | JSONFieldItem | EnumFieldItem, Callable[[str], Any] | None
str | ArrayFieldItem | JSONFieldItem | EnumFieldItem | JSONArrayFieldItem,
Callable[[str], Any] | None,
]
OrderSpecItem = tuple[
str | ArrayFieldItem | JSONFieldItem | EnumFieldItem, Callable[[sa.Column], Any] | None
str | ArrayFieldItem | JSONFieldItem | EnumFieldItem | JSONArrayFieldItem,
Callable[[sa.Column], Any] | None,
]


Expand Down
6 changes: 5 additions & 1 deletion src/ai/backend/manager/models/minilang/ordering.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from lark import Lark, LarkError, Transformer
from lark.lexer import Token

from . import JSONFieldItem, OrderSpecItem, get_col_from_table
from . import JSONArrayFieldItem, JSONFieldItem, OrderSpecItem, get_col_from_table

__all__ = (
"ColumnMapType",
Expand Down Expand Up @@ -56,6 +56,10 @@ def _get_col(self, col_name: str) -> sa.Column:
case JSONFieldItem(_col, _key):
_column = get_col_from_table(self._sa_table, _col)
matched_col = _column.op("->>")(_key)
case JSONArrayFieldItem(_col_name, _conditions, _key_name):
# TODO: Implement this.
pass
# ...
case _:
raise ValueError("Invalid type of field name", col_name)
col = func(matched_col) if func is not None else matched_col
Expand Down
38 changes: 37 additions & 1 deletion src/ai/backend/manager/models/minilang/queryfilter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,16 @@
import sqlalchemy as sa
from lark import Lark, LarkError, Transformer, Tree
from lark.lexer import Token
from sqlalchemy.dialects.postgresql import JSONB

from . import ArrayFieldItem, EnumFieldItem, FieldSpecItem, JSONFieldItem, get_col_from_table
from . import (
ArrayFieldItem,
EnumFieldItem,
FieldSpecItem,
JSONArrayFieldItem,
JSONFieldItem,
get_col_from_table,
)

__all__ = (
"FieldSpecType",
Expand Down Expand Up @@ -172,6 +180,34 @@ def build_expr(op: str, col, val):
# to retrieve the value used in the expression.
col = get_col_from_table(self._sa_table, col_name).op("->>")(obj_key)
expr = build_expr(op, col, val)
case JSONArrayFieldItem(col_name, conditions, key_name):
col = get_col_from_table(self._sa_table, col_name)
json_array = sa.func.jsonb_array_elements(col.cast(JSONB)).alias("item")

condition_list = []
for key, expected_value in conditions.items():
condition_list.append(
sa.column("item").op("->>")(key) == expected_value
)

element_timestamp = (
sa.column("item")
.op("->>")(key_name)
.cast(sa.types.TIMESTAMP(timezone=True))
)

combined_conditions = sa.and_(*condition_list)

subq = (
sa.select([sa.literal(1)])
.select_from(json_array)
.where(
sa.and_(combined_conditions, build_expr(op, element_timestamp, val))
)
)

expr = sa.exists(subq)

case EnumFieldItem(col_name, enum_cls):
col = get_col_from_table(self._sa_table, col_name)
# allow both key and value of enum to be specified on variable `val`
Expand Down
15 changes: 11 additions & 4 deletions src/ai/backend/manager/models/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@
from .group import GroupRow
from .image import ImageRow
from .kernel import ComputeContainer, KernelRow, KernelStatus
from .minilang import ArrayFieldItem
from .minilang import ArrayFieldItem, JSONArrayFieldItem
from .minilang.ordering import ColumnMapType, QueryOrderParser
from .minilang.queryfilter import FieldSpecType, QueryFilterParser, enum_field_getter
from .rbac import (
Expand Down Expand Up @@ -1570,7 +1570,7 @@ class Meta:
status_info = graphene.String()
status_data = graphene.JSONString()
status_history = graphene.JSONString(
deprecation_reason="Deprecated since 24.12.0; Use `status_history_log`"
deprecation_reason="Deprecated since 24.12.0; use `status_history_log`"
)
status_history_log = graphene.JSONString(description="Added in 24.12.0")

Expand Down Expand Up @@ -1757,7 +1757,14 @@ async def resolve_idle_checks(self, info: graphene.ResolveInfo) -> Mapping[str,
"created_at": ("sessions_created_at", dtparse),
"terminated_at": ("sessions_terminated_at", dtparse),
"starts_at": ("sessions_starts_at", dtparse),
"scheduled_at": ("scheduled_at", None),
"scheduled_at": (
JSONArrayFieldItem(
column_name="sessions_status_history",
conditions={"status": SessionStatus.SCHEDULED.name},
key_name="timestamp",
),
dtparse,
),
"startup_command": ("sessions_startup_command", None),
}

Expand Down Expand Up @@ -1786,7 +1793,7 @@ async def resolve_idle_checks(self, info: graphene.ResolveInfo) -> Mapping[str,
"created_at": ("sessions_created_at", None),
"terminated_at": ("sessions_terminated_at", None),
"starts_at": ("sessions_starts_at", None),
"scheduled_at": ("scheduled_at", None),
"scheduled_at": ("sessions_scheduled_at", None),
}

@classmethod
Expand Down

0 comments on commit db07dd3

Please sign in to comment.