PeftModel work well on python backend? #6158
Replies: 14 comments
-
Hi @YooSungHyun , If I understand correctly, you are running into an issue that |
Beta Was this translation helpful? Give feedback.
-
No, this statement is not working i think. self.model.set_adapter("bb") model config name: MODEL_NAME (model name is secret)
backend: "python"
max_batch_size: 0
parameters [
{
key: "gpu_shard"
value: {string_value: "0"}
},
{
key: "low_cpu_mem_usage"
value: {string_value: "1"}
},
{
key: "load_in_8bit"
value: {string_value: "1"}
},
{
key: "torch_dtype"
value: {string_value: "bfloat16"}
}
]
input [
{
name: "sequence"
dims: [1]
data_type: TYPE_STRING
},
{
name: "temperature"
dims: [1]
data_type: TYPE_FP32
},
{
name: "max_new_tokens"
dims: [1]
data_type: TYPE_INT16
},
{
name: "no_repeat_ngram_size"
dims: [1]
data_type: TYPE_INT16
},
{
name: "num_beam"
dims: [1]
data_type: TYPE_INT16
},
{
name: "top_k"
dims: [1]
data_type: TYPE_INT16
},
{
name: "top_p"
dims: [1]
data_type: TYPE_FP32
},
{
name: "length_penalty"
dims: [1]
data_type: TYPE_FP32
},
{
name: "repetition_penalty"
dims: [1]
data_type: TYPE_FP32
},
{
name: "do_sample"
dims: [1]
data_type: TYPE_BOOL
},
{
name: "eos_token_id"
dims: [1,-1]
data_type: TYPE_INT16
},
{
name: "model_name"
dims: [1]
data_type: TYPE_STRING
}
]
output [
{
name: "messages"
data_type: TYPE_STRING
dims: [1]
}
]
instance_group [
{
count: 1
kind: KIND_GPU
gpus: [ 0 ]
}
] json for curl (curl -X POST -H "Content-Type: application/json" -d @this.json localhost:8000/v2/models/MODEL_NAME/infer) {
"id": "0",
"inputs": [
{
"name": "sequence",
"shape": [
1
],
"datatype": "BYTES",
"parameters": {},
"data": [
INPUT_TEXT(skip)
]
},
{
"name": "temperature",
"shape": [
1
],
"datatype": "FP32",
"parameters": {},
"data": [
0.7
]
},
{
"name": "max_new_tokens",
"shape": [
1
],
"datatype": "INT16",
"parameters": {},
"data": [
128
]
},
{
"name": "no_repeat_ngram_size",
"shape": [
1
],
"datatype": "INT16",
"parameters": {},
"data": [
0
]
},
{
"name": "num_beam",
"shape": [
1
],
"datatype": "INT16",
"parameters": {},
"data": [
1
]
},
{
"name": "top_k",
"shape": [
1
],
"datatype": "INT16",
"parameters": {},
"data": [
50
]
},
{
"name": "top_p",
"shape": [
1
],
"datatype": "FP32",
"parameters": {},
"data": [
0.95
]
},
{
"name": "length_penalty",
"shape": [
1
],
"datatype": "FP32",
"parameters": {},
"data": [
0.0
]
},
{
"name": "repetition_penalty",
"shape": [
1
],
"datatype": "FP32",
"parameters": {},
"data": [
1.0
]
},
{
"name": "do_sample",
"shape": [
1
],
"datatype": "BOOL",
"parameters": {},
"data": [
true
]
},
{
"name": "eos_token_id",
"shape": [
1,
3
],
"datatype": "INT16",
"parameters": {},
"data": [
0,
1,
2
]
},
{
"name": "model_name",
"shape": [
1
],
"datatype": "BYTES",
"parameters": {},
"data": [
"bb"
]
}
]
} maybe i guess.... PeftModelForCausalLM's generete function is not working on python backend...? model_name variable is kinda confuse... so, i replace that name to self.model.set_adapter("bb")
model_output = self.model.generate(input_ids=input_ids.cuda(), generation_config=generation_config)[0] is same result
so i think not working |
Beta Was this translation helpful? Give feedback.
-
I found problem.... but i don't know why raise this problem by torch.compile....🤔 model = AutoModelForCausalLM.from_pretrained(
optional_config["model_path"],
low_cpu_mem_usage=bool(strtobool(optional_config["low_cpu_mem_usage"])),
load_in_8bit=bool(strtobool(optional_config["load_in_8bit"])),
torch_dtype=getattr(torch, optional_config["torch_dtype"], None),
device_map=device_map,
)
# model = torch.compile(model)
self.model = PeftModel.from_pretrained(
model,
optional_config["aa_lora_weights"],
adapter_name=optional_config["aa_lora_name"],
device_map=device_map,
)
self.model.load_adapter(
optional_config["bb_lora_weights"],
adapter_name=optional_config["bb_lora_name"],
device_map=device_map,
)
self.model.eval() |
Beta Was this translation helpful? Give feedback.
-
@YooSungHyun Thank you for sharing. You are using PyTorch 2.0 right? This is an interesting find. I am not sure if it is Triton-related issue, but I will file a ticket to investigate. |
Beta Was this translation helpful? Give feedback.
-
@oandreeva-nv yes pytorch 2.0.1 i used. |
Beta Was this translation helpful? Give feedback.
-
if i use like this. model = AutoModelForCausalLM.from_pretrained(
optional_config["model_path"],
low_cpu_mem_usage=bool(strtobool(optional_config["low_cpu_mem_usage"])),
load_in_8bit=bool(strtobool(optional_config["load_in_8bit"])),
torch_dtype=getattr(torch, optional_config["torch_dtype"], None),
device_map=device_map,
)
self.model = PeftModel.from_pretrained(
model,
optional_config["aa_lora_weights"],
adapter_name=optional_config["aa_lora_name"],
device_map=device_map,
)
self.model.load_adapter(
optional_config["bb_lora_weights"],
adapter_name=optional_config["bb_lora_name"],
device_map=device_map,
)
model = torch.compile(model)
self.model.eval() it work fine |
Beta Was this translation helpful? Give feedback.
-
@YooSungHyun I've noticed, that in the original post, you didn't use So to summarize: |
Beta Was this translation helpful? Give feedback.
-
@oandreeva-nv exactly right |
Beta Was this translation helpful? Give feedback.
-
@YooSungHyun Can you share the code for deploy PEFTmodel with python backend? |
Beta Was this translation helpful? Give feedback.
-
import os
import torch
import logging
import numpy as np
import triton_python_backend_utils as pb_utils
from setproctitle import setproctitle
from transformers import AutoModelForCausalLM, PreTrainedTokenizerFast, GenerationConfig
from peft import PeftModel
from distutils.util import strtobool
import time
setproctitle("tritonchild-LLM")
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s", datefmt="%m/%d/%Y %I:%M:%S %p")
stream_handler = logging.StreamHandler()
stream_handler.setFormatter(formatter)
logger.addHandler(stream_handler)
def np_bytes_to_str_list(inputs: np.ndarray):
# https://github.com/triton-inference-server/server/issues/4996
assert os.environ["PYTHONIOENCODING"] == "UTF-8", "Must export `PYTHONIOENCODING=UTF-8`"
return np.char.decode(inputs.astype("bytes"), "utf-8").tolist()
class TritonPythonModel:
def initialize(self, args):
device_id = int(args["model_instance_device_id"])
model_config = pb_utils.ModelConfig(args["model_config"]).as_dict()
parameters = model_config["parameters"]
optional_config = dict()
for key, value in parameters.items():
optional_config[key] = value["string_value"]
logger.info(f"optional_config: {optional_config}")
if strtobool(optional_config["gpu_shard"]):
device_map = "auto"
else:
device_map = {"": device_id}
self.tokenizer = PreTrainedTokenizerFast.from_pretrained(
optional_config["model_path"],
model_max_length=2048,
truncation_side="left",
)
model = AutoModelForCausalLM.from_pretrained(
optional_config["model_path"],
low_cpu_mem_usage=bool(strtobool(optional_config["low_cpu_mem_usage"])),
load_in_8bit=bool(strtobool(optional_config["load_in_8bit"])),
torch_dtype=getattr(torch, optional_config["torch_dtype"], None),
device_map=device_map,
)
model = PeftModel.from_pretrained(
model,
optional_config["qa_lora_weights"],
adapter_name=optional_config["qa_lora_name"],
device_map=device_map,
)
self.model = torch.compile(model)
self.model.eval()
def execute(self, requests):
responses = []
for request in requests:
start = time.time()
sequence = np_bytes_to_str_list(pb_utils.get_input_tensor_by_name(request, "sequence").as_numpy())
temperature = pb_utils.get_input_tensor_by_name(request, "temperature").as_numpy().item(0)
max_new_tokens = pb_utils.get_input_tensor_by_name(request, "max_new_tokens").as_numpy().item(0)
no_repeat_ngram_size = (
pb_utils.get_input_tensor_by_name(request, "no_repeat_ngram_size").as_numpy().item(0)
)
num_beam = pb_utils.get_input_tensor_by_name(request, "num_beam").as_numpy().item(0)
top_k = pb_utils.get_input_tensor_by_name(request, "top_k").as_numpy().item(0)
top_p = pb_utils.get_input_tensor_by_name(request, "top_p").as_numpy().item(0)
length_penalty = pb_utils.get_input_tensor_by_name(request, "length_penalty").as_numpy().item(0)
repetition_penalty = pb_utils.get_input_tensor_by_name(request, "repetition_penalty").as_numpy().item(0)
do_sample = pb_utils.get_input_tensor_by_name(request, "do_sample").as_numpy().item(0)
eos_token_id = pb_utils.get_input_tensor_by_name(request, "eos_token_id").as_numpy()
task_name = np_bytes_to_str_list(pb_utils.get_input_tensor_by_name(request, "task_name").as_numpy())
generation_config = GenerationConfig(
temperature=temperature, # 0.5
max_new_tokens=max_new_tokens,
no_repeat_ngram_size=no_repeat_ngram_size,
num_beams=num_beam,
num_return_sequences=1,
top_k=top_k,
top_p=top_p,
renormalize_logits=True,
length_penalty=length_penalty,
repetition_penalty=repetition_penalty,
do_sample=do_sample,
pad_token_id=0,
eos_token_id=eos_token_id[0].tolist(),
)
logger.info(generation_config)
input_ids = self.tokenizer(sequence, return_tensors="pt", truncation=True).input_ids
# try:
if task_name[0] == "LLM":
with self.model.disable_adapter():
model_output = self.model.generate(
input_ids=input_ids.cuda(), generation_config=generation_config
)[0]
else:
self.model.set_adapter(task_name[0])
model_output = self.model.generate(input_ids=input_ids.cuda(), generation_config=generation_config)[0]
output = self.tokenizer.decode(model_output, skip_special_tokens=False)
inference_response = pb_utils.InferenceResponse(
output_tensors=[pb_utils.Tensor("messages", np.char.encode(output, "utf-8"))]
)
responses.append(inference_response)
just_res = model_output[
(model_output == 1).nonzero(as_tuple=True)[0] + 1 : (model_output == 2).nonzero(as_tuple=True)[0]
]
just_text = self.tokenizer.decode(just_res, skip_special_tokens=False)
print(just_text)
end = time.time()
tot_time = end - start
tps = round(len(just_text) / tot_time, 2)
print(f"\n{tot_time:.5f} sec", f"({tps:.2f} T/s)\n")
return responses |
Beta Was this translation helpful? Give feedback.
-
Hi @YooSungHyun, for the
I wonder if it is working, because it differs from previous posts and looks well formed? |
Beta Was this translation helpful? Give feedback.
-
An example from PyTorch demonstrating the placement of |
Beta Was this translation helpful? Give feedback.
-
@kthui i think it just not compatible between peft and torch.compile.... |
Beta Was this translation helpful? Give feedback.
-
Hi YooSungHyun, I will convert this issue into a discussion, since you are suspecting compatibility "between peft and torch.compile" and we find |
Beta Was this translation helpful? Give feedback.
-
Description
i use PeftModel like this on python backend
but, all result is
disable_adapter
result...print(model_name[0])
-> "aa" or "bb" print wellprint(self.model.active_adapter)
-> "aa" or "bb" print wellprint(self.model)
-> PeftForCausalLM model and LoRA adapter layer is print wellif i don't use triton and just using just python inference code, that is work fine...
triton don't support peftmodel's adapter swap?
and, when if that code using on
pytriton
, it works good. but i have to usingpython backend
forgrpc streaming
andmulti instance api
Triton Information
tritonserver:23.05 not customize
To Reproduce
Huggingface's gptneoxforcalusallm and peftmodel use.
and using request input for adapter change
Expected behavior
Peft model change and inference well
Beta Was this translation helpful? Give feedback.
All reactions