Skip to content

Commit

Permalink
[Infer] Add pir_model path for server infer. (#9790)
Browse files Browse the repository at this point in the history
  • Loading branch information
aooxin authored Jan 17, 2025
1 parent d039ad2 commit fb3e4c0
Showing 1 changed file with 7 additions and 22 deletions.
29 changes: 7 additions & 22 deletions llm/server/server/server/engine/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import paddle
import paddle.distributed as dist
import paddle.distributed.fleet as fleet
from paddle.base.framework import use_pir_api
from paddlenlp.trl.llm_utils import get_rotary_position_embedding
from paddlenlp_ops import step_paddle
from server.data.processor import DataProcessor
Expand Down Expand Up @@ -467,32 +468,16 @@ def _init_predictor(self):
predictor init
"""
device_id = self.rank % 8
self.model_file = os.path.join(self.model_dir, f"model.pdmodel")
self.param_file = os.path.join(self.model_dir, f"model.pdiparams")
if use_pir_api():
self.model_file = os.path.join(self.model_dir, f"model.json")
self.param_file = os.path.join(self.model_dir, f"model.pdiparams")
else:
self.model_file = os.path.join(self.model_dir, f"model.pdmodel")
self.param_file = os.path.join(self.model_dir, f"model.pdiparams")
config = paddle.inference.Config(self.model_file, self.param_file)

config.switch_ir_optim(False)
config.enable_use_gpu(100, device_id)

# distributed config
if self.mp_degree > 1:
trainer_endpoints = fleet.worker_endpoints()
current_endpoint = trainer_endpoints[self.rank]
dist_config = config.dist_config()
dist_config.set_ranks(self.nranks, self.rank)
dist_config.set_endpoints(trainer_endpoints, current_endpoint)
dist_config.enable_dist_model(True)
if self.config.distributed_config_path:
dist_config.set_comm_init_config(self.config.distributed_config_path)
else:
raise Exception("Please set DISTRIBUTED_CONFIG env variable.")
logger.warning(
f"Use default distributed config, please set env DISTRIBUTED_CONFIG"
)
dist_config.set_comm_init_config(
os.path.join(Dir_Path + "/config", "rank_mapping_mp{}.csv".format(self.nranks)))

config.set_dist_config(dist_config)
self.predictor = paddle.inference.create_predictor(config)
self.input_names = self.predictor.get_input_names()
self.seq_lens_handle = self.predictor.get_input_handle('seq_lens_this_time')
Expand Down

0 comments on commit fb3e4c0

Please sign in to comment.