diff --git a/pygit2/branches.py b/pygit2/branches.py index 21d887ad..4d637da5 100644 --- a/pygit2/branches.py +++ b/pygit2/branches.py @@ -27,7 +27,7 @@ from typing import TYPE_CHECKING -from ._pygit2 import Commit, Oid +from ._pygit2 import Branch, Commit, Oid, Reference from .enums import BranchType, ReferenceType # Need BaseRepository for type hints, but don't let it cause a circular dependency @@ -37,7 +37,10 @@ class Branches: def __init__( - self, repository: BaseRepository, flag: BranchType = BranchType.ALL, commit=None + self, + repository: BaseRepository, + flag: BranchType = BranchType.ALL, + commit: Commit | Oid | str | None = None, ): self._repository = repository self._flag = flag @@ -76,13 +79,13 @@ def __iter__(self): if self._commit is None or self.get(branch_name) is not None: yield branch_name - def create(self, name: str, commit, force=False): + def create(self, name: str, commit: Commit, force: bool = False): return self._repository.create_branch(name, commit, force) def delete(self, name: str): self[name].delete() - def _valid(self, branch): + def _valid(self, branch: Branch | Reference): if branch.type == ReferenceType.SYMBOLIC: branch = branch.resolve() @@ -92,9 +95,9 @@ def _valid(self, branch): or self._repository.descendant_of(branch.target, self._commit) ) - def with_commit(self, commit): + def with_commit(self, commit: Commit | Oid | str | None): assert self._commit is None return Branches(self._repository, self._flag, commit) - def __contains__(self, name): + def __contains__(self, name: str): return self.get(name) is not None