From a059d4abedd731a143a47831835b46f862cb2298 Mon Sep 17 00:00:00 2001 From: Siddartha Naidu Date: Mon, 3 Feb 2025 04:30:47 +0000 Subject: [PATCH] Add support for tensors/heads not divisible by GPUs --- vllm/config.py | 13 +- vllm/distributed/parallel_state.py | 16 +++ vllm/model_executor/layers/fused_moe/layer.py | 45 ++++--- vllm/model_executor/layers/linear.py | 127 +++++++++++++----- .../layers/quantization/base_config.py | 4 + .../model_executor/layers/quantization/fp8.py | 20 ++- vllm/model_executor/models/deepseek_v3.py | 29 ++-- vllm/model_executor/parameter.py | 26 ++-- vllm/worker/cache_engine.py | 6 +- 9 files changed, 194 insertions(+), 92 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 2f4a7ad769d98..4baf5fcafe1ef 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -706,11 +706,6 @@ def verify_with_parallel_config( total_num_attention_heads = getattr(self.hf_text_config, "num_attention_heads", 0) tensor_parallel_size = parallel_config.tensor_parallel_size - if total_num_attention_heads % tensor_parallel_size != 0: - raise ValueError( - f"Total number of attention heads ({total_num_attention_heads})" - " must be divisible by tensor parallel size " - f"({tensor_parallel_size}).") pipeline_parallel_size = parallel_config.pipeline_parallel_size if pipeline_parallel_size > 1: @@ -839,13 +834,15 @@ def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int: # the tensor parallel size. We will replicate the KV heads in the # case where the number of KV heads is smaller than the tensor # parallel size so each GPU has at least one KV head. - return max(1, - total_num_kv_heads // parallel_config.tensor_parallel_size) + return max(1, (total_num_kv_heads + + parallel_config.tensor_parallel_size - 1) // + parallel_config.tensor_parallel_size) def get_num_attention_heads(self, parallel_config: "ParallelConfig") -> int: num_heads = getattr(self.hf_text_config, "num_attention_heads", 0) - return num_heads // parallel_config.tensor_parallel_size + return ((num_heads + parallel_config.tensor_parallel_size - 1) // + parallel_config.tensor_parallel_size) def get_layers_start_end_indices( self, parallel_config: "ParallelConfig") -> Tuple[int, int]: diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index c5c5dfbbab76b..d8125e314b501 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -1151,6 +1151,22 @@ def get_tensor_model_parallel_rank(): return get_tp_group().rank_in_group +def get_assigned_range(num_elements: int, tp_chunk: int = 1): + assert num_elements % tp_chunk == 0, 'Chunk size must divide the elements.' + num_elements = num_elements // tp_chunk + + tp_rank = get_tensor_model_parallel_rank() + tp_world_size = get_tp_group().world_size + base_elements_per_rank = num_elements // tp_world_size + extra_elements = num_elements % tp_world_size + # Ranks < extra_elements get one extra element + elements_per_rank = base_elements_per_rank + (1 if tp_rank < extra_elements else 0) + start = (tp_rank * base_elements_per_rank + min(tp_rank, extra_elements)) + end = start + elements_per_rank + assert tp_rank < tp_world_size - 1 or end == num_elements + return tp_chunk * start, tp_chunk * end + + def destroy_model_parallel(): """Set the groups to none and destroy them.""" global _TP diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 3c7ef5e0080ff..4d492ad6fdad9 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -6,7 +6,8 @@ import torch -from vllm.distributed import (get_tensor_model_parallel_rank, +from vllm.distributed import (get_assigned_range, + get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) from vllm.logger import init_logger @@ -259,6 +260,7 @@ def __init__( custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", e_score_correction_bias: Optional[torch.Tensor] = None, + tp_chunk: int = 1 ): super().__init__() @@ -267,10 +269,9 @@ def __init__( self.tp_size = (tp_size if tp_size is not None else get_tensor_model_parallel_world_size()) + self.tp_chunk = 1 self.top_k = top_k self.num_experts = num_experts - assert intermediate_size % self.tp_size == 0 - self.intermediate_size_per_partition = intermediate_size // self.tp_size self.reduce_results = reduce_results self.renormalize = renormalize self.use_grouped_topk = use_grouped_topk @@ -291,8 +292,13 @@ def __init__( UnquantizedFusedMoEMethod()) else: self.quant_method = quant_config.get_quant_method(self, prefix) + assert self.quant_method.tp_chunk(0) == self.quant_method.tp_chunk(1) + self.tp_chunk = self.quant_method.tp_chunk(0) assert self.quant_method is not None + start_idx, end_idx = get_assigned_range(intermediate_size, self.tp_chunk) + self.intermediate_start = start_idx + self.intermediate_size_per_partition = end_idx - start_idx moe_quant_params = { "num_experts": num_experts, "hidden_size": hidden_size, @@ -328,7 +334,7 @@ def _load_model_weight_or_group_weight_scale(self, expert_data: torch.Tensor, shard_id: str, loaded_weight: torch.Tensor, - tp_rank: int, + tp_start: int, load_full_w2: bool = False): """ Load grouped weight scales for group quantization or model weights @@ -336,7 +342,7 @@ def _load_model_weight_or_group_weight_scale(self, :param expert_data: parameter for a particular expert :param shard_id: either w1, w2, or w3 :param loaded_weight: checkpoint weight to load into the param - :param tp_rank: tensor parallel rank + :param tp_start: tensor parallel slice start :param load_full_w2: whether or not the w2 loaded should be sharded. """ if shard_id == "w2": @@ -345,19 +351,19 @@ def _load_model_weight_or_group_weight_scale(self, self._load_w2(shard_dim=shard_dim, loaded_weight=loaded_weight, expert_data=expert_data, - tp_rank=tp_rank, + tp_start=tp_start, load_full=load_full_w2) elif shard_id in ("w1", "w3"): self._load_w13(shard_id=shard_id, shard_dim=shard_dim, loaded_weight=loaded_weight, expert_data=expert_data, - tp_rank=tp_rank) + tp_start=tp_start) def _load_per_channel_weight_scale(self, expert_data: torch.Tensor, shard_dim: int, shard_id: str, loaded_weight: torch.Tensor, - tp_rank: int): + tp_start: int): # for per channel weight quantization if shard_id == "w2": expert_data.copy_(loaded_weight) @@ -366,16 +372,15 @@ def _load_per_channel_weight_scale(self, expert_data: torch.Tensor, shard_dim=shard_dim, loaded_weight=loaded_weight, expert_data=expert_data, - tp_rank=tp_rank) + tp_start=tp_start) def _load_w13(self, expert_data: torch.Tensor, shard_dim: int, - shard_id: str, loaded_weight: torch.Tensor, tp_rank: int): + shard_id: str, loaded_weight: torch.Tensor, tp_start: int): # Index the loaded weight for tp sharding. # gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim shard_size = expert_data.shape[shard_dim] // 2 - loaded_weight = loaded_weight.narrow(shard_dim, shard_size * tp_rank, - shard_size) + loaded_weight = loaded_weight.narrow(shard_dim, tp_start, shard_size) # Narrow parameter and load. # w1, gate_proj: Load into first logical weight of w13. if shard_id == "w1": @@ -390,7 +395,7 @@ def _load_w2(self, expert_data: torch.Tensor, shard_dim: int, loaded_weight: torch.Tensor, - tp_rank: int, + tp_start: int, load_full: bool = False): # Index the loaded weight for tp sharding. @@ -399,7 +404,7 @@ def _load_w2(self, shard_size = expert_data.shape[shard_dim] if not load_full: loaded_weight = loaded_weight.narrow(shard_dim, - shard_size * tp_rank, + tp_start, shard_size) # w2, down_proj: Load into only logical weight of w2. expert_data.copy_(loaded_weight) @@ -412,13 +417,13 @@ def _load_single_value(self, param: torch.nn.Parameter, param_data[expert_id] = loaded_weight def _load_g_idx(self, shard_id: str, expert_data: torch.Tensor, - shard_dim: int, loaded_weight: torch.Tensor, tp_rank: int): + shard_dim: int, loaded_weight: torch.Tensor, tp_start: int): if shard_id == "w2": self._load_w2(shard_dim=shard_dim, loaded_weight=loaded_weight, expert_data=expert_data, - tp_rank=tp_rank) + tp_start=tp_start) else: assert shard_id in ("w1", "w3") expert_data.copy_(loaded_weight) @@ -480,7 +485,7 @@ def weight_loader(self, param: torch.nn.Parameter, shard_id=shard_id, loaded_weight=loaded_weight, expert_data=expert_data, - tp_rank=tp_rank) + tp_start=self.intermediate_start) return # Case weight scales and zero_points @@ -497,7 +502,7 @@ def weight_loader(self, param: torch.nn.Parameter, shard_dim=shard_dim, loaded_weight=loaded_weight, expert_data=expert_data, - tp_rank=tp_rank) + tp_start=self.intermediate_start) elif quant_method in [ FusedMoeWeightScaleSupported.GROUP.value, FusedMoeWeightScaleSupported.BLOCK.value, @@ -507,7 +512,7 @@ def weight_loader(self, param: torch.nn.Parameter, shard_dim=shard_dim, loaded_weight=loaded_weight, expert_data=expert_data, - tp_rank=tp_rank, + tp_start=self.intermediate_start // self.tp_chunk, load_full_w2=getattr(param, "load_full_w2", False)) elif quant_method == FusedMoeWeightScaleSupported.TENSOR.value: self._load_per_tensor_weight_scale(shard_id=shard_id, @@ -534,7 +539,7 @@ def weight_loader(self, param: torch.nn.Parameter, shard_dim=shard_dim, loaded_weight=loaded_weight, expert_data=expert_data, - tp_rank=tp_rank) + tp_start=self.intermediate_start) return @staticmethod diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 08f1e103e53b7..1f78e6e6cd020 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -2,6 +2,7 @@ import itertools from abc import abstractmethod +from math import lcm from typing import Dict, List, Optional, Tuple import torch @@ -10,6 +11,7 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, + get_assigned_range, split_tensor_along_last_dim, tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce) @@ -84,6 +86,16 @@ def adjust_scalar_to_fused_array(param, loaded_weight, shard_id): return param[shard_id], loaded_weight +def partition_size(total_size: int, tp_chunk: int): + """Computes the partition size for the current rank. + Args: + total_size: This is the full size of the data being partitioned. + tp_chunk: The splitting is done in multiples of this parameter. + """ + chunk_start, chunk_end = get_assigned_range(divide(total_size, tp_chunk)) + return (chunk_end - chunk_start) * tp_chunk + + class LinearMethodBase(QuantizeMethodBase): """Base class for different (maybe quantized) linear methods.""" @@ -277,7 +289,8 @@ class ColumnParallelLinear(LinearBase): output_sizes: list of output sizes packed into one output, like for QKV the list would be size 3. prefix: The name of the layer in the state dict, including all parents - (e.g. model.layers.0.qkv_proj) + (e.g. model.layers.0.qkv_proj) + tp_chunk: Tensor parallel chunk size for splitting weights and inputs across devices. """ def __init__(self, @@ -289,23 +302,28 @@ def __init__(self, params_dtype: Optional[torch.dtype] = None, quant_config: Optional[QuantizationConfig] = None, output_sizes: Optional[List[int]] = None, - prefix: str = ""): + prefix: str = "", + tp_chunk: int = 1): super().__init__(input_size, output_size, skip_bias_add, params_dtype, quant_config, prefix) self.gather_output = gather_output # Divide the weight matrix along the last dimension. - tp_size = get_tensor_model_parallel_world_size() assert self.quant_method is not None - self.output_size_per_partition = divide(self.output_size, tp_size) + self.tp_chunk = lcm(tp_chunk, self.quant_method.tp_chunk(0)) + + self.output_size_per_partition = partition_size(self.output_size, self.tp_chunk) self.output_partition_sizes = [self.output_size_per_partition] # If QKV or MergedColumn, use output size of each partition. if hasattr(self, "output_sizes"): + min_size = min(self.output_sizes) + factors = [divide(size, min_size) for size in self.output_sizes] self.output_partition_sizes = [ - divide(output_size, tp_size) - for output_size in self.output_sizes + partition_size(output_size, self.tp_chunk * factor) + for output_size, factor in zip(self.output_sizes, factors) ] + self.output_partition_factors = factors if output_sizes is None: output_sizes = [output_size] @@ -319,7 +337,8 @@ def __init__(self, params_dtype=self.params_dtype, weight_loader=( self.weight_loader_v2 if self.quant_method.__class__.__name__ - in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader)) + in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader), + tp_chunk=tp_chunk) if bias: self.bias = Parameter( torch.empty(self.output_size_per_partition, @@ -331,8 +350,10 @@ def __init__(self, else: self.register_parameter("bias", None) + def column_tp_chunk(self) -> int: + return self.tp_chunk + def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): - tp_rank = get_tensor_model_parallel_rank() output_dim = getattr(param, "output_dim", None) # Special case for GGUF @@ -353,10 +374,11 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): param_data = param.data if output_dim is not None and not is_sharded_weight: + start_idx, end_idx = get_assigned_range( + loaded_weight.shape[output_dim], self.tp_chunk) shard_size = param_data.shape[output_dim] - start_idx = tp_rank * shard_size loaded_weight = loaded_weight.narrow(output_dim, start_idx, - shard_size) + end_idx - start_idx) # Special case for loading scales off disk, which often do not # have a shape (such as in the case of AutoFP8). @@ -372,7 +394,22 @@ def weight_loader_v2(self, param: Parameter, loaded_weight: torch.Tensor): if len(loaded_weight.shape) == 0: assert loaded_weight.numel() == 1 loaded_weight = loaded_weight.reshape(1) - param.load_column_parallel_weight(loaded_weight=loaded_weight) + + tp_chunk = self.tp_chunk + if isinstance(param, BlockQuantScaleParameter): + from vllm.model_executor.layers.quantization.fp8 import ( + Fp8LinearMethod, Fp8MoEMethod) + assert self.quant_method is not None + assert isinstance(self.quant_method, + (Fp8LinearMethod, Fp8MoEMethod)) + weight_block_size = self.quant_method.quant_config.weight_block_size + assert weight_block_size is not None + block_n, _ = weight_block_size[0], weight_block_size[1] + if tp_chunk % block_n != 0: + breakpoint() + tp_chunk = divide(tp_chunk, block_n) + + param.load_column_parallel_weight(loaded_weight=loaded_weight, tp_chunk=tp_chunk) def forward(self, input_): bias = self.bias if not self.skip_bias_add else None @@ -430,8 +467,6 @@ def __init__(self, quant_config: Optional[QuantizationConfig] = None, prefix: str = ""): self.output_sizes = output_sizes - tp_size = get_tensor_model_parallel_world_size() - assert all(output_size % tp_size == 0 for output_size in output_sizes) super().__init__(input_size=input_size, output_size=sum(output_sizes), bias=bias, @@ -440,6 +475,7 @@ def __init__(self, params_dtype=params_dtype, quant_config=quant_config, prefix=prefix) + assert all(output_size % self.tp_chunk == 0 for output_size in output_sizes) def weight_loader(self, param: Parameter, @@ -462,16 +498,12 @@ def weight_loader(self, return if is_gguf_weight: - tp_size = get_tensor_model_parallel_world_size() - tp_rank = get_tensor_model_parallel_rank() - output_dim = getattr(param, "output_dim", None) - shard_size = loaded_weight.size(output_dim) // tp_size - start_idx = tp_rank * shard_size + start_idx, end_idx = get_assigned_range(loaded_weight.size(output_dim), self.tp_chunk) if loaded_shard_id is not None: loaded_weight = loaded_weight.narrow(output_dim, start_idx, - shard_size) + end_idx - start_idx) param.shard_id.append(loaded_shard_id) param.shard_id_map[loaded_shard_id] = len(param.data_container) param.data_container.append(loaded_weight) @@ -535,6 +567,7 @@ def weight_loader(self, tp_rank = get_tensor_model_parallel_rank() tp_size = get_tensor_model_parallel_world_size() if output_dim is not None: + raise NotImplementedError("Need to support chunk based sharding for this path.") shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size shard_size = self.output_sizes[loaded_shard_id] // tp_size # Special case for quantization. @@ -641,6 +674,8 @@ def weight_loader_v2(self, assert loaded_shard_id < len(self.output_sizes) tp_size = get_tensor_model_parallel_world_size() + tp_chunk = self.tp_chunk + output_sizes = self.output_sizes if isinstance(param, BlockQuantScaleParameter): from vllm.model_executor.layers.quantization.fp8 import ( @@ -651,19 +686,26 @@ def weight_loader_v2(self, weight_block_size = self.quant_method.quant_config.weight_block_size assert weight_block_size is not None block_n, _ = weight_block_size[0], weight_block_size[1] - shard_offset = ( - (sum(self.output_sizes[:loaded_shard_id]) + block_n - 1) // - block_n) // tp_size - shard_size = ((self.output_sizes[loaded_shard_id] + block_n - 1) // - block_n // tp_size) - else: - shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size - shard_size = self.output_sizes[loaded_shard_id] // tp_size + tp_chunk = divide(tp_chunk, block_n) + output_sizes = [divide(size, block_n) for size in output_sizes] + + shard_offset = 0 + for factor, size in zip( + self.output_partition_factors[:loaded_shard_id], + output_sizes[:loaded_shard_id]): + shard_start, shard_end = get_assigned_range( + size, factor * tp_chunk) + shard_offset += shard_end - shard_start + tp_chunk *= self.output_partition_factors[loaded_shard_id] + size_start, size_end = get_assigned_range( + output_sizes[loaded_shard_id], tp_chunk) + shard_size = size_end - size_start param.load_merged_column_weight(loaded_weight=loaded_weight, shard_id=loaded_shard_id, shard_offset=shard_offset, - shard_size=shard_size) + shard_size=shard_size, + tp_chunk=tp_chunk) class QKVParallelLinear(ColumnParallelLinear): @@ -1038,7 +1080,8 @@ def __init__(self, params_dtype: Optional[torch.dtype] = None, reduce_results: bool = True, quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): + prefix: str = "", + tp_chunk: int = 1): super().__init__(input_size, output_size, skip_bias_add, params_dtype, quant_config, prefix) @@ -1048,8 +1091,10 @@ def __init__(self, # Divide the weight matrix along the last dimension. self.tp_rank = get_tensor_model_parallel_rank() self.tp_size = get_tensor_model_parallel_world_size() - self.input_size_per_partition = divide(input_size, self.tp_size) assert self.quant_method is not None + self.tp_chunk = lcm(tp_chunk, self.quant_method.tp_chunk(1)) + start_idx, end_idx = get_assigned_range(input_size, self.tp_chunk) + self.input_size_per_partition = end_idx - start_idx self.quant_method.create_weights( layer=self, @@ -1060,7 +1105,8 @@ def __init__(self, params_dtype=self.params_dtype, weight_loader=( self.weight_loader_v2 if self.quant_method.__class__.__name__ - in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader)) + in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader), + tp_chunk=tp_chunk) if not reduce_results and (bias and not skip_bias_add): raise ValueError("When not reduce the results, adding bias to the " "results can lead to incorrect results") @@ -1101,9 +1147,9 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): param_data = param.data if input_dim is not None and not is_sharded_weight: shard_size = param_data.shape[input_dim] - start_idx = tp_rank * shard_size + start_idx, end_idx = get_assigned_range(loaded_weight.shape[input_dim], self.tp_chunk) loaded_weight = loaded_weight.narrow(input_dim, start_idx, - shard_size) + end_idx - start_idx) # Special case for loading scales off disk, which often do not # have a shape (such as in the case of AutoFP8). @@ -1122,7 +1168,20 @@ def weight_loader_v2(self, param: BasevLLMParameter, assert loaded_weight.numel() == 1 loaded_weight = loaded_weight.reshape(1) - param.load_row_parallel_weight(loaded_weight=loaded_weight) + tp_chunk = self.tp_chunk + if isinstance(param, BlockQuantScaleParameter): + from vllm.model_executor.layers.quantization.fp8 import ( + Fp8LinearMethod, Fp8MoEMethod) + assert self.quant_method is not None + assert isinstance(self.quant_method, + (Fp8LinearMethod, Fp8MoEMethod)) + weight_block_size = self.quant_method.quant_config.weight_block_size + assert weight_block_size is not None + _, block_k = weight_block_size[0], weight_block_size[1] + tp_chunk = divide(tp_chunk, block_k) + + param.load_row_parallel_weight(loaded_weight=loaded_weight, + tp_chunk=tp_chunk) def forward(self, input_): if self.input_is_parallel: diff --git a/vllm/model_executor/layers/quantization/base_config.py b/vllm/model_executor/layers/quantization/base_config.py index 2eefcc4f30516..468585af56fb1 100644 --- a/vllm/model_executor/layers/quantization/base_config.py +++ b/vllm/model_executor/layers/quantization/base_config.py @@ -41,6 +41,10 @@ def process_weights_after_loading(self, layer: nn.Module) -> None: """ return + def tp_chunk(self, dim=None) -> int: + """Returns the tensor parallel chunk size along the specified dimension.""" + return 256 # Reasonable default. + def method_has_implemented_embedding( method_class: Type[QuantizeMethodBase]) -> bool: diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 86e025310f4ef..48883a33a3e2d 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -150,7 +150,14 @@ def __init__(self, quant_config: Fp8Config): if self.block_quant: # Marlin doesn't support block-wise fp8 self.use_marlin = False - + + def tp_chunk(self, dim=None) -> int: + if self.block_quant: + if dim is None: + return max(*self.quant_config.weight_block_size) # Should actually be LCM. + return self.quant_config.weight_block_size[dim] + return super.tp_chunk(dim) + def create_weights( self, layer: torch.nn.Module, @@ -163,6 +170,7 @@ def create_weights( ): output_size_per_partition = sum(output_partition_sizes) weight_loader = extra_weight_attrs.get("weight_loader") + tp_chunk = extra_weight_attrs.get("tp_chunk") or self.tp_chunk() if self.block_quant: tp_size = get_tensor_model_parallel_world_size() @@ -173,15 +181,19 @@ def create_weights( ) # Required by row parallel if (tp_size > 1 - and input_size // input_size_per_partition == tp_size + and output_size_per_partition == output_size + and input_size_per_partition != input_size + # Not column sharded and input_size_per_partition % block_k != 0): raise ValueError( f"Weight input_size_per_partition = " f"{input_size_per_partition} is not divisible by " f"weight quantization block_k = {block_k}.") # Required by column parallel or enabling merged weights - if (tp_size > 1 and output_size // output_size_per_partition - == tp_size) or len(output_partition_sizes) > 1: + if (tp_size > 1 + and input_size_per_partition == input_size # Not row sharded + and output_size_per_partition != output_size + ) or len(output_partition_sizes) > 1: for output_partition_size in output_partition_sizes: if output_partition_size % block_n != 0: raise ValueError( diff --git a/vllm/model_executor/models/deepseek_v3.py b/vllm/model_executor/models/deepseek_v3.py index a4829aa1a572b..c3aaa847f2ca6 100644 --- a/vllm/model_executor/models/deepseek_v3.py +++ b/vllm/model_executor/models/deepseek_v3.py @@ -31,7 +31,8 @@ from vllm.attention import Attention, AttentionMetadata from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, ModelConfig, VllmConfig -from vllm.distributed import (get_pp_group, +from vllm.distributed import (get_assigned_range, + get_pp_group, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) from vllm.model_executor.layers.activation import SiluAndMul @@ -101,14 +102,15 @@ def __init__( prefix: str = "", ): super().__init__() - self.tp_size = get_tensor_model_parallel_world_size() + tp_size = get_tensor_model_parallel_world_size() self.routed_scaling_factor = config.routed_scaling_factor self.n_shared_experts = config.n_shared_experts self.routed_scaling_factor = config.routed_scaling_factor - if self.tp_size > config.n_routed_experts: + if tp_size > config.n_routed_experts: raise ValueError( f"Tensor parallel size {self.tp_size} is greater than " f"the number of experts {config.n_routed_experts}.") + self.needs_reduce = tp_size > 1 if config.hidden_act != "silu": raise ValueError(f"Unsupported activation: {config.hidden_act}. " @@ -163,7 +165,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: router_logits=router_logits) * self.routed_scaling_factor if shared_output is not None: final_hidden_states = final_hidden_states + shared_output - if self.tp_size > 1: + if self.needs_reduce: final_hidden_states = tensor_model_parallel_all_reduce( final_hidden_states) @@ -205,9 +207,8 @@ def __init__( self.q_lora_rank = q_lora_rank self.kv_lora_rank = kv_lora_rank self.num_heads = num_heads - tp_size = get_tensor_model_parallel_world_size() - assert num_heads % tp_size == 0 - self.num_local_heads = num_heads // tp_size + start_head, end_head = get_assigned_range(num_heads, 2) + self.num_local_heads = end_head - start_head self.scaling = self.qk_head_dim**-0.5 self.rope_theta = rope_theta self.max_position_embeddings = max_position_embeddings @@ -225,14 +226,17 @@ def __init__( self.qk_head_dim, bias=False, quant_config=quant_config, - prefix=f"{prefix}.q_b_proj") + prefix=f"{prefix}.q_b_proj", + tp_chunk=(2 * self.qk_head_dim)) + #tp_chunk=self.qk_head_dim) else: self.q_proj = ColumnParallelLinear(self.hidden_size, self.num_heads * self.qk_head_dim, bias=False, quant_config=quant_config, - prefix=f"{prefix}.q_proj") + prefix=f"{prefix}.q_proj", + tp_chunk=(2*self.qk_head_dim)) self.kv_a_proj_with_mqa = ReplicatedLinear( self.hidden_size, @@ -247,13 +251,16 @@ def __init__( self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), bias=False, quant_config=quant_config, - prefix=f"{prefix}.kv_b_proj") + prefix=f"{prefix}.kv_b_proj", + tp_chunk=(2 * (self.qk_nope_head_dim + self.v_head_dim))) + #tp_chunk=(self.qk_nope_head_dim + self.v_head_dim)) # O projection. self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim, self.hidden_size, bias=False, quant_config=quant_config, - prefix=f"{prefix}.o_proj") + prefix=f"{prefix}.o_proj", + tp_chunk=(2*self.v_head_dim)) if rope_scaling: rope_scaling["rope_type"] = 'deepseek_yarn' self.use_normal_rope = False diff --git a/vllm/model_executor/parameter.py b/vllm/model_executor/parameter.py index 2b1294bf7baa3..5daf3ada7d6ff 100644 --- a/vllm/model_executor/parameter.py +++ b/vllm/model_executor/parameter.py @@ -6,7 +6,7 @@ import torch from torch.nn import Parameter -from vllm.distributed import get_tensor_model_parallel_rank +from vllm.distributed import get_assigned_range from vllm.logger import init_logger from vllm.model_executor.utils import _make_synced_weight_loader @@ -100,15 +100,14 @@ def __init__(self, output_dim: int, **kwargs): def output_dim(self): return self._output_dim - def load_column_parallel_weight(self, loaded_weight: torch.Tensor): - tp_rank = get_tensor_model_parallel_rank() - shard_size = self.data.shape[self.output_dim] + def load_column_parallel_weight(self, loaded_weight: torch.Tensor, tp_chunk: int = 1): + start_idx, end_idx = get_assigned_range(loaded_weight.shape[self.output_dim], tp_chunk) loaded_weight = loaded_weight.narrow(self.output_dim, - tp_rank * shard_size, shard_size) + start_idx, end_idx - start_idx) assert self.data.shape == loaded_weight.shape self.data.copy_(loaded_weight) - def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs): + def load_merged_column_weight(self, loaded_weight: torch.Tensor, tp_chunk: int = 1, **kwargs): shard_offset = kwargs.get("shard_offset") shard_size = kwargs.get("shard_size") @@ -121,15 +120,15 @@ def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs): param_data = self.data - tp_rank = get_tensor_model_parallel_rank() + start_idx, end_idx = get_assigned_range(loaded_weight.shape[self.output_dim], tp_chunk) param_data = param_data.narrow(self.output_dim, shard_offset, shard_size) loaded_weight = loaded_weight.narrow(self.output_dim, - tp_rank * shard_size, shard_size) + start_idx, end_idx - start_idx) assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) - def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs): + def load_qkv_weight(self, loaded_weight: torch.Tensor, tp_chunk: int = 1, **kwargs): shard_offset = kwargs.get("shard_offset") shard_size = kwargs.get("shard_size") @@ -144,6 +143,7 @@ def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs): shard_offset=shard_offset, shard_size=shard_size) param_data = self.data + assert tp_chunk == 1 # This needs to be fixed. tp_rank = get_tensor_model_parallel_rank() shard_id = tp_rank if shard_id == "q" else tp_rank // num_heads param_data = param_data.narrow(self.output_dim, shard_offset, @@ -171,13 +171,13 @@ def __init__(self, input_dim: int, **kwargs): def input_dim(self): return self._input_dim - def load_row_parallel_weight(self, loaded_weight: torch.Tensor): - tp_rank = get_tensor_model_parallel_rank() + def load_row_parallel_weight(self, loaded_weight: torch.Tensor, tp_chunk: int = 1): + start_idx, end_idx = get_assigned_range(loaded_weight.shape[self.input_dim], tp_chunk) shard_size = self.data.shape[self.input_dim] loaded_weight = loaded_weight.narrow(self.input_dim, - tp_rank * shard_size, shard_size) + start_idx, end_idx - start_idx) - if len(loaded_weight.shape) == 0: + if len(loaded_weight.shape) == 0: # How can this every be hit? The narrow op would fail. loaded_weight = loaded_weight.reshape(1) assert self.data.shape == loaded_weight.shape diff --git a/vllm/worker/cache_engine.py b/vllm/worker/cache_engine.py index 252fe06600dae..5ecaf8cb66083 100644 --- a/vllm/worker/cache_engine.py +++ b/vllm/worker/cache_engine.py @@ -6,6 +6,7 @@ from vllm.attention import get_attn_backend from vllm.config import CacheConfig, DeviceConfig, ModelConfig, ParallelConfig +from vllm.distributed import get_assigned_range, get_tensor_model_parallel_rank from vllm.logger import init_logger from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, LayerBlockType, get_dtype_size, is_pin_memory_available) @@ -37,7 +38,6 @@ def __init__( # Models like Jamba, have mixed typed layers, E.g Mamba self.num_attention_layers = model_config.get_num_layers_by_block_type( parallel_config, LayerBlockType.attention) - self.num_kv_heads = model_config.get_num_kv_heads(parallel_config) self.block_size = cache_config.block_size self.num_gpu_blocks = cache_config.num_gpu_blocks @@ -71,8 +71,10 @@ def _allocate_kv_cache( device: str, ) -> List[torch.Tensor]: """Allocates KV cache on the specified device.""" + total_num_kv_heads = self.model_config.get_total_num_kv_heads() + start_idx, end_idx = get_assigned_range(total_num_kv_heads, 2) kv_cache_shape = self.attn_backend.get_kv_cache_shape( - num_blocks, self.block_size, self.num_kv_heads, self.head_size) + num_blocks, self.block_size, end_idx - start_idx, self.head_size) pin_memory = is_pin_memory_available() if device == "cpu" else False kv_cache: List[torch.Tensor] = [] for _ in range(self.num_attention_layers):