Skip to content

Commit

Permalink
feat: Cache gpu_alloc_map, and Add ScanGPUAllocMap mutation
Browse files Browse the repository at this point in the history
  • Loading branch information
jopemachine committed Dec 24, 2024
1 parent cc44683 commit 68617a4
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 3 deletions.
7 changes: 7 additions & 0 deletions src/ai/backend/manager/api/schema.graphql
Original file line number Diff line number Diff line change
Expand Up @@ -1718,6 +1718,7 @@ type Mutations {
This action cannot be undone.
"""
purge_user(email: String!, props: PurgeUserInput!): PurgeUser
scan_gpu_alloc_maps(agent_id: String): ScanGPUAllocMap

Check failure on line 1721 in src/ai/backend/manager/api/schema.graphql

View workflow job for this annotation

GitHub Actions / GraphQL Inspector

New fields must include a description with a version number in the format "Added in XX.XX.X.", Field 'scan_gpu_alloc_maps' was added to object type 'Mutations'

New fields must include a description with a version number in the format "Added in XX.XX.X."
create_keypair(props: KeyPairInput!, user_id: String!): CreateKeyPair
modify_keypair(access_key: String!, props: ModifyKeyPairInput!): ModifyKeyPair
delete_keypair(access_key: String!): DeleteKeyPair
Expand Down Expand Up @@ -2112,6 +2113,12 @@ input PurgeUserInput {
purge_shared_vfolders: Boolean
}

type ScanGPUAllocMap {

Check failure on line 2116 in src/ai/backend/manager/api/schema.graphql

View workflow job for this annotation

GitHub Actions / GraphQL Inspector

New types must include a description with a version number in the format "Added in XX.XX.X.", Type 'ScanGPUAllocMap' was added

New types must include a description with a version number in the format "Added in XX.XX.X."
ok: Boolean
msg: String
task_id: UUID
}

type CreateKeyPair {
ok: Boolean
msg: String
Expand Down
2 changes: 2 additions & 0 deletions src/ai/backend/manager/models/gql.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
AgentSummary,
AgentSummaryList,
ModifyAgent,
ScanGPUAllocMap,
)
from .gql_models.domain import (
CreateDomainNode,
Expand Down Expand Up @@ -250,6 +251,7 @@ class Mutations(graphene.ObjectType):
modify_user = ModifyUser.Field()
delete_user = DeleteUser.Field()
purge_user = PurgeUser.Field()
scan_gpu_alloc_maps = ScanGPUAllocMap.Field()

# admin only
create_keypair = CreateKeyPair.Field()
Expand Down
51 changes: 51 additions & 0 deletions src/ai/backend/manager/models/gql_models/agent.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import json
import uuid
from collections.abc import Iterable, Mapping, Sequence
from typing import (
Expand All @@ -18,6 +19,7 @@
from sqlalchemy.ext.asyncio import AsyncConnection as SAConnection

from ai.backend.common import msgpack, redis_helper
from ai.backend.common.bgtask import ProgressReporter
from ai.backend.common.types import (
AccessKey,
AgentId,
Expand Down Expand Up @@ -878,3 +880,52 @@ async def mutate(

update_query = sa.update(agents).values(data).where(agents.c.id == id)
return await simple_db_mutate(cls, graph_ctx, update_query)


class ScanGPUAllocMap(graphene.Mutation):
allowed_roles = (UserRole.SUPERADMIN,)

class Arguments:
agent_id = graphene.String()

ok = graphene.Boolean()
msg = graphene.String()
task_id = graphene.UUID()

@classmethod
@privileged_mutation(
UserRole.SUPERADMIN,
lambda id, **kwargs: (None, id),
)
async def mutate(
cls,
root,
info: graphene.ResolveInfo,
agent_id: Optional[str] = None,
) -> ScanGPUAllocMap:
graph_ctx: GraphQueryContext = info.context

if agent_id:
agent_ids = [agent_id]
else:
agent_ids = [agent.id async for agent in graph_ctx.registry.enumerate_instances()]

async def _scan_alloc_map_task(reporter: ProgressReporter) -> None:
for index, agent_id in enumerate(agent_ids, start=1):
await reporter.update(
increment=1, message=f"Agent {agent_id} scannning... ({index}/{len(agent_ids)})"
)

async with graph_ctx.registry.agent_cache.rpc_context(AgentId(agent_id)) as rpc:
alloc_map: Mapping[str, Any] = await rpc.call.scan_gpu_alloc_map()
key = f"gpu_alloc_map.{agent_id}"

await redis_helper.execute(
graph_ctx.registry.redis_stat,
lambda r: r.set(name=key, value=json.dumps(alloc_map)),
)

await reporter.update(increment=1, message="GPU alloc map scanning completed")

task_id = await graph_ctx.background_task_manager.start(_scan_alloc_map_task)
return ScanGPUAllocMap(ok=True, msg="", task_id=task_id)
11 changes: 8 additions & 3 deletions src/ai/backend/manager/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import base64
import copy
import itertools
import json
import logging
import re
import secrets
Expand Down Expand Up @@ -436,9 +437,13 @@ async def gather_storage_hwinfo(self, vfolder_host: str) -> HardwareMetadata:
)

async def scan_gpu_alloc_map(self, instance_id: AgentId) -> Mapping[str, Any]:
agent = await self.get_instance(instance_id, agents.c.addr)
async with self.agent_cache.rpc_context(agent["id"]) as rpc:
return await rpc.call.scan_gpu_alloc_map()
raw_alloc_map = await redis_helper.execute(
self.redis_stat, lambda r: r.get(f"gpu_alloc_map.{instance_id}")
)
if raw_alloc_map:
return json.loads(raw_alloc_map)
else:
return {}

async def create_session(
self,
Expand Down

0 comments on commit 68617a4

Please sign in to comment.