Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Replace sessions, kernels's status_history's type dict with list #3201

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/3201.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Change the type of `status_history` from a mapping of status and timestamps to a list of log entries containing status and timestamps, to preserve timestamps when revisiting session/kernel statuses (e.g., after session restarts).
89 changes: 53 additions & 36 deletions src/ai/backend/client/cli/session/lifecycle.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import inquirer
import treelib
from async_timeout import timeout
from dateutil.parser import isoparse
from dateutil.tz import tzutc
from faker import Faker
from humanize import naturalsize
Expand All @@ -25,6 +24,9 @@
from ai.backend.cli.main import main
from ai.backend.cli.params import CommaSeparatedListType, OptionalType
from ai.backend.cli.types import ExitCode, Undefined, undefined
from ai.backend.client.cli.extensions import pass_ctx_obj
from ai.backend.client.cli.types import CLIContext
from ai.backend.client.utils import get_latest_timestamp_for_status
from ai.backend.common.arch import DEFAULT_IMAGE_ARCH
from ai.backend.common.types import ClusterMode, SessionId

Expand Down Expand Up @@ -889,48 +891,63 @@ def logs(session_id: str, kernel: str | None) -> None:
sys.exit(ExitCode.FAILURE)


@session.command("status-history")
@session.command()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know much about session.command, but can you subtract the internal string "status-history"?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@main.group()
def session():
"""Set of compute session operations"""

Looking at the code, it seems that the decorator itself does not perform any operations on the arguments.
So, I think it would be better to leave the arguments empty, as is currently done in most of the other code.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is using overload to overload the function for the last implementation, can you double check? @jopemachine

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If no value is passed as an argument, the decorated function name is automatically used as the CLI command.
If we pass an arbitrary string, that string can be used as the CLI command instead of changing the function name.
So, let's leave the arguments empty here, just like with the other functions.

@pass_ctx_obj
@click.argument("session_id", metavar="SESSID")
def status_history(session_id: str) -> None:
def status_history(ctx: CLIContext, session_id: SessionId) -> None:
"""
Shows the status transition history of the compute session.

\b
SESSID: Session ID or its alias given when creating the session.
"""
with Session() as session:
print_wait("Retrieving status history...")
kernel = session.ComputeSession(session_id)
try:
status_history = kernel.get_status_history().get("result")
print_info(f"status_history: {status_history}")
if (preparing := status_history.get("preparing")) is None:
result = {
"result": {
"seconds": 0,
"microseconds": 0,
},
}
elif (terminated := status_history.get("terminated")) is None:
alloc_time_until_now: timedelta = datetime.now(tzutc()) - isoparse(preparing)
result = {
"result": {
"seconds": alloc_time_until_now.seconds,
"microseconds": alloc_time_until_now.microseconds,
},
}
else:
alloc_time: timedelta = isoparse(terminated) - isoparse(preparing)
result = {
"result": {
"seconds": alloc_time.seconds,
"microseconds": alloc_time.microseconds,
},
}
print_done(f"Actual Resource Allocation Time: {result}")
except Exception as e:
print_error(e)
sys.exit(ExitCode.FAILURE)

async def cmd_main() -> None:
async with AsyncSession() as session:
print_wait("Retrieving status history...")

kernel = session.ComputeSession(str(session_id))
try:
resp = await kernel.get_status_history()
status_history = resp["result"]

prev_time = None

for status_record in status_history:
timestamp = datetime.fromisoformat(status_record["timestamp"])

if prev_time:
time_diff = timestamp - prev_time
status_record["time_elapsed"] = str(time_diff)

prev_time = timestamp

ctx.output.print_list(
status_history,
[FieldSpec("status"), FieldSpec("timestamp"), FieldSpec("time_elapsed")],
)

if (
preparing := get_latest_timestamp_for_status(status_history, "PREPARING")
) is None:
elapsed = timedelta()
elif (
terminated := get_latest_timestamp_for_status(status_history, "TERMINATED")
) is None:
elapsed = datetime.now(tzutc()) - preparing
else:
elapsed = terminated - preparing

print_done(f"Actual Resource Allocation Time: {elapsed.total_seconds()}")
except Exception as e:
print_error(e)
sys.exit(ExitCode.FAILURE)

try:
asyncio.run(cmd_main())
except Exception as e:
print_error(e)
sys.exit(ExitCode.FAILURE)


