From 129a37a1f439b53126bdd888921cb21aabcc811c Mon Sep 17 00:00:00 2001 From: Mike Degatano Date: Wed, 5 Feb 2025 08:24:37 -0500 Subject: [PATCH] Prevent race condition with location reload and backups list (#5602) --- supervisor/backups/manager.py | 49 +++++++++++++++++++++++------------ tests/api/test_backups.py | 49 +++++++++++++++++++++++++++++++++++ 2 files changed, 81 insertions(+), 17 deletions(-) diff --git a/supervisor/backups/manager.py b/supervisor/backups/manager.py index 4f5d0d81263..af236695c96 100644 --- a/supervisor/backups/manager.py +++ b/supervisor/backups/manager.py @@ -230,13 +230,15 @@ def load(self) -> Awaitable[None]: async def reload(self, location: str | None | type[DEFAULT] = DEFAULT) -> bool: """Load exists backups.""" + backups: dict[str, Backup] = {} + async def _load_backup(location_name: str | None, tar_file: Path) -> bool: """Load the backup.""" backup = Backup(self.coresys, tar_file, "temp", location_name) if await backup.load(): - if backup.slug in self._backups: + if backup.slug in backups: try: - self._backups[backup.slug].consolidate(backup) + backups[backup.slug].consolidate(backup) except BackupInvalidError as err: _LOGGER.error( "Ignoring backup %s in %s due to: %s", @@ -247,23 +249,18 @@ async def _load_backup(location_name: str | None, tar_file: Path) -> bool: return False else: - self._backups[backup.slug] = Backup( + backups[backup.slug] = Backup( self.coresys, tar_file, backup.slug, location_name, backup.data ) return True return False - # Single location refresh clears out just that part of the cache and rebuilds it - if location != DEFAULT: - locations = {location: self.backup_locations[location]} - for backup in self.list_backups: - if location in backup.all_locations: - del backup.all_locations[location] - else: - locations = self.backup_locations - self._backups = {} - + locations = ( + self.backup_locations + if location == DEFAULT + else {location: self.backup_locations[location]} + ) tasks = [ self.sys_create_task(_load_backup(_location, tar_file)) for _location, path in locations.items() @@ -274,10 +271,28 @@ async def _load_backup(location_name: str | None, tar_file: Path) -> bool: if tasks: await asyncio.wait(tasks) - # Remove any backups with no locations from cache (only occurs in single location refresh) - if location != DEFAULT: - for backup in list(self.list_backups): - if not backup.all_locations: + # For a full reload, replace our cache with new one + if location == DEFAULT: + self._backups = backups + return True + + # For a location reload, merge new cache in with existing + for backup in list(self.list_backups): + if backup.slug in backups: + try: + backup.consolidate(backups[backup.slug]) + except BackupInvalidError as err: + _LOGGER.error( + "Ignoring backup %s in %s due to: %s", + backup.slug, + location, + err, + ) + + elif location in backup.all_locations: + if len(backup.all_locations) > 1: + del backup.all_locations[location] + else: del self._backups[backup.slug] return True diff --git a/tests/api/test_backups.py b/tests/api/test_backups.py index bfa24500b83..45ddca10a96 100644 --- a/tests/api/test_backups.py +++ b/tests/api/test_backups.py @@ -1300,3 +1300,52 @@ async def test_missing_file_removes_backup_from_cache( # Wait for reload task to complete and confirm backup is removed await asyncio.sleep(0) assert not coresys.backups.list_backups + + +@pytest.mark.usefixtures("tmp_supervisor_data") +async def test_immediate_list_after_missing_file_restore( + api_client: TestClient, coresys: CoreSys +): + """Test race with reload for missing file on restore does not error.""" + coresys.core.state = CoreState.RUNNING + coresys.hardware.disk.get_disk_free_space = lambda x: 5000 + + backup_file = get_fixture_path("backup_example.tar") + bad_location = Path(copy(backup_file, coresys.config.path_backup)) + # Copy a second backup in so there's something to reload later + copy(get_fixture_path("backup_example_enc.tar"), coresys.config.path_backup) + await coresys.backups.reload() + + # After reload, remove one of the file and confirm we have an out of date cache + bad_location.unlink() + assert coresys.backups.get("7fed74c8").all_locations.keys() == {None} + + event = asyncio.Event() + orig_wait = asyncio.wait + + async def mock_wait(tasks: list[asyncio.Task], *args, **kwargs): + """Mock for asyncio wait that allows force of race condition.""" + if tasks[0].get_coro().__qualname__.startswith("BackupManager.reload"): + await event.wait() + return await orig_wait(tasks, *args, **kwargs) + + with patch("supervisor.backups.manager.asyncio.wait", new=mock_wait): + resp = await api_client.post( + "/backups/7fed74c8/restore/partial", + json={"location": ".local", "folders": ["ssl"]}, + ) + assert resp.status == 404 + + await asyncio.sleep(0) + resp = await api_client.get("/backups") + assert resp.status == 200 + result = await resp.json() + assert len(result["data"]["backups"]) == 2 + + event.set() + await asyncio.sleep(0.1) + resp = await api_client.get("/backups") + assert resp.status == 200 + result = await resp.json() + assert len(result["data"]["backups"]) == 1 + assert result["data"]["backups"][0]["slug"] == "93b462f8"