Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for tensors/heads not divisible by GPUs #1

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 5 additions & 8 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -703,14 +703,9 @@
self,
parallel_config: "ParallelConfig",
) -> None:
total_num_attention_heads = getattr(self.hf_text_config,

Check failure on line 706 in vllm/config.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (F841)

vllm/config.py:706:9: F841 Local variable `total_num_attention_heads` is assigned to but never used
"num_attention_heads", 0)
tensor_parallel_size = parallel_config.tensor_parallel_size

Check failure on line 708 in vllm/config.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (F841)

vllm/config.py:708:9: F841 Local variable `tensor_parallel_size` is assigned to but never used
if total_num_attention_heads % tensor_parallel_size != 0:
raise ValueError(
f"Total number of attention heads ({total_num_attention_heads})"
" must be divisible by tensor parallel size "
f"({tensor_parallel_size}).")

pipeline_parallel_size = parallel_config.pipeline_parallel_size
if pipeline_parallel_size > 1:
Expand Down Expand Up @@ -839,13 +834,15 @@
# the tensor parallel size. We will replicate the KV heads in the
# case where the number of KV heads is smaller than the tensor
# parallel size so each GPU has at least one KV head.
return max(1,
total_num_kv_heads // parallel_config.tensor_parallel_size)
return max(1, (total_num_kv_heads
+ parallel_config.tensor_parallel_size - 1) //
parallel_config.tensor_parallel_size)

def get_num_attention_heads(self,
parallel_config: "ParallelConfig") -> int:
num_heads = getattr(self.hf_text_config, "num_attention_heads", 0)
return num_heads // parallel_config.tensor_parallel_size
return ((num_heads + parallel_config.tensor_parallel_size - 1) //
parallel_config.tensor_parallel_size)

def get_layers_start_end_indices(
self, parallel_config: "ParallelConfig") -> Tuple[int, int]:
Expand Down
16 changes: 16 additions & 0 deletions vllm/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -1151,6 +1151,22 @@ def get_tensor_model_parallel_rank():
return get_tp_group().rank_in_group


def get_assigned_range(num_elements: int, tp_chunk: int = 1):
assert num_elements % tp_chunk == 0, 'Chunk size must divide the elements.'
num_elements = num_elements // tp_chunk

tp_rank = get_tensor_model_parallel_rank()
tp_world_size = get_tp_group().world_size
base_elements_per_rank = num_elements // tp_world_size
extra_elements = num_elements % tp_world_size
# Ranks < extra_elements get one extra element
elements_per_rank = base_elements_per_rank + (1 if tp_rank < extra_elements else 0)
start = (tp_rank * base_elements_per_rank + min(tp_rank, extra_elements))
end = start + elements_per_rank
assert tp_rank < tp_world_size - 1 or end == num_elements
return tp_chunk * start, tp_chunk * end


def destroy_model_parallel():
"""Set the groups to none and destroy them."""
global _TP
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"4": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"8": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 3
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"24": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"32": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"48": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"64": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"96": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"128": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"256": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"1536": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 3
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"4096": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 3
}
}
45 changes: 25 additions & 20 deletions vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@

import torch

from vllm.distributed import (get_tensor_model_parallel_rank,
from vllm.distributed import (get_assigned_range,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.logger import init_logger
Expand Down Expand Up @@ -259,6 +260,7 @@
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
tp_chunk: int = 1
):
super().__init__()

Expand All @@ -267,10 +269,9 @@

self.tp_size = (tp_size if tp_size is not None else
get_tensor_model_parallel_world_size())
self.tp_chunk = 1
self.top_k = top_k
self.num_experts = num_experts
assert intermediate_size % self.tp_size == 0
self.intermediate_size_per_partition = intermediate_size // self.tp_size
self.reduce_results = reduce_results
self.renormalize = renormalize
self.use_grouped_topk = use_grouped_topk
Expand All @@ -291,8 +292,13 @@
UnquantizedFusedMoEMethod())
else:
self.quant_method = quant_config.get_quant_method(self, prefix)
assert self.quant_method.tp_chunk(0) == self.quant_method.tp_chunk(1)
self.tp_chunk = self.quant_method.tp_chunk(0)
assert self.quant_method is not None

