diff --git a/vllm/config.py b/vllm/config.py index f6bd8b1ad8f14..76031434df225 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -690,11 +690,6 @@ def verify_with_parallel_config( total_num_attention_heads = getattr(self.hf_text_config, "num_attention_heads", 0) tensor_parallel_size = parallel_config.tensor_parallel_size - if total_num_attention_heads % tensor_parallel_size != 0: - raise ValueError( - f"Total number of attention heads ({total_num_attention_heads})" - " must be divisible by tensor parallel size " - f"({tensor_parallel_size}).") pipeline_parallel_size = parallel_config.pipeline_parallel_size if pipeline_parallel_size > 1: @@ -745,7 +740,8 @@ def is_deepseek_mla(self) -> bool: def get_head_size(self) -> int: # TODO remove hard code - if self.is_deepseek_mla: + if self.is_deepseek_mla or hasattr(self.hf_text_config, + 'qk_nope_head_dim'): if self.use_mla: return self.hf_text_config.kv_lora_rank else: @@ -822,13 +818,15 @@ def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int: # the tensor parallel size. We will replicate the KV heads in the # case where the number of KV heads is smaller than the tensor # parallel size so each GPU has at least one KV head. - return max(1, - total_num_kv_heads // parallel_config.tensor_parallel_size) + return max(1, (total_num_kv_heads + + parallel_config.tensor_parallel_size - 1) // + parallel_config.tensor_parallel_size) def get_num_attention_heads(self, parallel_config: "ParallelConfig") -> int: num_heads = getattr(self.hf_text_config, "num_attention_heads", 0) - return num_heads // parallel_config.tensor_parallel_size + return ((num_heads + parallel_config.tensor_parallel_size - 1) // + parallel_config.tensor_parallel_size) def get_layers_start_end_indices( self, parallel_config: "ParallelConfig") -> Tuple[int, int]: diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 7fe9b68d4b9e8..70e0580a628cf 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -1149,6 +1149,22 @@ def get_tensor_model_parallel_rank(): return get_tp_group().rank_in_group +def get_assigned_range(num_elements: int, tp_chunk: int = 1): + assert num_elements % tp_chunk == 0, 'Chunk size must divide the elements.' + num_elements = num_elements // tp_chunk + + tp_rank = get_tensor_model_parallel_rank() + tp_world_size = get_tp_group().world_size + base_elements_per_rank = num_elements // tp_world_size + extra_elements = num_elements % tp_world_size + # Ranks < extra_elements get one extra element + elements_per_rank = base_elements_per_rank + (1 if tp_rank < extra_elements else 0) + start = (tp_rank * base_elements_per_rank + min(tp_rank, extra_elements)) + end = start + elements_per_rank + assert tp_rank < tp_world_size - 1 or end == num_elements + return tp_chunk * start, tp_chunk * end + + def destroy_model_parallel(): """Set the groups to none and destroy them.""" global _TP diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index da0ce1885dbb2..88f1b237d85d6 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -4,7 +4,8 @@ import torch -from vllm.distributed import (get_tensor_model_parallel_rank, +from vllm.distributed import (get_assigned_range, + get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) from vllm.logger import init_logger @@ -257,6 +258,7 @@ def __init__( custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", e_score_correction_bias: Optional[torch.Tensor] = None, + tp_chunk: int = 1 ): super().__init__() @@ -265,10 +267,9 @@ def __init__( self.tp_size = (tp_size if tp_size is not None else get_tensor_model_parallel_world_size()) + self.tp_chunk = 1 self.top_k = top_k self.num_experts = num_experts - assert intermediate_size % self.tp_size == 0 - self.intermediate_size_per_partition = intermediate_size // self.tp_size self.reduce_results = reduce_results self.renormalize = renormalize self.use_grouped_topk = use_grouped_topk @@ -289,8 +290,13 @@ def __init__( UnquantizedFusedMoEMethod()) else: self.quant_method = quant_config.get_quant_method(self, prefix) + assert self.quant_method.tp_chunk(0) == self.quant_method.tp_chunk(1) + self.tp_chunk = self.quant_method.tp_chunk(0) assert self.quant_method is not None + start_idx, end_idx = get_assigned_range(intermediate_size, self.tp_chunk) + self.intermediate_start = start_idx + self.intermediate_size_per_partition = end_idx - start_idx moe_quant_params = { "num_experts": num_experts, "hidden_size": hidden_size, @@ -326,7 +332,7 @@ def _load_model_weight_or_group_weight_scale(self, expert_data: torch.Tensor, shard_id: str, loaded_weight: torch.Tensor, - tp_rank: int, + tp_start: int, load_full_w2: bool = False): """ Load grouped weight scales for group quantization or model weights @@ -334,7 +340,7 @@ def _load_model_weight_or_group_weight_scale(self, :param expert_data: parameter for a particular expert :param shard_id: either w1, w2, or w3 :param loaded_weight: checkpoint weight to load into the param - :param tp_rank: tensor parallel rank + :param tp_start: tensor parallel slice start :param load_full_w2: whether or not the w2 loaded should be sharded. """ if shard_id == "w2": @@ -343,19 +349,19 @@ def _load_model_weight_or_group_weight_scale(self, self._load_w2(shard_dim=shard_dim, loaded_weight=loaded_weight, expert_data=expert_data, - tp_rank=tp_rank, + tp_start=tp_start, load_full=load_full_w2) elif shard_id in ("w1", "w3"): self._load_w13(shard_id=shard_id, shard_dim=shard_dim, loaded_weight=loaded_weight, expert_data=expert_data, - tp_rank=tp_rank) + tp_start=tp_start) def _load_per_channel_weight_scale(self, expert_data: torch.Tensor, shard_dim: int, shard_id: str, loaded_weight: torch.Tensor, - tp_rank: int): + tp_start: int): # for per channel weight quantization if shard_id == "w2": expert_data.copy_(loaded_weight) @@ -364,16 +370,15 @@ def _load_per_channel_weight_scale(self, expert_data: torch.Tensor, shard_dim=shard_dim, loaded_weight=loaded_weight, expert_data=expert_data, - tp_rank=tp_rank) + tp_start=tp_start) def _load_w13(self, expert_data: torch.Tensor, shard_dim: int, - shard_id: str, loaded_weight: torch.Tensor, tp_rank: int): + shard_id: str, loaded_weight: torch.Tensor, tp_start: int): # Index the loaded weight for tp sharding. # gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim shard_size = expert_data.shape[shard_dim] // 2 - loaded_weight = loaded_weight.narrow(shard_dim, shard_size * tp_rank, - shard_size) + loaded_weight = loaded_weight.narrow(shard_dim, tp_start, shard_size) # Narrow parameter and load. # w1, gate_proj: Load into first logical weight of w13. if shard_id == "w1": @@ -388,7 +393,7 @@ def _load_w2(self, expert_data: torch.Tensor, shard_dim: int, loaded_weight: torch.Tensor, - tp_rank: int, + tp_start: int, load_full: bool = False): # Index the loaded weight for tp sharding. @@ -397,7 +402,7 @@ def _load_w2(self, shard_size = expert_data.shape[shard_dim] if not load_full: loaded_weight = loaded_weight.narrow(shard_dim, - shard_size * tp_rank, + tp_start, shard_size) # w2, down_proj: Load into only logical weight of w2. expert_data.copy_(loaded_weight) @@ -410,13 +415,13 @@ def _load_single_value(self, param: torch.nn.Parameter, param_data[expert_id] = loaded_weight def _load_g_idx(self, shard_id: str, expert_data: torch.Tensor, - shard_dim: int, loaded_weight: torch.Tensor, tp_rank: int): + shard_dim: int, loaded_weight: torch.Tensor, tp_start: int): if shard_id == "w2": self._load_w2(shard_dim=shard_dim, loaded_weight=loaded_weight, expert_data=expert_data, - tp_rank=tp_rank) + tp_start=tp_start) else: assert shard_id in ("w1", "w3") expert_data.copy_(loaded_weight) @@ -478,7 +483,7 @@ def weight_loader(self, param: torch.nn.Parameter, shard_id=shard_id, loaded_weight=loaded_weight, expert_data=expert_data, - tp_rank=tp_rank) + tp_start=self.intermediate_start) return # Case weight scales and zero_points @@ -495,7 +500,7 @@ def weight_loader(self, param: torch.nn.Parameter, shard_dim=shard_dim, loaded_weight=loaded_weight, expert_data=expert_data, - tp_rank=tp_rank) + tp_start=self.intermediate_start) elif quant_method in [ FusedMoeWeightScaleSupported.GROUP.value, FusedMoeWeightScaleSupported.BLOCK.value, @@ -505,7 +510,7 @@ def weight_loader(self, param: torch.nn.Parameter, shard_dim=shard_dim, loaded_weight=loaded_weight, expert_data=expert_data, - tp_rank=tp_rank, + tp_start=self.intermediate_start // self.tp_chunk, load_full_w2=getattr(param, "load_full_w2", False)) elif quant_method == FusedMoeWeightScaleSupported.TENSOR.value: self._load_per_tensor_weight_scale(shard_id=shard_id, @@ -532,7 +537,7 @@ def weight_loader(self, param: torch.nn.Parameter, shard_dim=shard_dim, loaded_weight=loaded_weight, expert_data=expert_data, - tp_rank=tp_rank) + tp_start=self.intermediate_start) return @staticmethod diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 52263e96fb9f9..c85b41763a638 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -1,5 +1,6 @@ import itertools from abc import abstractmethod +from math import lcm from typing import Dict, List, Optional, Tuple import torch @@ -8,6 +9,7 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, + get_assigned_range, split_tensor_along_last_dim, tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce) @@ -82,6 +84,16 @@ def adjust_scalar_to_fused_array(param, loaded_weight, shard_id): return param[shard_id], loaded_weight +def partition_size(total_size: int, tp_chunk: int): + """Computes the partition size for the current rank. + Args: + total_size: This is the full size of the data being partitioned. + tp_chunk: The splitting is done in multiples of this parameter. + """ + chunk_start, chunk_end = get_assigned_range(divide(total_size, tp_chunk)) + return (chunk_end - chunk_start) * tp_chunk + + class LinearMethodBase(QuantizeMethodBase): """Base class for different (maybe quantized) linear methods.""" @@ -275,7 +287,8 @@ class ColumnParallelLinear(LinearBase): output_sizes: list of output sizes packed into one output, like for QKV the list would be size 3. prefix: The name of the layer in the state dict, including all parents - (e.g. model.layers.0.qkv_proj) + (e.g. model.layers.0.qkv_proj) + tp_chunk: Tensor parallel chunk size for splitting weights and inputs across devices. """ def __init__(self, @@ -287,23 +300,28 @@ def __init__(self, params_dtype: Optional[torch.dtype] = None, quant_config: Optional[QuantizationConfig] = None, output_sizes: Optional[List[int]] = None, - prefix: str = ""): + prefix: str = "", + tp_chunk: int = 1): super().__init__(input_size, output_size, skip_bias_add, params_dtype, quant_config, prefix) self.gather_output = gather_output # Divide the weight matrix along the last dimension. - tp_size = get_tensor_model_parallel_world_size() assert self.quant_method is not None - self.output_size_per_partition = divide(self.output_size, tp_size) + self.tp_chunk = lcm(tp_chunk, self.quant_method.tp_chunk(0)) + + self.output_size_per_partition = partition_size(self.output_size, self.tp_chunk) self.output_partition_sizes = [self.output_size_per_partition] # If QKV or MergedColumn, use output size of each partition. if hasattr(self, "output_sizes"): + min_size = min(self.output_sizes) + factors = [divide(size, min_size) for size in self.output_sizes] self.output_partition_sizes = [ - divide(output_size, tp_size) - for output_size in self.output_sizes + partition_size(output_size, self.tp_chunk * factor) + for output_size, factor in zip(self.output_sizes, factors) ] + self.output_partition_factors = factors if output_sizes is None: output_sizes = [output_size] @@ -317,7 +335,8 @@ def __init__(self, params_dtype=self.params_dtype, weight_loader=( self.weight_loader_v2 if self.quant_method.__class__.__name__ - in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader)) + in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader), + tp_chunk=tp_chunk) if bias: self.bias = Parameter( torch.empty(self.output_size_per_partition, @@ -329,8 +348,10 @@ def __init__(self, else: self.register_parameter("bias", None) + def column_tp_chunk(self) -> int: + return self.tp_chunk + def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): - tp_rank = get_tensor_model_parallel_rank() output_dim = getattr(param, "output_dim", None) # Special case for GGUF @@ -351,10 +372,11 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): param_data = param.data if output_dim is not None and not is_sharded_weight: + start_idx, end_idx = get_assigned_range( + loaded_weight.shape[output_dim], self.tp_chunk) shard_size = param_data.shape[output_dim] - start_idx = tp_rank * shard_size loaded_weight = loaded_weight.narrow(output_dim, start_idx, - shard_size) + end_idx - start_idx) # Special case for loading scales off disk, which often do not # have a shape (such as in the case of AutoFP8). @@ -370,7 +392,22 @@ def weight_loader_v2(self, param: Parameter, loaded_weight: torch.Tensor): if len(loaded_weight.shape) == 0: assert loaded_weight.numel() == 1 loaded_weight = loaded_weight.reshape(1) - param.load_column_parallel_weight(loaded_weight=loaded_weight) + + tp_chunk = self.tp_chunk + if isinstance(param, BlockQuantScaleParameter): + from vllm.model_executor.layers.quantization.fp8 import ( + Fp8LinearMethod, Fp8MoEMethod) + assert self.quant_method is not None + assert isinstance(self.quant_method, + (Fp8LinearMethod, Fp8MoEMethod)) + weight_block_size = self.quant_method.quant_config.weight_block_size + assert weight_block_size is not None + block_n, _ = weight_block_size[0], weight_block_size[1] + if tp_chunk % block_n != 0: + breakpoint() + tp_chunk = divide(tp_chunk, block_n) + + param.load_column_parallel_weight(loaded_weight=loaded_weight, tp_chunk=tp_chunk) def forward(self, input_): bias = self.bias if not self.skip_bias_add else None @@ -428,8 +465,6 @@ def __init__(self, quant_config: Optional[QuantizationConfig] = None, prefix: str = ""): self.output_sizes = output_sizes - tp_size = get_tensor_model_parallel_world_size() - assert all(output_size % tp_size == 0 for output_size in output_sizes) super().__init__(input_size=input_size, output_size=sum(output_sizes), bias=bias, @@ -438,6 +473,7 @@ def __init__(self, params_dtype=params_dtype, quant_config=quant_config, prefix=prefix) + assert all(output_size % self.tp_chunk == 0 for output_size in output_sizes) def weight_loader(self, param: Parameter, @@ -460,16 +496,12 @@ def weight_loader(self, return if is_gguf_weight: - tp_size = get_tensor_model_parallel_world_size() - tp_rank = get_tensor_model_parallel_rank() - output_dim = getattr(param, "output_dim", None) - shard_size = loaded_weight.size(output_dim) // tp_size - start_idx = tp_rank * shard_size + start_idx, end_idx = get_assigned_range(loaded_weight.size(output_dim), self.tp_chunk) if loaded_shard_id is not None: loaded_weight = loaded_weight.narrow(output_dim, start_idx, - shard_size) + end_idx - start_idx) param.shard_id.append(loaded_shard_id) param.shard_id_map[loaded_shard_id] = len(param.data_container) param.data_container.append(loaded_weight) @@ -533,6 +565,7 @@ def weight_loader(self, tp_rank = get_tensor_model_parallel_rank() tp_size = get_tensor_model_parallel_world_size() if output_dim is not None: + raise NotImplementedError("Need to support chunk based sharding for this path.") shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size shard_size = self.output_sizes[loaded_shard_id] // tp_size # Special case for quantization. @@ -639,6 +672,8 @@ def weight_loader_v2(self, assert loaded_shard_id < len(self.output_sizes) tp_size = get_tensor_model_parallel_world_size() + tp_chunk = self.tp_chunk + output_sizes = self.output_sizes if isinstance(param, BlockQuantScaleParameter): from vllm.model_executor.layers.quantization.fp8 import ( @@ -649,19 +684,26 @@ def weight_loader_v2(self, weight_block_size = self.quant_method.quant_config.weight_block_size assert weight_block_size is not None block_n, _ = weight_block_size[0], weight_block_size[1] - shard_offset = ( - (sum(self.output_sizes[:loaded_shard_id]) + block_n - 1) // - block_n) // tp_size - shard_size = ((self.output_sizes[loaded_shard_id] + block_n - 1) // - block_n // tp_size) - else: - shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size - shard_size = self.output_sizes[loaded_shard_id] // tp_size + tp_chunk = divide(tp_chunk, block_n) + output_sizes = [divide(size, block_n) for size in output_sizes] + + shard_offset = 0 + for factor, size in zip( + self.output_partition_factors[:loaded_shard_id], + output_sizes[:loaded_shard_id]): + shard_start, shard_end = get_assigned_range( + size, factor * tp_chunk) + shard_offset += shard_end - shard_start + tp_chunk *= self.output_partition_factors[loaded_shard_id] + size_start, size_end = get_assigned_range( + output_sizes[loaded_shard_id], tp_chunk) + shard_size = size_end - size_start param.load_merged_column_weight(loaded_weight=loaded_weight, shard_id=loaded_shard_id, shard_offset=shard_offset, - shard_size=shard_size) + shard_size=shard_size, + tp_chunk=tp_chunk) class QKVParallelLinear(ColumnParallelLinear): @@ -1036,7 +1078,8 @@ def __init__(self, params_dtype: Optional[torch.dtype] = None, reduce_results: bool = True, quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): + prefix: str = "", + tp_chunk: int = 1): super().__init__(input_size, output_size, skip_bias_add, params_dtype, quant_config, prefix) @@ -1046,8 +1089,10 @@ def __init__(self, # Divide the weight matrix along the last dimension. self.tp_rank = get_tensor_model_parallel_rank() self.tp_size = get_tensor_model_parallel_world_size() - self.input_size_per_partition = divide(input_size, self.tp_size) assert self.quant_method is not None + self.tp_chunk = lcm(tp_chunk, self.quant_method.tp_chunk(1)) + start_idx, end_idx = get_assigned_range(input_size, self.tp_chunk) + self.input_size_per_partition = end_idx - start_idx self.quant_method.create_weights( layer=self, @@ -1058,7 +1103,8 @@ def __init__(self, params_dtype=self.params_dtype, weight_loader=( self.weight_loader_v2 if self.quant_method.__class__.__name__ - in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader)) + in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader), + tp_chunk=tp_chunk) if not reduce_results and (bias and not skip_bias_add): raise ValueError("When not reduce the results, adding bias to the " "results can lead to incorrect results") @@ -1099,9 +1145,9 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): param_data = param.data if input_dim is not None and not is_sharded_weight: shard_size = param_data.shape[input_dim] - start_idx = tp_rank * shard_size + start_idx, end_idx = get_assigned_range(loaded_weight.shape[input_dim], self.tp_chunk) loaded_weight = loaded_weight.narrow(input_dim, start_idx, - shard_size) + end_idx - start_idx) # Special case for loading scales off disk, which often do not # have a shape (such as in the case of AutoFP8). @@ -1120,7 +1166,20 @@ def weight_loader_v2(self, param: BasevLLMParameter, assert loaded_weight.numel() == 1 loaded_weight = loaded_weight.reshape(1) - param.load_row_parallel_weight(loaded_weight=loaded_weight) + tp_chunk = self.tp_chunk + if isinstance(param, BlockQuantScaleParameter): + from vllm.model_executor.layers.quantization.fp8 import ( + Fp8LinearMethod, Fp8MoEMethod) + assert self.quant_method is not None + assert isinstance(self.quant_method, + (Fp8LinearMethod, Fp8MoEMethod)) + weight_block_size = self.quant_method.quant_config.weight_block_size + assert weight_block_size is not None + _, block_k = weight_block_size[0], weight_block_size[1] + tp_chunk = divide(tp_chunk, block_k) + + param.load_row_parallel_weight(loaded_weight=loaded_weight, + tp_chunk=tp_chunk) def forward(self, input_): if self.input_is_parallel: diff --git a/vllm/model_executor/layers/quantization/base_config.py b/vllm/model_executor/layers/quantization/base_config.py index 2fb2642dd5156..c9c0ae727d11e 100644 --- a/vllm/model_executor/layers/quantization/base_config.py +++ b/vllm/model_executor/layers/quantization/base_config.py @@ -39,6 +39,10 @@ def process_weights_after_loading(self, layer: nn.Module) -> None: """ return + def tp_chunk(self, dim=None) -> int: + """Returns the tensor parallel chunk size along the specified dimension.""" + return 256 # Reasonable default. + def method_has_implemented_embedding( method_class: Type[QuantizeMethodBase]) -> bool: diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 57dd6e310297d..ada6883273847 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -146,7 +146,14 @@ def __init__(self, quant_config: Fp8Config): if self.block_quant: # Marlin doesn't support block-wise fp8 self.use_marlin = False - + + def tp_chunk(self, dim=None) -> int: + if self.block_quant: + if dim is None: + return max(*self.quant_config.weight_block_size) # Should actually be LCM. + return self.quant_config.weight_block_size[dim] + return super.tp_chunk(dim) + def create_weights( self, layer: torch.nn.Module, @@ -159,6 +166,7 @@ def create_weights( ): output_size_per_partition = sum(output_partition_sizes) weight_loader = extra_weight_attrs.get("weight_loader") + tp_chunk = extra_weight_attrs.get("tp_chunk") or self.tp_chunk() if self.block_quant: tp_size = get_tensor_model_parallel_world_size() @@ -169,15 +177,19 @@ def create_weights( ) # Required by row parallel if (tp_size > 1 - and input_size // input_size_per_partition == tp_size + and output_size_per_partition == output_size + and input_size_per_partition != input_size + # Not column sharded and input_size_per_partition % block_k != 0): raise ValueError( f"Weight input_size_per_partition = " f"{input_size_per_partition} is not divisible by " f"weight quantization block_k = {block_k}.") # Required by column parallel or enabling merged weights - if (tp_size > 1 and output_size // output_size_per_partition - == tp_size) or len(output_partition_sizes) > 1: + if (tp_size > 1 + and input_size_per_partition == input_size # Not row sharded + and output_size_per_partition != output_size + ) or len(output_partition_sizes) > 1: for output_partition_size in output_partition_sizes: if output_partition_size % block_n != 0: raise ValueError( diff --git a/vllm/model_executor/models/deepseek_v3.py b/vllm/model_executor/models/deepseek_v3.py index 0b44f0d062c40..7a22bf1eef699 100644 --- a/vllm/model_executor/models/deepseek_v3.py +++ b/vllm/model_executor/models/deepseek_v3.py @@ -28,7 +28,8 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import (get_pp_group, +from vllm.distributed import (get_assigned_range, + get_pp_group, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) from vllm.model_executor.layers.activation import SiluAndMul @@ -98,14 +99,15 @@ def __init__( prefix: str = "", ): super().__init__() - self.tp_size = get_tensor_model_parallel_world_size() + tp_size = get_tensor_model_parallel_world_size() self.routed_scaling_factor = config.routed_scaling_factor self.n_shared_experts = config.n_shared_experts self.routed_scaling_factor = config.routed_scaling_factor - if self.tp_size > config.n_routed_experts: + if tp_size > config.n_routed_experts: raise ValueError( f"Tensor parallel size {self.tp_size} is greater than " f"the number of experts {config.n_routed_experts}.") + self.needs_reduce = tp_size > 1 if config.hidden_act != "silu": raise ValueError(f"Unsupported activation: {config.hidden_act}. " @@ -160,7 +162,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: router_logits=router_logits) * self.routed_scaling_factor if shared_output is not None: final_hidden_states = final_hidden_states + shared_output - if self.tp_size > 1: + if self.needs_reduce: final_hidden_states = tensor_model_parallel_all_reduce( final_hidden_states) @@ -202,9 +204,8 @@ def __init__( self.q_lora_rank = q_lora_rank self.kv_lora_rank = kv_lora_rank self.num_heads = num_heads - tp_size = get_tensor_model_parallel_world_size() - assert num_heads % tp_size == 0 - self.num_local_heads = num_heads // tp_size + start_head, end_head = get_assigned_range(num_heads, 2) + self.num_local_heads = end_head - start_head self.scaling = self.qk_head_dim**-0.5 self.rope_theta = rope_theta self.max_position_embeddings = max_position_embeddings @@ -222,14 +223,17 @@ def __init__( self.qk_head_dim, bias=False, quant_config=quant_config, - prefix=f"{prefix}.q_b_proj") + prefix=f"{prefix}.q_b_proj", + tp_chunk=(2 * self.qk_head_dim)) + #tp_chunk=self.qk_head_dim) else: self.q_proj = ColumnParallelLinear(self.hidden_size, self.num_heads * self.qk_head_dim, bias=False, quant_config=quant_config, - prefix=f"{prefix}.q_proj") + prefix=f"{prefix}.q_proj", + tp_chunk=(2*self.qk_head_dim)) self.kv_a_proj_with_mqa = ReplicatedLinear( self.hidden_size, @@ -244,13 +248,16 @@ def __init__( self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), bias=False, quant_config=quant_config, - prefix=f"{prefix}.kv_b_proj") + prefix=f"{prefix}.kv_b_proj", + tp_chunk=(2 * (self.qk_nope_head_dim + self.v_head_dim))) + #tp_chunk=(self.qk_nope_head_dim + self.v_head_dim)) # O projection. self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim, self.hidden_size, bias=False, quant_config=quant_config, - prefix=f"{prefix}.o_proj") + prefix=f"{prefix}.o_proj", + tp_chunk=(2*self.v_head_dim)) if rope_scaling: rope_scaling["rope_type"] = 'deepseek_yarn' self.use_normal_rope = False diff --git a/vllm/model_executor/parameter.py b/vllm/model_executor/parameter.py index a9ce8af15d3bb..1dfb7c2e20395 100644 --- a/vllm/model_executor/parameter.py +++ b/vllm/model_executor/parameter.py @@ -4,7 +4,7 @@ import torch from torch.nn import Parameter -from vllm.distributed import get_tensor_model_parallel_rank +from vllm.distributed import get_assigned_range from vllm.logger import init_logger from vllm.model_executor.utils import _make_synced_weight_loader @@ -98,15 +98,14 @@ def __init__(self, output_dim: int, **kwargs): def output_dim(self): return self._output_dim - def load_column_parallel_weight(self, loaded_weight: torch.Tensor): - tp_rank = get_tensor_model_parallel_rank() - shard_size = self.data.shape[self.output_dim] + def load_column_parallel_weight(self, loaded_weight: torch.Tensor, tp_chunk: int = 1): + start_idx, end_idx = get_assigned_range(loaded_weight.shape[self.output_dim], tp_chunk) loaded_weight = loaded_weight.narrow(self.output_dim, - tp_rank * shard_size, shard_size) + start_idx, end_idx - start_idx) assert self.data.shape == loaded_weight.shape self.data.copy_(loaded_weight) - def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs): + def load_merged_column_weight(self, loaded_weight: torch.Tensor, tp_chunk: int = 1, **kwargs): shard_offset = kwargs.get("shard_offset") shard_size = kwargs.get("shard_size") @@ -119,15 +118,15 @@ def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs): param_data = self.data - tp_rank = get_tensor_model_parallel_rank() + start_idx, end_idx = get_assigned_range(loaded_weight.shape[self.output_dim], tp_chunk) param_data = param_data.narrow(self.output_dim, shard_offset, shard_size) loaded_weight = loaded_weight.narrow(self.output_dim, - tp_rank * shard_size, shard_size) + start_idx, end_idx - start_idx) assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) - def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs): + def load_qkv_weight(self, loaded_weight: torch.Tensor, tp_chunk: int = 1, **kwargs): shard_offset = kwargs.get("shard_offset") shard_size = kwargs.get("shard_size") @@ -142,6 +141,7 @@ def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs): shard_offset=shard_offset, shard_size=shard_size) param_data = self.data + assert tp_chunk == 1 # This needs to be fixed. tp_rank = get_tensor_model_parallel_rank() shard_id = tp_rank if shard_id == "q" else tp_rank // num_heads param_data = param_data.narrow(self.output_dim, shard_offset, @@ -169,13 +169,13 @@ def __init__(self, input_dim: int, **kwargs): def input_dim(self): return self._input_dim - def load_row_parallel_weight(self, loaded_weight: torch.Tensor): - tp_rank = get_tensor_model_parallel_rank() + def load_row_parallel_weight(self, loaded_weight: torch.Tensor, tp_chunk: int = 1): + start_idx, end_idx = get_assigned_range(loaded_weight.shape[self.input_dim], tp_chunk) shard_size = self.data.shape[self.input_dim] loaded_weight = loaded_weight.narrow(self.input_dim, - tp_rank * shard_size, shard_size) + start_idx, end_idx - start_idx) - if len(loaded_weight.shape) == 0: + if len(loaded_weight.shape) == 0: # How can this every be hit? The narrow op would fail. loaded_weight = loaded_weight.reshape(1) assert self.data.shape == loaded_weight.shape diff --git a/vllm/worker/cache_engine.py b/vllm/worker/cache_engine.py index 08316ba74aad8..7475eee121ce8 100644 --- a/vllm/worker/cache_engine.py +++ b/vllm/worker/cache_engine.py @@ -5,6 +5,7 @@ from vllm.attention import get_attn_backend from vllm.config import CacheConfig, DeviceConfig, ModelConfig, ParallelConfig +from vllm.distributed import get_assigned_range, get_tensor_model_parallel_rank from vllm.logger import init_logger from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, LayerBlockType, get_dtype_size, is_pin_memory_available) @@ -36,7 +37,6 @@ def __init__( # Models like Jamba, have mixed typed layers, E.g Mamba self.num_attention_layers = model_config.get_num_layers_by_block_type( parallel_config, LayerBlockType.attention) - self.num_kv_heads = model_config.get_num_kv_heads(parallel_config) self.block_size = cache_config.block_size self.num_gpu_blocks = cache_config.num_gpu_blocks @@ -70,8 +70,10 @@ def _allocate_kv_cache( device: str, ) -> List[torch.Tensor]: """Allocates KV cache on the specified device.""" + total_num_kv_heads = self.model_config.get_total_num_kv_heads() + start_idx, end_idx = get_assigned_range(total_num_kv_heads, 2) kv_cache_shape = self.attn_backend.get_kv_cache_shape( - num_blocks, self.block_size, self.num_kv_heads, self.head_size) + num_blocks, self.block_size, end_idx - start_idx, self.head_size) pin_memory = is_pin_memory_available() if device == "cpu" else False kv_cache: List[torch.Tensor] = [] for _ in range(self.num_attention_layers):