Skip to content

Commit

Permalink
add download models from www.modelscope.cn (lm-sys#2830)
Browse files Browse the repository at this point in the history
Co-authored-by: mulin.lyh <[email protected]>
  • Loading branch information
2 people authored and zhanghao.smooth committed Jan 26, 2024
1 parent 246644a commit cdcecfd
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 0 deletions.
13 changes: 13 additions & 0 deletions fastchat/model/model_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,19 @@ def load_model(
if dtype is not None: # Overwrite dtype if it is provided in the arguments.
kwargs["torch_dtype"] = dtype

if os.environ.get("FASTCHAT_USE_MODELSCOPE", "False").lower() == "true":
# download model from ModelScope hub,
# lazy import so that modelscope is not required for normal use.
try:
from modelscope.hub.snapshot_download import snapshot_download

model_path = snapshot_download(model_id=model_path, revision=revision)
except ImportError as e:
warnings.warn(
"Use model from www.modelscope.cn need pip install modelscope"
)
raise e

# hackable
if isinstance(adapter,RerankAdapter):
model, tokenizer = adapter.load_model(model_path, {"device": device})
Expand Down
3 changes: 3 additions & 0 deletions fastchat/serve/model_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def __init__(
device: str,
num_gpus: int,
max_gpu_memory: str,
revision: str = None,
dtype: Optional[torch.dtype] = None,
load_8bit: bool = False,
cpu_offloading: bool = False,
Expand Down Expand Up @@ -77,6 +78,7 @@ def __init__(
logger.info(f"Loading the model {self.model_names} on worker {worker_id} ...")
self.model, self.tokenizer = load_model(
model_path,
revision=revision,
device=device,
num_gpus=num_gpus,
max_gpu_memory=max_gpu_memory,
Expand Down Expand Up @@ -391,6 +393,7 @@ def create_model_worker():
args.model_path,
args.model_names,
args.limit_worker_concurrency,
revision=args.revision,
no_register=args.no_register,
device=args.device,
num_gpus=args.num_gpus,
Expand Down

0 comments on commit cdcecfd

Please sign in to comment.