Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(BA-460): Cache gpu_alloc_map in Redis, and Add RescanGPUAllocMaps mutation #3293

Open
wants to merge 27 commits into
base: topic/06-13-feat_support_scanning_gpu_allocation
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
bc08613
feat: Cache gpu_alloc_map, and Add ScanGPUAllocMap mutation
jopemachine Dec 24, 2024
ce43cce
chore: Add news fragment
jopemachine Dec 24, 2024
cd07a40
chore: Add news fragment
jopemachine Dec 24, 2024
7757955
chore: fix typo
jopemachine Dec 24, 2024
4934ee2
chore: Improve news fragment
jopemachine Dec 24, 2024
97237dd
fix: Add milestone comment
jopemachine Dec 24, 2024
557d749
fix: Wrong impl of AgentRegistry.scan_gpu_alloc_map
jopemachine Dec 24, 2024
2046e50
fix: Add `extra_fixtures`
jopemachine Dec 24, 2024
c97b426
feat: Add `test_scan_gpu_alloc_maps` test case
jopemachine Dec 24, 2024
d9e75f7
feat: Add update call count check
jopemachine Dec 24, 2024
7cef622
fix: Improve `test_scan_gpu_alloc_maps`
jopemachine Dec 26, 2024
d743951
fix: Improve `test_scan_gpu_alloc_maps`
jopemachine Dec 26, 2024
f097a0f
chore: Rename variables
jopemachine Dec 26, 2024
7c7a974
fix: `ScanGPUAllocMaps` -> `RescanGPUAllocMaps`
jopemachine Dec 26, 2024
d5f6bbf
fix: Broken test
jopemachine Dec 26, 2024
c55c329
fix: Remove useless `_default_host`
jopemachine Dec 26, 2024
9e0beaf
chore: Rename news fragment
jopemachine Dec 26, 2024
16a6c8b
feat: Improve error handling
jopemachine Dec 26, 2024
60bab91
fix: Improve exception handling and test case
jopemachine Dec 26, 2024
ef503e5
fix: Replace useless `mock_agent_registry_ctx` with local_config's `r…
jopemachine Dec 26, 2024
3901605
fix: Wrong reference to `redis_stat`
jopemachine Dec 26, 2024
141e103
docs: Add description about agent_id
jopemachine Dec 26, 2024
532b2df
chore: update GraphQL schema dump
jopemachine Dec 26, 2024
fbf5115
feat: Call agent rpc call in parallel
jopemachine Dec 26, 2024
486089c
fix: Update milestone
jopemachine Jan 8, 2025
3b9acd2
fix: lint
jopemachine Jan 8, 2025
e5fbcb9
fix: Update milestone
jopemachine Jan 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/3293.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Cache `gpu_alloc_map` in Redis, and Add `RescanGPUAllocMaps` mutation for update the `gpu_alloc_map`s.
13 changes: 13 additions & 0 deletions src/ai/backend/manager/api/schema.graphql
Original file line number Diff line number Diff line change
Expand Up @@ -1718,6 +1718,12 @@
This action cannot be undone.
"""
purge_user(email: String!, props: PurgeUserInput!): PurgeUser

"""Added in 25.1.0."""
rescan_gpu_alloc_maps(

Check notice on line 1723 in src/ai/backend/manager/api/schema.graphql

View workflow job for this annotation

GitHub Actions / GraphQL Inspector

Field 'rescan_gpu_alloc_maps' was added to object type 'Mutations'

Field 'rescan_gpu_alloc_maps' was added to object type 'Mutations'
"""Agent ID to rescan GPU alloc map, Pass None to rescan all agents"""
agent_id: String
): RescanGPUAllocMaps
create_keypair(props: KeyPairInput!, user_id: String!): CreateKeyPair
modify_keypair(access_key: String!, props: ModifyKeyPairInput!): ModifyKeyPair
delete_keypair(access_key: String!): DeleteKeyPair
Expand Down Expand Up @@ -2112,6 +2118,13 @@
purge_shared_vfolders: Boolean
}

"""Added in 25.1.0."""
type RescanGPUAllocMaps {

Check notice on line 2122 in src/ai/backend/manager/api/schema.graphql

View workflow job for this annotation

GitHub Actions / GraphQL Inspector

Type 'RescanGPUAllocMaps' was added

Type 'RescanGPUAllocMaps' was added
ok: Boolean
msg: String
task_id: UUID
}

type CreateKeyPair {
ok: Boolean
msg: String
Expand Down
2 changes: 2 additions & 0 deletions src/ai/backend/manager/models/gql.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
AgentSummary,
AgentSummaryList,
ModifyAgent,
RescanGPUAllocMaps,
)
from .gql_models.domain import (
CreateDomainNode,
Expand Down Expand Up @@ -250,6 +251,7 @@ class Mutations(graphene.ObjectType):
modify_user = ModifyUser.Field()
delete_user = DeleteUser.Field()
purge_user = PurgeUser.Field()
rescan_gpu_alloc_maps = RescanGPUAllocMaps.Field(description="Added in 25.1.0.")

# admin only
create_keypair = CreateKeyPair.Field()
Expand Down
93 changes: 91 additions & 2 deletions src/ai/backend/manager/models/gql_models/agent.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from __future__ import annotations

import asyncio
import json
import logging
import uuid
from collections.abc import Iterable, Mapping, Sequence
from typing import (
Expand All @@ -18,12 +21,14 @@
from sqlalchemy.ext.asyncio import AsyncConnection as SAConnection

from ai.backend.common import msgpack, redis_helper
from ai.backend.common.bgtask import ProgressReporter
from ai.backend.common.types import (
AccessKey,
AgentId,
BinarySize,
HardwareMetadata,
)
from ai.backend.logging.utils import BraceStyleAdapter

from ..agent import (
AgentRow,
Expand Down Expand Up @@ -61,6 +66,8 @@
if TYPE_CHECKING:
from ..gql import GraphQueryContext

log = BraceStyleAdapter(logging.getLogger(__spec__.name))

__all__ = (
"Agent",
"AgentNode",
Expand Down Expand Up @@ -181,7 +188,13 @@ async def resolve_live_stat(self, info: graphene.ResolveInfo) -> Any:

async def resolve_gpu_alloc_map(self, info: graphene.ResolveInfo) -> Mapping[str, int]:
ctx: GraphQueryContext = info.context
return await ctx.registry.scan_gpu_alloc_map(self.id)
raw_alloc_map = await redis_helper.execute(
ctx.redis_stat, lambda r: r.get(f"gpu_alloc_map.{self.id}")
)
if raw_alloc_map:
return json.loads(raw_alloc_map)
else:
return {}

async def resolve_hardware_metadata(
self,
Expand Down Expand Up @@ -435,7 +448,13 @@ async def resolve_container_count(self, info: graphene.ResolveInfo) -> int:

async def resolve_gpu_alloc_map(self, info: graphene.ResolveInfo) -> Mapping[str, int]:
ctx: GraphQueryContext = info.context
return await ctx.registry.scan_gpu_alloc_map(self.id)
raw_alloc_map = await redis_helper.execute(
ctx.redis_stat, lambda r: r.get(f"gpu_alloc_map.{self.id}")
)
if raw_alloc_map:
return json.loads(raw_alloc_map)
else:
return {}

_queryfilter_fieldspec: Mapping[str, FieldSpecItem] = {
"id": ("id", None),
Expand Down Expand Up @@ -878,3 +897,73 @@ async def mutate(

update_query = sa.update(agents).values(data).where(agents.c.id == id)
return await simple_db_mutate(cls, graph_ctx, update_query)


class RescanGPUAllocMaps(graphene.Mutation):
allowed_roles = (UserRole.SUPERADMIN,)

class Meta:
description = "Added in 25.1.0."

class Arguments:
agent_id = graphene.String(
description="Agent ID to rescan GPU alloc map, Pass None to rescan all agents",
required=False,
)

ok = graphene.Boolean()
msg = graphene.String()
task_id = graphene.UUID()

@classmethod
@privileged_mutation(
UserRole.SUPERADMIN,
lambda id, **kwargs: (None, id),
)
async def mutate(
cls,
root,
info: graphene.ResolveInfo,
agent_id: Optional[str] = None,
) -> RescanGPUAllocMaps:
log.info("rescanning GPU alloc maps")
graph_ctx: GraphQueryContext = info.context

if agent_id:
agent_ids = [agent_id]
else:
agent_ids = [agent.id async for agent in graph_ctx.registry.enumerate_instances()]

async def _scan_single_agent(agent_id: str, reporter: ProgressReporter) -> None:
await reporter.update(message=f"Agent {agent_id} GPU alloc map scanning...")

reporter_msg = ""
try:
alloc_map: Mapping[str, Any] = await graph_ctx.registry.scan_gpu_alloc_map(
AgentId(agent_id)
)
key = f"gpu_alloc_map.{agent_id}"
await redis_helper.execute(
graph_ctx.registry.redis_stat,
lambda r: r.set(name=key, value=json.dumps(alloc_map)),
)
except Exception as e:
reporter_msg = f"Failed to scan GPU alloc map for agent {agent_id}: {str(e)}"
log.error(reporter_msg)
else:
reporter_msg = f"Agent {agent_id} GPU alloc map scanned."

await reporter.update(
increment=1,
message=reporter_msg,
)

async def _rescan_alloc_map_task(reporter: ProgressReporter) -> None:
async with asyncio.TaskGroup() as tg:
for agent_id in agent_ids:
tg.create_task(_scan_single_agent(agent_id, reporter))

await reporter.update(message="GPU alloc map scanning completed")

task_id = await graph_ctx.background_task_manager.start(_rescan_alloc_map_task)
return RescanGPUAllocMaps(ok=True, msg="", task_id=task_id)
32 changes: 30 additions & 2 deletions tests/manager/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from unittest.mock import AsyncMock, MagicMock
from urllib.parse import quote_plus as urlquote

import aiofiles.os
import aiohttp
import asyncpg
import pytest
Expand Down Expand Up @@ -172,6 +173,8 @@ def local_config(
redis_addr = redis_container[1]
postgres_addr = postgres_container[1]

build_root = Path(os.environ["BACKEND_BUILD_ROOT"])

# Establish a self-contained config.
cfg = LocalConfig({
**etcd_config_iv.check({
Expand Down Expand Up @@ -206,6 +209,7 @@ def local_config(
"service-addr": HostPortPair("127.0.0.1", 29100 + get_parallel_slot() * 10),
"allowed-plugins": set(),
"disabled-plugins": set(),
"rpc-auth-manager-keypair": f"{build_root}/fixtures/manager/manager.key_secret",
},
"debug": {
"enabled": False,
Expand Down Expand Up @@ -257,7 +261,15 @@ def etcd_fixture(
"volumes": {
"_mount": str(vfolder_mount),
"_fsprefix": str(vfolder_fsprefix),
"_default_host": str(vfolder_host),
"default_host": str(vfolder_host),
"proxies": {
"local": {
"client_api": "http://127.0.0.1:6021",
"manager_api": "https://127.0.0.1:6022",
"secret": "some-secret-shared-with-storage-proxy",
"ssl_verify": "false",
}
},
},
"nodes": {},
"config": {
Expand Down Expand Up @@ -420,7 +432,12 @@ async def database_engine(local_config, database):


@pytest.fixture()
def database_fixture(local_config, test_db, database) -> Iterator[None]:
def extra_fixtures():
return {}


@pytest.fixture()
def database_fixture(local_config, test_db, database, extra_fixtures) -> Iterator[None]:
"""
Populate the example data as fixtures to the database
and delete them after use.
Expand All @@ -431,12 +448,20 @@ def database_fixture(local_config, test_db, database) -> Iterator[None]:
db_url = f"postgresql+asyncpg://{db_user}:{urlquote(db_pass)}@{db_addr}/{test_db}"

build_root = Path(os.environ["BACKEND_BUILD_ROOT"])

extra_fixture_file = tempfile.NamedTemporaryFile(delete=False)
extra_fixture_file_path = Path(extra_fixture_file.name)

with open(extra_fixture_file_path, "w") as f:
json.dump(extra_fixtures, f)

fixture_paths = [
build_root / "fixtures" / "manager" / "example-users.json",
build_root / "fixtures" / "manager" / "example-keypairs.json",
build_root / "fixtures" / "manager" / "example-set-user-main-access-keys.json",
build_root / "fixtures" / "manager" / "example-resource-presets.json",
build_root / "fixtures" / "manager" / "example-container-registries-harbor.json",
extra_fixture_file_path,
]

async def init_fixture() -> None:
Expand All @@ -461,6 +486,9 @@ async def init_fixture() -> None:
yield

async def clean_fixture() -> None:
if extra_fixture_file_path.exists():
await aiofiles.os.remove(extra_fixture_file_path)

engine: SAEngine = sa.ext.asyncio.create_async_engine(
db_url,
connect_args=pgsql_connect_opts,
Expand Down
1 change: 1 addition & 0 deletions tests/manager/models/gql_models/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
python_tests(name="tests")
Loading
Loading