diff --git a/changes/3293.feature.md b/changes/3293.feature.md new file mode 100644 index 0000000000..35bf8c4669 --- /dev/null +++ b/changes/3293.feature.md @@ -0,0 +1 @@ +Cache `gpu_alloc_map` in Redis, and Add `RescanGPUAllocMaps` mutation for update the `gpu_alloc_map`s. diff --git a/src/ai/backend/manager/api/schema.graphql b/src/ai/backend/manager/api/schema.graphql index 97d011a064..de55300a9f 100644 --- a/src/ai/backend/manager/api/schema.graphql +++ b/src/ai/backend/manager/api/schema.graphql @@ -1718,6 +1718,12 @@ type Mutations { This action cannot be undone. """ purge_user(email: String!, props: PurgeUserInput!): PurgeUser + + """Added in 25.1.0.""" + rescan_gpu_alloc_maps( + """Agent ID to rescan GPU alloc map, Pass None to rescan all agents""" + agent_id: String + ): RescanGPUAllocMaps create_keypair(props: KeyPairInput!, user_id: String!): CreateKeyPair modify_keypair(access_key: String!, props: ModifyKeyPairInput!): ModifyKeyPair delete_keypair(access_key: String!): DeleteKeyPair @@ -2112,6 +2118,13 @@ input PurgeUserInput { purge_shared_vfolders: Boolean } +"""Added in 25.1.0.""" +type RescanGPUAllocMaps { + ok: Boolean + msg: String + task_id: UUID +} + type CreateKeyPair { ok: Boolean msg: String diff --git a/src/ai/backend/manager/models/gql.py b/src/ai/backend/manager/models/gql.py index ba3a84fa15..ef43e09711 100644 --- a/src/ai/backend/manager/models/gql.py +++ b/src/ai/backend/manager/models/gql.py @@ -73,6 +73,7 @@ AgentSummary, AgentSummaryList, ModifyAgent, + RescanGPUAllocMaps, ) from .gql_models.domain import ( CreateDomainNode, @@ -250,6 +251,7 @@ class Mutations(graphene.ObjectType): modify_user = ModifyUser.Field() delete_user = DeleteUser.Field() purge_user = PurgeUser.Field() + rescan_gpu_alloc_maps = RescanGPUAllocMaps.Field(description="Added in 25.1.0.") # admin only create_keypair = CreateKeyPair.Field() diff --git a/src/ai/backend/manager/models/gql_models/agent.py b/src/ai/backend/manager/models/gql_models/agent.py index e858f43bde..37ebd3b566 100644 --- a/src/ai/backend/manager/models/gql_models/agent.py +++ b/src/ai/backend/manager/models/gql_models/agent.py @@ -1,5 +1,8 @@ from __future__ import annotations +import asyncio +import json +import logging import uuid from collections.abc import Iterable, Mapping, Sequence from typing import ( @@ -18,12 +21,14 @@ from sqlalchemy.ext.asyncio import AsyncConnection as SAConnection from ai.backend.common import msgpack, redis_helper +from ai.backend.common.bgtask import ProgressReporter from ai.backend.common.types import ( AccessKey, AgentId, BinarySize, HardwareMetadata, ) +from ai.backend.logging.utils import BraceStyleAdapter from ..agent import ( AgentRow, @@ -61,6 +66,8 @@ if TYPE_CHECKING: from ..gql import GraphQueryContext +log = BraceStyleAdapter(logging.getLogger(__spec__.name)) + __all__ = ( "Agent", "AgentNode", @@ -181,7 +188,13 @@ async def resolve_live_stat(self, info: graphene.ResolveInfo) -> Any: async def resolve_gpu_alloc_map(self, info: graphene.ResolveInfo) -> Mapping[str, int]: ctx: GraphQueryContext = info.context - return await ctx.registry.scan_gpu_alloc_map(self.id) + raw_alloc_map = await redis_helper.execute( + ctx.redis_stat, lambda r: r.get(f"gpu_alloc_map.{self.id}") + ) + if raw_alloc_map: + return json.loads(raw_alloc_map) + else: + return {} async def resolve_hardware_metadata( self, @@ -435,7 +448,13 @@ async def resolve_container_count(self, info: graphene.ResolveInfo) -> int: async def resolve_gpu_alloc_map(self, info: graphene.ResolveInfo) -> Mapping[str, int]: ctx: GraphQueryContext = info.context - return await ctx.registry.scan_gpu_alloc_map(self.id) + raw_alloc_map = await redis_helper.execute( + ctx.redis_stat, lambda r: r.get(f"gpu_alloc_map.{self.id}") + ) + if raw_alloc_map: + return json.loads(raw_alloc_map) + else: + return {} _queryfilter_fieldspec: Mapping[str, FieldSpecItem] = { "id": ("id", None), @@ -878,3 +897,73 @@ async def mutate( update_query = sa.update(agents).values(data).where(agents.c.id == id) return await simple_db_mutate(cls, graph_ctx, update_query) + + +class RescanGPUAllocMaps(graphene.Mutation): + allowed_roles = (UserRole.SUPERADMIN,) + + class Meta: + description = "Added in 25.1.0." + + class Arguments: + agent_id = graphene.String( + description="Agent ID to rescan GPU alloc map, Pass None to rescan all agents", + required=False, + ) + + ok = graphene.Boolean() + msg = graphene.String() + task_id = graphene.UUID() + + @classmethod + @privileged_mutation( + UserRole.SUPERADMIN, + lambda id, **kwargs: (None, id), + ) + async def mutate( + cls, + root, + info: graphene.ResolveInfo, + agent_id: Optional[str] = None, + ) -> RescanGPUAllocMaps: + log.info("rescanning GPU alloc maps") + graph_ctx: GraphQueryContext = info.context + + if agent_id: + agent_ids = [agent_id] + else: + agent_ids = [agent.id async for agent in graph_ctx.registry.enumerate_instances()] + + async def _scan_single_agent(agent_id: str, reporter: ProgressReporter) -> None: + await reporter.update(message=f"Agent {agent_id} GPU alloc map scanning...") + + reporter_msg = "" + try: + alloc_map: Mapping[str, Any] = await graph_ctx.registry.scan_gpu_alloc_map( + AgentId(agent_id) + ) + key = f"gpu_alloc_map.{agent_id}" + await redis_helper.execute( + graph_ctx.registry.redis_stat, + lambda r: r.set(name=key, value=json.dumps(alloc_map)), + ) + except Exception as e: + reporter_msg = f"Failed to scan GPU alloc map for agent {agent_id}: {str(e)}" + log.error(reporter_msg) + else: + reporter_msg = f"Agent {agent_id} GPU alloc map scanned." + + await reporter.update( + increment=1, + message=reporter_msg, + ) + + async def _rescan_alloc_map_task(reporter: ProgressReporter) -> None: + async with asyncio.TaskGroup() as tg: + for agent_id in agent_ids: + tg.create_task(_scan_single_agent(agent_id, reporter)) + + await reporter.update(message="GPU alloc map scanning completed") + + task_id = await graph_ctx.background_task_manager.start(_rescan_alloc_map_task) + return RescanGPUAllocMaps(ok=True, msg="", task_id=task_id) diff --git a/tests/manager/conftest.py b/tests/manager/conftest.py index 51233ca973..6b760eea2f 100644 --- a/tests/manager/conftest.py +++ b/tests/manager/conftest.py @@ -28,6 +28,7 @@ from unittest.mock import AsyncMock, MagicMock from urllib.parse import quote_plus as urlquote +import aiofiles.os import aiohttp import asyncpg import pytest @@ -172,6 +173,8 @@ def local_config( redis_addr = redis_container[1] postgres_addr = postgres_container[1] + build_root = Path(os.environ["BACKEND_BUILD_ROOT"]) + # Establish a self-contained config. cfg = LocalConfig({ **etcd_config_iv.check({ @@ -206,6 +209,7 @@ def local_config( "service-addr": HostPortPair("127.0.0.1", 29100 + get_parallel_slot() * 10), "allowed-plugins": set(), "disabled-plugins": set(), + "rpc-auth-manager-keypair": f"{build_root}/fixtures/manager/manager.key_secret", }, "debug": { "enabled": False, @@ -257,7 +261,15 @@ def etcd_fixture( "volumes": { "_mount": str(vfolder_mount), "_fsprefix": str(vfolder_fsprefix), - "_default_host": str(vfolder_host), + "default_host": str(vfolder_host), + "proxies": { + "local": { + "client_api": "http://127.0.0.1:6021", + "manager_api": "https://127.0.0.1:6022", + "secret": "some-secret-shared-with-storage-proxy", + "ssl_verify": "false", + } + }, }, "nodes": {}, "config": { @@ -420,7 +432,12 @@ async def database_engine(local_config, database): @pytest.fixture() -def database_fixture(local_config, test_db, database) -> Iterator[None]: +def extra_fixtures(): + return {} + + +@pytest.fixture() +def database_fixture(local_config, test_db, database, extra_fixtures) -> Iterator[None]: """ Populate the example data as fixtures to the database and delete them after use. @@ -431,12 +448,20 @@ def database_fixture(local_config, test_db, database) -> Iterator[None]: db_url = f"postgresql+asyncpg://{db_user}:{urlquote(db_pass)}@{db_addr}/{test_db}" build_root = Path(os.environ["BACKEND_BUILD_ROOT"]) + + extra_fixture_file = tempfile.NamedTemporaryFile(delete=False) + extra_fixture_file_path = Path(extra_fixture_file.name) + + with open(extra_fixture_file_path, "w") as f: + json.dump(extra_fixtures, f) + fixture_paths = [ build_root / "fixtures" / "manager" / "example-users.json", build_root / "fixtures" / "manager" / "example-keypairs.json", build_root / "fixtures" / "manager" / "example-set-user-main-access-keys.json", build_root / "fixtures" / "manager" / "example-resource-presets.json", build_root / "fixtures" / "manager" / "example-container-registries-harbor.json", + extra_fixture_file_path, ] async def init_fixture() -> None: @@ -461,6 +486,9 @@ async def init_fixture() -> None: yield async def clean_fixture() -> None: + if extra_fixture_file_path.exists(): + await aiofiles.os.remove(extra_fixture_file_path) + engine: SAEngine = sa.ext.asyncio.create_async_engine( db_url, connect_args=pgsql_connect_opts, diff --git a/tests/manager/models/gql_models/BUILD b/tests/manager/models/gql_models/BUILD new file mode 100644 index 0000000000..75b8f46de9 --- /dev/null +++ b/tests/manager/models/gql_models/BUILD @@ -0,0 +1 @@ +python_tests(name="tests") diff --git a/tests/manager/models/gql_models/test_agent.py b/tests/manager/models/gql_models/test_agent.py new file mode 100644 index 0000000000..de70c4ea4e --- /dev/null +++ b/tests/manager/models/gql_models/test_agent.py @@ -0,0 +1,192 @@ +import asyncio +import json +from unittest.mock import AsyncMock, patch + +import attr +import pytest +from graphene import Schema +from graphene.test import Client + +from ai.backend.common import redis_helper +from ai.backend.common.events import BgtaskDoneEvent, EventDispatcher +from ai.backend.common.types import AgentId +from ai.backend.manager.api.context import RootContext +from ai.backend.manager.models.agent import AgentStatus +from ai.backend.manager.models.gql import GraphQueryContext, Mutations, Queries +from ai.backend.manager.server import ( + agent_registry_ctx, + background_task_ctx, + database_ctx, + event_dispatcher_ctx, + hook_plugin_ctx, + monitoring_ctx, + network_plugin_ctx, + redis_ctx, + shared_config_ctx, + storage_manager_ctx, +) + + +@pytest.fixture(scope="module") +def client() -> Client: + return Client(Schema(query=Queries, mutation=Mutations, auto_camelcase=False)) + + +def get_graphquery_context(root_context: RootContext) -> GraphQueryContext: + return GraphQueryContext( + schema=None, # type: ignore + dataloader_manager=None, # type: ignore + local_config=None, # type: ignore + shared_config=None, # type: ignore + etcd=None, # type: ignore + user={"domain": "default", "role": "superadmin"}, + access_key="AKIAIOSFODNN7EXAMPLE", + db=root_context.db, # type: ignore + redis_stat=None, # type: ignore + redis_image=None, # type: ignore + redis_live=None, # type: ignore + manager_status=None, # type: ignore + known_slot_types=None, # type: ignore + background_task_manager=root_context.background_task_manager, # type: ignore + storage_manager=None, # type: ignore + registry=root_context.registry, # type: ignore + idle_checker_host=None, # type: ignore + network_plugin_ctx=None, # type: ignore + ) + + +def agent_template(id: str, status: AgentStatus): + return { + "id": id, + "status": status.name, + "scaling_group": "default", + "schedulable": True, + "region": "local", + "available_slots": {}, + "occupied_slots": {}, + "addr": "tcp://127.0.0.1:6011", + "public_host": "127.0.0.1", + "version": "24.12.0a1", + "architecture": "x86_64", + "compute_plugins": {}, + } + + +EXTRA_FIXTURES = { + "agents": [ + agent_template("i-ag1", AgentStatus.ALIVE), + agent_template("i-ag2", AgentStatus.ALIVE), + agent_template("i-ag3", AgentStatus.ALIVE), + agent_template("i-ag4", AgentStatus.LOST), + ] +} + + +@patch("ai.backend.manager.registry.AgentRegistry.scan_gpu_alloc_map", new_callable=AsyncMock) +@pytest.mark.asyncio +@pytest.mark.parametrize( + "test_case, extra_fixtures", + [ + ( + { + "mock_agent_responses": [ + { + "00000000-0000-0000-0000-000000000001": "10.00", + "00000000-0000-0000-0000-000000000002": "5.00", + }, + { + "00000000-0000-0000-0000-000000000011": "15.00", + "00000000-0000-0000-0000-000000000012": "7.00", + }, + Exception("RPC call error"), # simulate an error + None, + ], + "expected": { + "redis": [ + { + "00000000-0000-0000-0000-000000000001": "10.00", + "00000000-0000-0000-0000-000000000002": "5.00", + }, + { + "00000000-0000-0000-0000-000000000011": "15.00", + "00000000-0000-0000-0000-000000000012": "7.00", + }, + None, + None, + ], + }, + }, + EXTRA_FIXTURES, + ), + ], +) +async def test_scan_gpu_alloc_maps( + mock_agent_responses, + client, + local_config, + etcd_fixture, + database_fixture, + create_app_and_client, + test_case, + extra_fixtures, +): + test_app, _ = await create_app_and_client( + [ + shared_config_ctx, + database_ctx, + redis_ctx, + monitoring_ctx, + hook_plugin_ctx, + event_dispatcher_ctx, + storage_manager_ctx, + network_plugin_ctx, + agent_registry_ctx, + background_task_ctx, + ], + [], + ) + + root_ctx: RootContext = test_app["_root.context"] + dispatcher: EventDispatcher = root_ctx.event_dispatcher + done_handler_ctx: dict = {} + done_event = asyncio.Event() + + async def done_sub( + context: None, + source: AgentId, + event: BgtaskDoneEvent, + ) -> None: + update_body = attr.asdict(event) # type: ignore + done_handler_ctx.update(**update_body) + done_event.set() + + dispatcher.subscribe(BgtaskDoneEvent, None, done_sub) + + mock_agent_responses.side_effect = test_case["mock_agent_responses"] + + context = get_graphquery_context(root_ctx) + query = """ + mutation { + rescan_gpu_alloc_maps { + ok + msg + task_id + } + } + """ + + res = await client.execute_async(query, variables={}, context_value=context) + assert res["data"]["rescan_gpu_alloc_maps"]["ok"] + await done_event.wait() + + assert str(done_handler_ctx["task_id"]) == res["data"]["rescan_gpu_alloc_maps"]["task_id"] + + alloc_map_keys = [f"gpu_alloc_map.{agent['id']}" for agent in extra_fixtures["agents"]] + raw_alloc_map_cache = await redis_helper.execute( + root_ctx.redis_stat, + lambda r: r.mget(*alloc_map_keys), + ) + alloc_map_cache = [ + json.loads(stat) if stat is not None else None for stat in raw_alloc_map_cache + ] + assert alloc_map_cache == test_case["expected"]["redis"]