Skip to content

Commit

Permalink
do not include pth and bin files as well
Browse files Browse the repository at this point in the history
  • Loading branch information
ackizilkale committed Jan 14, 2025
1 parent 8eec367 commit 53131cb
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions clarifai/runners/utils/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def validate_hftoken(cls, hf_token: str):
def download_checkpoints(self, checkpoint_path: str):
# throw error if huggingface_hub wasn't installed
try:
from huggingface_hub import snapshot_download
from huggingface_hub import list_repo_files, snapshot_download
except ImportError:
raise ImportError(self.HF_DOWNLOAD_TEXT)
if os.path.exists(checkpoint_path) and self.validate_download(checkpoint_path):
Expand All @@ -52,11 +52,17 @@ def download_checkpoints(self, checkpoint_path: str):
if not is_hf_model_exists:
logger.error("Model %s not found on Hugging Face" % (self.repo_id))
return False

ignore_patterns = None # Download everything.
repo_files = list_repo_files(repo_id=self.repo_id, token=self.token)
if any(f.endswith(".safetensors") for f in repo_files):
logger.info(f"SafeTensors found in {self.repo_id}, downloading only .safetensors files.")
ignore_patterns = ["original/*", "*.pth", "*.bin"]
snapshot_download(
repo_id=self.repo_id,
local_dir=checkpoint_path,
local_dir_use_symlinks=False,
ignore_patterns="original/*")
ignore_patterns=ignore_patterns)
except Exception as e:
logger.error(f"Error downloading model checkpoints {e}")
return False
Expand Down

0 comments on commit 53131cb

Please sign in to comment.