From 68617a41fede43297a0d2181788b4fa239421d60 Mon Sep 17 00:00:00 2001 From: jopemachine Date: Tue, 24 Dec 2024 02:27:37 +0000 Subject: [PATCH] feat: Cache gpu_alloc_map, and Add ScanGPUAllocMap mutation --- src/ai/backend/manager/api/schema.graphql | 7 +++ src/ai/backend/manager/models/gql.py | 2 + .../manager/models/gql_models/agent.py | 51 +++++++++++++++++++ src/ai/backend/manager/registry.py | 11 ++-- 4 files changed, 68 insertions(+), 3 deletions(-) diff --git a/src/ai/backend/manager/api/schema.graphql b/src/ai/backend/manager/api/schema.graphql index 1683464d46..6214c094cb 100644 --- a/src/ai/backend/manager/api/schema.graphql +++ b/src/ai/backend/manager/api/schema.graphql @@ -1718,6 +1718,7 @@ type Mutations { This action cannot be undone. """ purge_user(email: String!, props: PurgeUserInput!): PurgeUser + scan_gpu_alloc_maps(agent_id: String): ScanGPUAllocMap create_keypair(props: KeyPairInput!, user_id: String!): CreateKeyPair modify_keypair(access_key: String!, props: ModifyKeyPairInput!): ModifyKeyPair delete_keypair(access_key: String!): DeleteKeyPair @@ -2112,6 +2113,12 @@ input PurgeUserInput { purge_shared_vfolders: Boolean } +type ScanGPUAllocMap { + ok: Boolean + msg: String + task_id: UUID +} + type CreateKeyPair { ok: Boolean msg: String diff --git a/src/ai/backend/manager/models/gql.py b/src/ai/backend/manager/models/gql.py index 81669804f1..a95bde42af 100644 --- a/src/ai/backend/manager/models/gql.py +++ b/src/ai/backend/manager/models/gql.py @@ -73,6 +73,7 @@ AgentSummary, AgentSummaryList, ModifyAgent, + ScanGPUAllocMap, ) from .gql_models.domain import ( CreateDomainNode, @@ -250,6 +251,7 @@ class Mutations(graphene.ObjectType): modify_user = ModifyUser.Field() delete_user = DeleteUser.Field() purge_user = PurgeUser.Field() + scan_gpu_alloc_maps = ScanGPUAllocMap.Field() # admin only create_keypair = CreateKeyPair.Field() diff --git a/src/ai/backend/manager/models/gql_models/agent.py b/src/ai/backend/manager/models/gql_models/agent.py index 532541a36a..6cb289f62b 100644 --- a/src/ai/backend/manager/models/gql_models/agent.py +++ b/src/ai/backend/manager/models/gql_models/agent.py @@ -1,5 +1,6 @@ from __future__ import annotations +import json import uuid from collections.abc import Iterable, Mapping, Sequence from typing import ( @@ -18,6 +19,7 @@ from sqlalchemy.ext.asyncio import AsyncConnection as SAConnection from ai.backend.common import msgpack, redis_helper +from ai.backend.common.bgtask import ProgressReporter from ai.backend.common.types import ( AccessKey, AgentId, @@ -878,3 +880,52 @@ async def mutate( update_query = sa.update(agents).values(data).where(agents.c.id == id) return await simple_db_mutate(cls, graph_ctx, update_query) + + +class ScanGPUAllocMap(graphene.Mutation): + allowed_roles = (UserRole.SUPERADMIN,) + + class Arguments: + agent_id = graphene.String() + + ok = graphene.Boolean() + msg = graphene.String() + task_id = graphene.UUID() + + @classmethod + @privileged_mutation( + UserRole.SUPERADMIN, + lambda id, **kwargs: (None, id), + ) + async def mutate( + cls, + root, + info: graphene.ResolveInfo, + agent_id: Optional[str] = None, + ) -> ScanGPUAllocMap: + graph_ctx: GraphQueryContext = info.context + + if agent_id: + agent_ids = [agent_id] + else: + agent_ids = [agent.id async for agent in graph_ctx.registry.enumerate_instances()] + + async def _scan_alloc_map_task(reporter: ProgressReporter) -> None: + for index, agent_id in enumerate(agent_ids, start=1): + await reporter.update( + increment=1, message=f"Agent {agent_id} scannning... ({index}/{len(agent_ids)})" + ) + + async with graph_ctx.registry.agent_cache.rpc_context(AgentId(agent_id)) as rpc: + alloc_map: Mapping[str, Any] = await rpc.call.scan_gpu_alloc_map() + key = f"gpu_alloc_map.{agent_id}" + + await redis_helper.execute( + graph_ctx.registry.redis_stat, + lambda r: r.set(name=key, value=json.dumps(alloc_map)), + ) + + await reporter.update(increment=1, message="GPU alloc map scanning completed") + + task_id = await graph_ctx.background_task_manager.start(_scan_alloc_map_task) + return ScanGPUAllocMap(ok=True, msg="", task_id=task_id) diff --git a/src/ai/backend/manager/registry.py b/src/ai/backend/manager/registry.py index 067f0d7316..4bc72c8de4 100644 --- a/src/ai/backend/manager/registry.py +++ b/src/ai/backend/manager/registry.py @@ -4,6 +4,7 @@ import base64 import copy import itertools +import json import logging import re import secrets @@ -436,9 +437,13 @@ async def gather_storage_hwinfo(self, vfolder_host: str) -> HardwareMetadata: ) async def scan_gpu_alloc_map(self, instance_id: AgentId) -> Mapping[str, Any]: - agent = await self.get_instance(instance_id, agents.c.addr) - async with self.agent_cache.rpc_context(agent["id"]) as rpc: - return await rpc.call.scan_gpu_alloc_map() + raw_alloc_map = await redis_helper.execute( + self.redis_stat, lambda r: r.get(f"gpu_alloc_map.{instance_id}") + ) + if raw_alloc_map: + return json.loads(raw_alloc_map) + else: + return {} async def create_session( self,