start_idx, end_idx = get_assigned_range(intermediate_size, self.tp_chunk)
self.intermediate_start = start_idx
self.intermediate_size_per_partition = end_idx - start_idx
moe_quant_params = {
"num_experts": num_experts,
"hidden_size": hidden_size,
Expand Down Expand Up @@ -328,15 +334,15 @@
expert_data: torch.Tensor,
shard_id: str,
loaded_weight: torch.Tensor,
tp_rank: int,
tp_start: int,
load_full_w2: bool = False):
"""
Load grouped weight scales for group quantization or model weights
:param shard_dim: dimension to shard
:param expert_data: parameter for a particular expert
:param shard_id: either w1, w2, or w3
:param loaded_weight: checkpoint weight to load into the param
:param tp_rank: tensor parallel rank
:param tp_start: tensor parallel slice start
:param load_full_w2: whether or not the w2 loaded should be sharded.
"""
if shard_id == "w2":
Expand All @@ -345,19 +351,19 @@
self._load_w2(shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank,
tp_start=tp_start,
load_full=load_full_w2)
elif shard_id in ("w1", "w3"):
self._load_w13(shard_id=shard_id,
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank)
tp_start=tp_start)

def _load_per_channel_weight_scale(self, expert_data: torch.Tensor,
shard_dim: int, shard_id: str,
loaded_weight: torch.Tensor,
tp_rank: int):
tp_start: int):
# for per channel weight quantization
if shard_id == "w2":
expert_data.copy_(loaded_weight)
Expand All @@ -366,16 +372,15 @@
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank)
tp_start=tp_start)

def _load_w13(self, expert_data: torch.Tensor, shard_dim: int,
shard_id: str, loaded_weight: torch.Tensor, tp_rank: int):
shard_id: str, loaded_weight: torch.Tensor, tp_start: int):

# Index the loaded weight for tp sharding.
# gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
shard_size = expert_data.shape[shard_dim] // 2
loaded_weight = loaded_weight.narrow(shard_dim, shard_size * tp_rank,
shard_size)
loaded_weight = loaded_weight.narrow(shard_dim, tp_start, shard_size)
# Narrow parameter and load.
# w1, gate_proj: Load into first logical weight of w13.
if shard_id == "w1":
Expand All @@ -390,7 +395,7 @@
expert_data: torch.Tensor,
shard_dim: int,
loaded_weight: torch.Tensor,
tp_rank: int,
tp_start: int,
load_full: bool = False):

# Index the loaded weight for tp sharding.
Expand All @@ -399,7 +404,7 @@
shard_size = expert_data.shape[shard_dim]
if not load_full:
loaded_weight = loaded_weight.narrow(shard_dim,
shard_size * tp_rank,
tp_start,
shard_size)
# w2, down_proj: Load into only logical weight of w2.
expert_data.copy_(loaded_weight)
Expand All @@ -412,13 +417,13 @@
param_data[expert_id] = loaded_weight

def _load_g_idx(self, shard_id: str, expert_data: torch.Tensor,
shard_dim: int, loaded_weight: torch.Tensor, tp_rank: int):
shard_dim: int, loaded_weight: torch.Tensor, tp_start: int):

if shard_id == "w2":
self._load_w2(shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank)
tp_start=tp_start)
else:
assert shard_id in ("w1", "w3")
expert_data.copy_(loaded_weight)
Expand Down Expand Up @@ -447,7 +452,7 @@
SHARD_ID_TO_SHARDED_DIM = {"w1": 0, "w2": 1, "w3": 0}

expert_data = param.data[expert_id]
tp_rank = get_tensor_model_parallel_rank()

Check failure on line 455 in vllm/model_executor/layers/fused_moe/layer.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (F841)

vllm/model_executor/layers/fused_moe/layer.py:455:9: F841 Local variable `tp_rank` is assigned to but never used

# is_transposed: if the dim to shard the weight
# should be flipped. Required by GPTQ, compressed-tensors
Expand Down Expand Up @@ -480,7 +485,7 @@
shard_id=shard_id,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank)
tp_start=self.intermediate_start)
return

# Case weight scales and zero_points
Expand All @@ -497,7 +502,7 @@
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank)
tp_start=self.intermediate_start)
elif quant_method in [
FusedMoeWeightScaleSupported.GROUP.value,
FusedMoeWeightScaleSupported.BLOCK.value,
Expand All @@ -507,7 +512,7 @@
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank,
tp_start=self.intermediate_start // self.tp_chunk,
load_full_w2=getattr(param, "load_full_w2", False))
elif quant_method == FusedMoeWeightScaleSupported.TENSOR.value:
self._load_per_tensor_weight_scale(shard_id=shard_id,
Expand All @@ -534,7 +539,7 @@
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank)
tp_start=self.intermediate_start)
return

@staticmethod
Expand Down
Loading