Skip to content

Commit

Permalink
Add support for tensors/heads not divisible by GPUs
Browse files Browse the repository at this point in the history
  • Loading branch information
siddartha-RE committed Feb 3, 2025
1 parent 325f679 commit 641fb85
Show file tree
Hide file tree
Showing 9 changed files with 196 additions and 93 deletions.
16 changes: 7 additions & 9 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -690,11 +690,6 @@ def verify_with_parallel_config(
total_num_attention_heads = getattr(self.hf_text_config,

Check failure on line 690 in vllm/config.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (F841)

vllm/config.py:690:9: F841 Local variable `total_num_attention_heads` is assigned to but never used
"num_attention_heads", 0)
tensor_parallel_size = parallel_config.tensor_parallel_size

Check failure on line 692 in vllm/config.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (F841)

vllm/config.py:692:9: F841 Local variable `tensor_parallel_size` is assigned to but never used
if total_num_attention_heads % tensor_parallel_size != 0:
raise ValueError(
f"Total number of attention heads ({total_num_attention_heads})"
" must be divisible by tensor parallel size "
f"({tensor_parallel_size}).")

pipeline_parallel_size = parallel_config.pipeline_parallel_size
if pipeline_parallel_size > 1:
Expand Down Expand Up @@ -745,7 +740,8 @@ def is_deepseek_mla(self) -> bool:

def get_head_size(self) -> int:
# TODO remove hard code
if self.is_deepseek_mla:
if self.is_deepseek_mla or hasattr(self.hf_text_config,
'qk_nope_head_dim'):
if self.use_mla:
return self.hf_text_config.kv_lora_rank
else:
Expand Down Expand Up @@ -822,13 +818,15 @@ def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int:
# the tensor parallel size. We will replicate the KV heads in the
# case where the number of KV heads is smaller than the tensor
# parallel size so each GPU has at least one KV head.
return max(1,
total_num_kv_heads // parallel_config.tensor_parallel_size)
return max(1, (total_num_kv_heads
+ parallel_config.tensor_parallel_size - 1) //
parallel_config.tensor_parallel_size)

def get_num_attention_heads(self,
parallel_config: "ParallelConfig") -> int:
num_heads = getattr(self.hf_text_config, "num_attention_heads", 0)
return num_heads // parallel_config.tensor_parallel_size
return ((num_heads + parallel_config.tensor_parallel_size - 1) //
parallel_config.tensor_parallel_size)

def get_layers_start_end_indices(
self, parallel_config: "ParallelConfig") -> Tuple[int, int]:
Expand Down
16 changes: 16 additions & 0 deletions vllm/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -1149,6 +1149,22 @@ def get_tensor_model_parallel_rank():
return get_tp_group().rank_in_group


def get_assigned_range(num_elements: int, tp_chunk: int = 1):
assert num_elements % tp_chunk == 0, 'Chunk size must divide the elements.'
num_elements = num_elements // tp_chunk

tp_rank = get_tensor_model_parallel_rank()
tp_world_size = get_tp_group().world_size
base_elements_per_rank = num_elements // tp_world_size
extra_elements = num_elements % tp_world_size
# Ranks < extra_elements get one extra element
elements_per_rank = base_elements_per_rank + (1 if tp_rank < extra_elements else 0)
start = (tp_rank * base_elements_per_rank + min(tp_rank, extra_elements))
end = start + elements_per_rank
assert tp_rank < tp_world_size - 1 or end == num_elements
return tp_chunk * start, tp_chunk * end


def destroy_model_parallel():
"""Set the groups to none and destroy them."""
global _TP
Expand Down
45 changes: 25 additions & 20 deletions vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

import torch

from vllm.distributed import (get_tensor_model_parallel_rank,
from vllm.distributed import (get_assigned_range,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.logger import init_logger
Expand Down Expand Up @@ -257,6 +258,7 @@ def __init__(
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
tp_chunk: int = 1
):
super().__init__()

Expand All @@ -265,10 +267,9 @@ def __init__(

self.tp_size = (tp_size if tp_size is not None else
get_tensor_model_parallel_world_size())
self.tp_chunk = 1
self.top_k = top_k
self.num_experts = num_experts
assert intermediate_size % self.tp_size == 0
self.intermediate_size_per_partition = intermediate_size // self.tp_size
self.reduce_results = reduce_results
self.renormalize = renormalize
self.use_grouped_topk = use_grouped_topk
Expand All @@ -289,8 +290,13 @@ def __init__(
UnquantizedFusedMoEMethod())
else:
self.quant_method = quant_config.get_quant_method(self, prefix)
assert self.quant_method.tp_chunk(0) == self.quant_method.tp_chunk(1)
self.tp_chunk = self.quant_method.tp_chunk(0)
assert self.quant_method is not None

start_idx, end_idx = get_assigned_range(intermediate_size, self.tp_chunk)
self.intermediate_start = start_idx
self.intermediate_size_per_partition = end_idx - start_idx
moe_quant_params = {
"num_experts": num_experts,
"hidden_size": hidden_size,
Expand Down Expand Up @@ -326,15 +332,15 @@ def _load_model_weight_or_group_weight_scale(self,
expert_data: torch.Tensor,
shard_id: str,
loaded_weight: torch.Tensor,
tp_rank: int,
tp_start: int,
load_full_w2: bool = False):
"""
Load grouped weight scales for group quantization or model weights
:param shard_dim: dimension to shard
:param expert_data: parameter for a particular expert
:param shard_id: either w1, w2, or w3
:param loaded_weight: checkpoint weight to load into the param
:param tp_rank: tensor parallel rank
:param tp_start: tensor parallel slice start
:param load_full_w2: whether or not the w2 loaded should be sharded.
"""
if shard_id == "w2":
Expand All @@ -343,19 +349,19 @@ def _load_model_weight_or_group_weight_scale(self,
self._load_w2(shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank,
tp_start=tp_start,
load_full=load_full_w2)
elif shard_id in ("w1", "w3"):
self._load_w13(shard_id=shard_id,
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank)
tp_start=tp_start)

def _load_per_channel_weight_scale(self, expert_data: torch.Tensor,
shard_dim: int, shard_id: str,
loaded_weight: torch.Tensor,
tp_rank: int):
tp_start: int):
# for per channel weight quantization
if shard_id == "w2":
expert_data.copy_(loaded_weight)
Expand All @@ -364,16 +370,15 @@ def _load_per_channel_weight_scale(self, expert_data: torch.Tensor,
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank)
tp_start=tp_start)

def _load_w13(self, expert_data: torch.Tensor, shard_dim: int,
shard_id: str, loaded_weight: torch.Tensor, tp_rank: int):
shard_id: str, loaded_weight: torch.Tensor, tp_start: int):

# Index the loaded weight for tp sharding.
# gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
shard_size = expert_data.shape[shard_dim] // 2
loaded_weight = loaded_weight.narrow(shard_dim, shard_size * tp_rank,
shard_size)
loaded_weight = loaded_weight.narrow(shard_dim, tp_start, shard_size)
# Narrow parameter and load.
# w1, gate_proj: Load into first logical weight of w13.
if shard_id == "w1":
Expand All @@ -388,7 +393,7 @@ def _load_w2(self,
expert_data: torch.Tensor,
shard_dim: int,
loaded_weight: torch.Tensor,
tp_rank: int,
tp_start: int,
load_full: bool = False):

# Index the loaded weight for tp sharding.
Expand All @@ -397,7 +402,7 @@ def _load_w2(self,
shard_size = expert_data.shape[shard_dim]
if not load_full:
loaded_weight = loaded_weight.narrow(shard_dim,
shard_size * tp_rank,
tp_start,
shard_size)
# w2, down_proj: Load into only logical weight of w2.
expert_data.copy_(loaded_weight)
Expand All @@ -410,13 +415,13 @@ def _load_single_value(self, param: torch.nn.Parameter,
param_data[expert_id] = loaded_weight

def _load_g_idx(self, shard_id: str, expert_data: torch.Tensor,
shard_dim: int, loaded_weight: torch.Tensor, tp_rank: int):
shard_dim: int, loaded_weight: torch.Tensor, tp_start: int):

if shard_id == "w2":
self._load_w2(shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank)
tp_start=tp_start)
else:
assert shard_id in ("w1", "w3")
expert_data.copy_(loaded_weight)
Expand Down Expand Up @@ -478,7 +483,7 @@ def weight_loader(self, param: torch.nn.Parameter,
shard_id=shard_id,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank)
tp_start=self.intermediate_start)
return

# Case weight scales and zero_points
Expand All @@ -495,7 +500,7 @@ def weight_loader(self, param: torch.nn.Parameter,
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank)
tp_start=self.intermediate_start)
elif quant_method in [
FusedMoeWeightScaleSupported.GROUP.value,
FusedMoeWeightScaleSupported.BLOCK.value,
Expand All @@ -505,7 +510,7 @@ def weight_loader(self, param: torch.nn.Parameter,
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank,
tp_start=self.intermediate_start // self.tp_chunk,
load_full_w2=getattr(param, "load_full_w2", False))
elif quant_method == FusedMoeWeightScaleSupported.TENSOR.value:
self._load_per_tensor_weight_scale(shard_id=shard_id,
Expand All @@ -532,7 +537,7 @@ def weight_loader(self, param: torch.nn.Parameter,
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank)
tp_start=self.intermediate_start)
return

@staticmethod
Expand Down
Loading

0 comments on commit 641fb85

Please sign in to comment.