diff --git a/config.py b/config.py index 0743be36..c83e195d 100644 --- a/config.py +++ b/config.py @@ -1,46 +1,37 @@ """Load configuration settings for protonfixes""" -import os -from configparser import ConfigParser +from config_base import ConfigBase +from dataclasses import dataclass +from pathlib import Path -try: - from .logger import log -except ImportError: - from logger import log +class Config(ConfigBase): + """Configuration for umu-protonfix""" + @dataclass + class MainSection: + """General parameters + + Attributes: + enable_checks (bool): Run checks (`checks.py`) before the fix is executed. + enable_global_fixes (bool): Enables included fixes. If deactivated, only local fixes (`~/.config/protonfixes/localfixes`) are executed. -CONF_FILE = '~/.config/protonfixes/config.ini' -DEFAULT_CONF = """ -[main] -enable_checks = true -enable_splash = false -enable_global_fixes = true + """ + enable_checks: bool = True + enable_global_fixes: bool = True -[path] -cache_dir = ~/.cache/protonfixes -""" + @dataclass + class PathSection: + """Path parameters -CONF = ConfigParser() -CONF.read_string(DEFAULT_CONF) + Attributes: + cache_dir (Path): The path that should be used to create temporary and cached files. -try: - CONF.read(os.path.expanduser(CONF_FILE)) + """ -except Exception: - log.debug('Unable to read config file ' + CONF_FILE) + cache_dir: Path = Path.home() / '.cache/protonfixes' + main: MainSection + path: PathSection -def opt_bool(opt: str) -> bool: - """Convert bool ini strings to actual boolean values""" - return opt.lower() in ['yes', 'y', 'true', '1'] - - -locals().update({x: opt_bool(y) for x, y in CONF['main'].items() if 'enable' in x}) - -locals().update({x: os.path.expanduser(y) for x, y in CONF['path'].items()}) - -try: - [os.makedirs(os.path.expanduser(d)) for n, d in CONF['path'].items()] -except OSError: - pass +config = Config(Path.home() / '.config/protonfixes/config.ini') diff --git a/config_base.py b/config_base.py new file mode 100644 index 00000000..242c9c77 --- /dev/null +++ b/config_base.py @@ -0,0 +1,186 @@ +"""Load configuration settings for protonfixes""" + +import re + +from configparser import ConfigParser +from dataclasses import is_dataclass +from pathlib import Path + +from typing import Any +from collections.abc import Callable + +from logger import log + +class ConfigBase: + """Base class for configuration objects. + + This reflects a given config file and populates the object with it's values. + It also injects attributes from the sub classes, this isn't compatible with static type checking though. + You can define the attributes accordingly to satisfy type checkers. + """ + + __CAMEL_CASE_PATTERN: re.Pattern = re.compile('((?<=[a-z0-9])[A-Z]|(?!^)[A-Z](?=[a-z]))') + + @classmethod + def snake_case(cls, input: str) -> str: + """Converts CamelCase to snake_case. + + Args: + input (str): The string to convert. + + Returns: + str: The converted string. + + """ + return cls.__CAMEL_CASE_PATTERN.sub(r'_\1', input).lower() + + + @staticmethod + def __log(message: str, level: str = 'INFO') -> None: + log.log(f'[CONFIG]: {message}', level) + + + def __init__(self, path: Path) -> None: + """Initialize the instance from a given config file. + + Defaults will be used if the file doesn't exist. + The file will also be created in this case. + + Args: + path (Path): The reflected config file's path. + + Raises: + IsADirectoryError: If the path exists, but isn't a file. + + """ + assert path + if path.is_file(): + self.parse_config_file(path) + elif not path.exists(): + self.init_sections() + self.write_config_file(path) + else: + raise IsADirectoryError(f'Given path "{path.absolute()}" exists, but is not a file.') + + + def init_sections(self, force: bool = False) -> None: + """Find sub-classes and initialize them as attributes. + + Sub-classes are initialized and injected as attributes. + Example: `MainSection` will be injected as `main` to the config (this) object. + + Args: + force (bool, optional): Force initialization? This results in a reset. Defaults to False. + + """ + for (member_name, member) in self.__class__.__dict__.items(): + # Find non private section definitions + if not member_name.endswith('Section') or member_name.startswith('_'): + continue + if not is_dataclass(member): + continue + + # Convert section definition class name to variable name (MyCfgSection -> my_cfg) + section_name = member_name.removesuffix('Section') + section_name = self.snake_case(section_name) + + # Do not override existing members by default + if hasattr(self, section_name) and not force: + continue + + # Initialize section class as a member + setattr(self, section_name, member()) # pyright: ignore [reportCallIssue] + + + def parse_config_file(self, file: Path) -> bool: + """Parse a config file. + + This resets the data in the sections, regardless if the file exists or is loaded. + + Args: + file (Path): The reflected config file's path. + + Returns: + bool: True, if the config file was successfully loaded. + + """ + # Initialize / reset sections to defaults + self.init_sections(True) + + # Only precede if the config file exists + if not file.is_file(): + return False + + try: + parser = ConfigParser() + parser.read(file) + + # Iterate over local config section objects + for (section_name, section) in self.__dict__.items(): + if not parser.has_section(section_name): + continue + + parser_items = parser[section_name] + + # FIXME: match is not supported in Python 3.9 + def _get_parse_function(type_name: str) -> Callable[[str], Any]: + # Mapping of type_name to according value get function + value = { + 'int': parser_items.getint, + 'float': parser_items.getfloat, + 'bool': parser_items.getboolean, + 'Path': lambda option: Path(parser_items.get(option, '')), + 'PosixPath': lambda option: Path(parser_items.get(option, '')), + 'str': parser_items.get + }.get(type_name, None) + if not value: + value = parser_items.get + self.__log(f'Unknown type "{type_name}", falling back to "str".', 'WARN') + return value + + # Iterate over the option objects in this section + for (option_name, option_item) in section.__dict__.items(): + # Get values from config and set it on object + type_name = type(option_item).__name__ + func = _get_parse_function(type_name) + value = func(option_name) + setattr(section, option_name, value) + except Exception as ex: + self.__log(f'Failed to parse config file "{file}". Exception: "{ex}"', 'CRIT') + return False + return True + + + def write_config_file(self, file: Path) -> bool: + """Write the current config to a file. + + Args: + file (Path): The file path to write to. + + Returns: + bool: True, if the file was successfully written. + + """ + # Only precede if the parent directory exists + if not file.parent.is_dir(): + self.__log(f'Parent directory "{file.parent}" does not exist. Abort.', 'WARN') + return False + + # Create and populate ConfigParser + try: + parser = ConfigParser() + # Iterate over local config section objects + for (section_name, section_item) in self.__dict__.items(): + if not parser.has_section(section_name): + parser.add_section(section_name) + + for (option_name, option_item) in section_item.__dict__.items(): + parser.set(section_name, option_name, str(option_item)) + + # Write config file + with file.open(mode='w') as stream: + parser.write(stream) + except Exception as ex: + self.__log(f'Failed to create config file "{file}". Exception: "{ex}"', 'CRIT') + return False + return True diff --git a/fix.py b/fix.py index 825833a5..29102475 100644 --- a/fix.py +++ b/fix.py @@ -8,11 +8,11 @@ from importlib import import_module try: - from . import config + from .config import config from .checks import run_checks from .logger import log except ImportError: - import config + from config import config from checks import run_checks from logger import log @@ -175,15 +175,15 @@ def run_fix(game_id: str) -> None: if game_id is None: return - if config.enable_checks: + if config.main.enable_checks: run_checks() # execute default.py (local) - if not _run_fix_local(game_id, True) and config.enable_global_fixes: + if not _run_fix_local(game_id, True) and config.main.enable_global_fixes: _run_fix(game_id, True) # global # execute .py (local) - if not _run_fix_local(game_id, False) and config.enable_global_fixes: + if not _run_fix_local(game_id, False) and config.main.enable_global_fixes: _run_fix(game_id, False) # global diff --git a/util.py b/util.py index 6b1c0d6c..0e02464e 100644 --- a/util.py +++ b/util.py @@ -18,9 +18,11 @@ try: from .logger import log + from .config import config from .steamhelper import install_app except ImportError: from logger import log + from config import config from steamhelper import install_app try: @@ -428,8 +430,7 @@ def patch_libcuda() -> bool: Returns true if the library was patched correctly. Otherwise returns false """ - cache_dir = os.path.expanduser('~/.cache/protonfixes') - os.makedirs(cache_dir, exist_ok=True) + config.path.cache_dir.mkdir(parents=True, exist_ok=True) try: # Use shutil.which to find ldconfig binary @@ -472,10 +473,9 @@ def patch_libcuda() -> bool: log.info(f'Found 64-bit libcuda.so at: {libcuda_path}') - patched_library = os.path.join(cache_dir, 'libcuda.patched.so') + patched_library = config.path.cache_dir / 'libcuda.patched.so' try: - with open(libcuda_path, 'rb') as f: - binary_data = f.read() + binary_data = patched_library.read_bytes() except OSError as e: log.crit(f'Unable to read libcuda.so: {e}') return False @@ -496,11 +496,10 @@ def patch_libcuda() -> bool: patched_binary_data = bytes.fromhex(hex_data) try: - with open(patched_library, 'wb') as f: - f.write(patched_binary_data) + patched_library.write_bytes(patched_binary_data) # Set permissions to rwxr-xr-x (755) - os.chmod(patched_library, 0o755) + patched_library.chmod(0o755) log.debug(f'Permissions set to rwxr-xr-x for {patched_library}') except OSError as e: log.crit(f'Unable to write patched libcuda.so to {patched_library}: {e}') @@ -810,12 +809,12 @@ def install_battleye_runtime() -> None: def install_all_from_tgz(url: str, path: str = os.getcwd()) -> None: """Install all files from a downloaded tar.gz""" - cache_dir = os.path.expanduser('~/.cache/protonfixes') - os.makedirs(cache_dir, exist_ok=True) + config.path.cache_dir.mkdir(parents=True, exist_ok=True) + tgz_file_name = os.path.basename(url) - tgz_file_path = os.path.join(cache_dir, tgz_file_name) + tgz_file_path = config.path.cache_dir / tgz_file_name - if tgz_file_name not in os.listdir(cache_dir): + if not tgz_file_path.is_file(): log.info('Downloading ' + tgz_file_name) urllib.request.urlretrieve(url, tgz_file_path) @@ -830,12 +829,12 @@ def install_from_zip(url: str, filename: str, path: str = os.getcwd()) -> None: log.info(f'File {filename} found in {path}') return - cache_dir = os.path.expanduser('~/.cache/protonfixes') - os.makedirs(cache_dir, exist_ok=True) + config.path.cache_dir.mkdir(parents=True, exist_ok=True) + zip_file_name = os.path.basename(url) - zip_file_path = os.path.join(cache_dir, zip_file_name) + zip_file_path = config.path.cache_dir / zip_file_name - if zip_file_name not in os.listdir(cache_dir): + if not zip_file_path.is_file(): log.info(f'Downloading {filename} to {zip_file_path}') urllib.request.urlretrieve(url, zip_file_path)