@session.command()
Expand Down
2 changes: 2 additions & 0 deletions src/ai/backend/client/output/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,8 @@
FieldSpec("created_user_id"),
FieldSpec("status"),
FieldSpec("status_info"),
FieldSpec("status_history"),
FieldSpec("status_history_log"),
FieldSpec("status_data", formatter=nested_dict_formatter),
FieldSpec("status_changed", "Last Updated"),
FieldSpec("created_at"),
Expand Down
13 changes: 13 additions & 0 deletions src/ai/backend/client/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import io
import os
from datetime import datetime
from typing import Optional

from dateutil.parser import parse as dtparse
from tqdm import tqdm


Expand Down Expand Up @@ -48,3 +51,13 @@ def readinto1(self, *args, **kwargs):
count = super().readinto1(*args, **kwargs)
self.tqdm.set_postfix(file=self._filename, refresh=False)
self.tqdm.update(count)


def get_latest_timestamp_for_status(
status_history: list[dict[str, str]],
status: str,
) -> Optional[datetime]:
for item in reversed(status_history):
if item["status"] == status:
return dtparse(item["timestamp"])
return None
2 changes: 1 addition & 1 deletion src/ai/backend/manager/api/resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,7 +466,7 @@ async def _pipe_builder(r: Redis) -> RedisPipeline:
"status": row["status"].name,
"status_info": row["status_info"],
"status_changed": str(row["status_changed"]),
"status_history": row["status_history"] or {},
"status_history": row["status_history"],
"cluster_mode": row["cluster_mode"],
}
if group_id not in objs_per_group:
Expand Down
10 changes: 9 additions & 1 deletion src/ai/backend/manager/api/schema.graphql
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,9 @@ type ComputeContainer implements Item {
registry: String
status: String
status_changed: DateTime

"""Added in 24.12.0."""
status_history: JSONString
status_info: String
status_data: JSONString
created_at: DateTime
Expand Down Expand Up @@ -925,7 +928,10 @@ type ComputeSession implements Item {
status_changed: DateTime
status_info: String
status_data: JSONString
status_history: JSONString
status_history: JSONString @deprecated(reason: "Deprecated since 24.12.0; use `status_history_log`")

"""Added in 24.12.0"""
status_history_log: JSONString
created_at: DateTime
terminated_at: DateTime
starts_at: DateTime
Expand Down Expand Up @@ -1204,6 +1210,8 @@ type ComputeSessionNode implements Node {
status: String
status_info: String
status_data: JSONString

"""Added in 24.12.0."""
status_history: JSONString
created_at: DateTime
terminated_at: DateTime
Expand Down
30 changes: 30 additions & 0 deletions src/ai/backend/manager/api/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -2333,6 +2333,35 @@ async def get_container_logs(
return web.json_response(resp, status=200)


@server_status_required(READ_ALLOWED)
@auth_required
@check_api_params(
t.Dict({
t.Key("owner_access_key", default=None): t.Null | t.String,
})
)
async def get_status_history(request: web.Request, params: Any) -> web.Response:
root_ctx: RootContext = request.app["_root.context"]
session_name: str = request.match_info["session_name"]
requester_access_key, owner_access_key = await get_access_key_scopes(request, params)
log.info(
"GET_STATUS_HISTORY (ak:{}/{}, s:{})", requester_access_key, owner_access_key, session_name
)
resp: dict[str, Mapping] = {"result": {}}

async with root_ctx.db.begin_readonly_session() as db_sess:
compute_session = await SessionRow.get_session(
db_sess,
session_name,
owner_access_key,
allow_stale=True,
kernel_loading_strategy=KernelLoadingStrategy.ALL_KERNELS,
)
resp["result"] = compute_session.status_history

return web.json_response(resp, status=200)


@server_status_required(READ_ALLOWED)
@auth_required
@check_api_params(
Expand Down Expand Up @@ -2472,6 +2501,7 @@ def create_app(
app.router.add_route("GET", "/{session_name}/direct-access-info", get_direct_access_info)
)
cors.add(app.router.add_route("GET", "/{session_name}/logs", get_container_logs))
cors.add(app.router.add_route("GET", "/{session_name}/status-history", get_status_history))
cors.add(app.router.add_route("POST", "/{session_name}/rename", rename_session))
cors.add(app.router.add_route("POST", "/{session_name}/interrupt", interrupt))
cors.add(app.router.add_route("POST", "/{session_name}/complete", complete))
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
"""Replace sessions, kernels's status_history's type map with list
Revision ID: 8c8e90aebacd
Revises: 0bb88d5a46bf
Create Date: 2024-12-05 11:19:23.075014
"""

from alembic import op

# revision identifiers, used by Alembic.
revision = "8c8e90aebacd"
down_revision = "0bb88d5a46bf"
branch_labels = None
depends_on = None


def upgrade():
op.execute(
"""
WITH data AS (
SELECT id,
(jsonb_each(status_history)).key AS status,
(jsonb_each(status_history)).value AS timestamp
FROM kernels
WHERE jsonb_typeof(status_history) = 'object'
)
UPDATE kernels
SET status_history = (
SELECT jsonb_agg(
jsonb_build_object('status', status, 'timestamp', timestamp)
ORDER BY timestamp
)
FROM data
WHERE data.id = kernels.id
AND jsonb_typeof(kernels.status_history) = 'object'
);
"""
)
op.execute("UPDATE kernels SET status_history = '[]'::jsonb WHERE status_history IS NULL;")
op.alter_column("kernels", "status_history", nullable=False, default=[])

op.execute(
"""
WITH data AS (
SELECT id,
(jsonb_each(status_history)).key AS status,
(jsonb_each(status_history)).value AS timestamp
FROM sessions
WHERE jsonb_typeof(status_history) = 'object'
)
UPDATE sessions
SET status_history = (
SELECT jsonb_agg(
jsonb_build_object('status', status, 'timestamp', timestamp)
ORDER BY timestamp
)
FROM data
WHERE data.id = sessions.id
AND jsonb_typeof(sessions.status_history) = 'object'
);
"""
)
op.execute("UPDATE sessions SET status_history = '[]'::jsonb WHERE status_history IS NULL;")
op.alter_column("sessions", "status_history", nullable=False, default=[])


def downgrade():
op.execute(
"""
WITH data AS (
SELECT id,
jsonb_object_agg(
elem->>'status', elem->>'timestamp'
) AS new_status_history
FROM kernels,
jsonb_array_elements(status_history) AS elem
WHERE jsonb_typeof(status_history) = 'array'
GROUP BY id
)
UPDATE kernels
SET status_history = data.new_status_history
FROM data
WHERE data.id = kernels.id
AND jsonb_typeof(kernels.status_history) = 'array';
"""
)
op.alter_column("kernels", "status_history", nullable=True, default=None)
op.execute("UPDATE kernels SET status_history = NULL WHERE status_history = '[]'::jsonb;")

op.execute(
"""
WITH data AS (
SELECT id,
jsonb_object_agg(
elem->>'status', elem->>'timestamp'
) AS new_status_history
FROM sessions,
jsonb_array_elements(status_history) AS elem
WHERE jsonb_typeof(status_history) = 'array'
GROUP BY id
)
UPDATE sessions
SET status_history = data.new_status_history
FROM data
WHERE data.id = sessions.id
AND jsonb_typeof(sessions.status_history) = 'array';
"""
)
op.alter_column("sessions", "status_history", nullable=True, default=None)
op.execute("UPDATE sessions SET status_history = NULL WHERE status_history = '[]'::jsonb;")
15 changes: 11 additions & 4 deletions src/ai/backend/manager/models/gql_models/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
TYPE_CHECKING,
Any,
Self,
cast,
)

import graphene
Expand All @@ -14,14 +15,15 @@

from ai.backend.common import msgpack, redis_helper
from ai.backend.common.types import AgentId, KernelId, SessionId
from ai.backend.manager.models.base import (

from ..base import (
batch_multiresult_in_scalar_stream,
batch_multiresult_in_session,
)

from ..gql_relay import AsyncNode, Connection
from ..kernel import KernelRow, KernelStatus
from ..user import UserRole
from ..utils import get_latest_timestamp_for_status
from .image import ImageNode

if TYPE_CHECKING:
Expand Down Expand Up @@ -113,7 +115,12 @@ def from_row(cls, ctx: GraphQueryContext, row: KernelRow) -> Self:
hide_agents = False
else:
hide_agents = ctx.local_config["manager"]["hide-agents"]
status_history = row.status_history or {}

timestamp = get_latest_timestamp_for_status(
cast(list[dict[str, str]], row.status_history), KernelStatus.SCHEDULED
)
scheduled_at = str(timestamp) if timestamp is not None else None

return KernelNode(
id=row.id, # auto-converted to Relay global ID
row_id=row.id,
Expand All @@ -129,7 +136,7 @@ def from_row(cls, ctx: GraphQueryContext, row: KernelRow) -> Self:
created_at=row.created_at,
terminated_at=row.terminated_at,
starts_at=row.starts_at,
scheduled_at=status_history.get(KernelStatus.SCHEDULED.name),
scheduled_at=scheduled_at,
occupied_slots=row.occupied_slots.to_json(),
agent_id=row.agent if not hide_agents else None,
agent_addr=row.agent_addr if not hide_agents else None,
Expand Down
Loading
Loading