Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: improve task submission and handling on delta update #355

Merged
merged 3 commits into from
Jan 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 20 additions & 21 deletions umu/umu_bspatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,12 @@ def __init__( # noqa: D107
self._arc_manifest: list[ManifestEntry] = self._arc_contents["manifest"]
self._compat_tool = compat_tool
self._thread_pool = thread_pool
self._futures: list[Future] = []
# Collection where each task creates a new file within an existing compatibility tool
self._add: list[Future] = []
# Collection where each task updates an existing file
self._update: list[Future] = []
# Collection where each task verifies an existing file
self._verify: list[Future] = []

def add_binaries(self) -> None:
"""Add binaries within a compatibility tool.
Expand All @@ -131,7 +136,7 @@ def add_binaries(self) -> None:
build_file: Path = self._compat_tool.joinpath(item["name"])
if item["type"] == FileType.File.value:
# Decompress the zstd data and write the file
self._futures.append(
self._add.append(
self._thread_pool.submit(self._write_proton_file, build_file, item)
)
continue
Expand Down Expand Up @@ -160,7 +165,7 @@ def update_binaries(self) -> None:
build_file: Path = self._compat_tool.joinpath(item["name"])
if item["type"] == FileType.File.value:
# For files, apply a binary patch
self._futures.append(
self._update.append(
self._thread_pool.submit(self._patch_proton_file, build_file, item)
)
continue
Expand Down Expand Up @@ -209,33 +214,27 @@ def delete_binaries(self) -> None:
def verify_integrity(self) -> None:
"""Verify the expected mode, size, file and digest of the compatibility tool."""
for item in self._arc_manifest:
self._futures.append(
self._verify.append(
self._thread_pool.submit(self._check_binaries, self._compat_tool, item)
)

def result(self) -> list[Future]:
"""Return the currently submitted tasks."""
return self._futures
def result(self) -> tuple[list[Future], list[Future], list[Future]]:
"""Return all the currently submitted tasks."""
return (self._verify, self._add, self._update)

def _check_binaries(
self, proton: Path, item: ManifestEntry
) -> ManifestEntry | None:
def _check_binaries(self, proton: Path, item: ManifestEntry) -> ManifestEntry:
rpath: Path = proton.joinpath(item["name"])

try:
with rpath.open("rb") as fp:
stats: os.stat_result = os.fstat(fp.fileno())
xxhash: int = 0
if item["size"] != stats.st_size:
log.error(
"Expected size %s, received %s", item["size"], stats.st_size
)
return None
err: str = f"Expected size {item['size']}, received {stats.st_size}"
raise ValueError(err)
if item["mode"] != stats.st_mode:
log.error(
"Expected mode %s, received %s", item["mode"], stats.st_mode
)
return None
err: str = f"Expected mode {item['mode']}, received {stats.st_mode}"
raise ValueError(err)
if stats.st_size > MMAP_MIN:
with mmap(fp.fileno(), length=0, access=ACCESS_READ) as mm:
# Ignore. Passing an mmap is valid here
Expand All @@ -245,11 +244,11 @@ def _check_binaries(
else:
xxhash = xxh3_64_intdigest(fp.read())
if item["xxhash"] != xxhash:
log.error("Expected xxhash %s, received %s", item["xxhash"], xxhash)
return None
err: str = f"Expected xxhash {item['xxhash']}, received {xxhash}"
raise ValueError(err)
except FileNotFoundError:
log.debug("Aborting partial update, file not found: %s", rpath)
return None
raise

return item

Expand Down
32 changes: 17 additions & 15 deletions umu/umu_proton.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import os
import time
from concurrent.futures import ThreadPoolExecutor
from concurrent.futures import ALL_COMPLETED, FIRST_EXCEPTION, ThreadPoolExecutor
from concurrent.futures import wait as futures_wait
from enum import Enum
from hashlib import sha512
from http import HTTPStatus
from importlib.util import find_spec
from itertools import chain
from pathlib import Path
from re import split as resplit
from shutil import move
Expand Down Expand Up @@ -575,24 +577,21 @@ def _get_delta(
# Apply the patch
for content in cbor["contents"]:
src: str = content["source"]

if src.startswith((ProtonVersion.GE.value, ProtonVersion.UMU.value)):
patchers.append(_apply_delta(proton, content, thread_pool))
continue

subdir: Path | None = next(umu_compat.joinpath(version).rglob(src), None)
if not subdir:
log.error("Could not find subdirectory '%s', skipping", subdir)
continue

patchers.append(_apply_delta(subdir, content, thread_pool))
renames.append((subdir, subdir.parent / content["target"]))

# Wait for results and rename versioned subdirectories
start: float = time.time_ns()
for patcher in filter(None, patchers):
for future in filter(None, patcher.result()):
future.result()
_, *futures = patcher.result()
futures_wait(list(chain.from_iterable(futures)), return_when=ALL_COMPLETED)

for rename in renames:
orig, new = rename
Expand All @@ -614,25 +613,28 @@ def _apply_delta(
thread_pool: ThreadPoolExecutor,
) -> CustomPatcher | None:
patcher: CustomPatcher = CustomPatcher(content, path, thread_pool)
is_updated: bool = False

# Verify the identity of the build. At this point the patch file is authenticated.
# Note, this will skip the update if the user had tinkered with their build. We do
# this so we can ensure the result of each binary patch isn't garbage
patcher.verify_integrity()

for item in patcher.result():
if item.result() is None:
is_updated = True
break
# Handle tasks that failed metadata validation. On success, skip waiting for results
futures, *_ = patcher.result()
done, not_done = futures_wait(futures, return_when=FIRST_EXCEPTION)
for future in done:
try:
future.result()
except (FileNotFoundError, ValueError) as e:
log.exception(e)
for future in not_done:
future.cancel()
return None

if is_updated:
log.debug("%s (latest) validation failed, skipping", os.environ["PROTONPATH"])
return None
futures_wait(not_done, return_when=ALL_COMPLETED)

# Patch the current build, upgrading proton to the latest
log.info("%s is OK, applying partial update...", os.environ["PROTONPATH"])

patcher.update_binaries()
patcher.add_binaries()
patcher.delete_binaries()
Expand Down
Loading