first commit
This commit is contained in:
77
patches/gpt_oss_triton_moe.patch
Normal file
77
patches/gpt_oss_triton_moe.patch
Normal file
@@ -0,0 +1,77 @@
|
||||
diff --git a/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py b/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py
|
||||
index badedfc54..e05c0eea4 100644
|
||||
--- a/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py
|
||||
+++ b/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py
|
||||
@@ -20,9 +20,16 @@ logger = init_logger(__name__)
|
||||
if has_triton_kernels():
|
||||
try:
|
||||
import triton_kernels.swiglu
|
||||
- from triton_kernels.matmul_ogs import FnSpecs, FusedActivation, matmul_ogs
|
||||
- from triton_kernels.routing import RoutingData, routing, routing_from_bitmatrix
|
||||
- from triton_kernels.tensor import Bitmatrix
|
||||
+ from triton_kernels.matmul_ogs import (
|
||||
+ FnSpecs,
|
||||
+ FusedActivation,
|
||||
+ GatherIndx,
|
||||
+ RoutingData,
|
||||
+ ScatterIndx,
|
||||
+ matmul_ogs,
|
||||
+ )
|
||||
+ from triton_kernels.tensor import BIT, Bitmatrix, SparseMatrix, make_ragged_tensor_metadata
|
||||
+ from triton_kernels.topk import topk as triton_topk
|
||||
except (AttributeError, ImportError) as e:
|
||||
logger.error(
|
||||
"Failed to import Triton kernels. Please make sure your triton "
|
||||
@@ -84,8 +91,17 @@ def triton_kernel_moe_forward(
|
||||
global_num_experts: int = -1,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
- routing_data, gather_idx, scatter_idx = routing(
|
||||
- gating_output, topk, sm_first=not renormalize
|
||||
+ # Use new topk API instead of deprecated routing
|
||||
+ sm_first = not renormalize
|
||||
+ if sm_first:
|
||||
+ gating_output = torch.softmax(gating_output, dim=-1)
|
||||
+ sparse_logits = triton_topk(
|
||||
+ gating_output, topk, apply_softmax=not sm_first, y_indx=None, n_rows=None
|
||||
+ )
|
||||
+
|
||||
+ # Convert to routing data using the existing make_routing_data function
|
||||
+ routing_data, gather_idx, scatter_idx = make_routing_data(
|
||||
+ sparse_logits.indx, sparse_logits.vals, gating_output.shape[-1]
|
||||
)
|
||||
|
||||
return triton_kernel_fused_experts(
|
||||
@@ -202,14 +218,29 @@ def make_routing_data(
|
||||
bitmatrix_shape = [n_rows, bm_cols * 32]
|
||||
bitmatrix_shape_max = [n_rows, None]
|
||||
bitmatrix = Bitmatrix(
|
||||
- bitmatrix, shape=bitmatrix_shape, shape_max=bitmatrix_shape_max, scratchpad=None
|
||||
+ bitmatrix, dtype=BIT, shape=bitmatrix_shape, shape_max=bitmatrix_shape_max
|
||||
)
|
||||
|
||||
# matmul_ogs expects invalid topk_weights to be -1s
|
||||
topk_weights = torch.where(topk_ids == -1, -1.0, topk_weights)
|
||||
- routing_data, gather_indx, scatter_indx = routing_from_bitmatrix(
|
||||
- bitmatrix, topk_weights, topk_ids, num_local_experts, num_topk
|
||||
+
|
||||
+ # Use new SparseMatrix API instead of deprecated routing_from_bitmatrix
|
||||
+ sparse_logits = SparseMatrix(indx=topk_ids, vals=topk_weights, mask=bitmatrix)
|
||||
+ dispatch_indx = sparse_logits.mask_metadata.col_sorted_indx
|
||||
+ combine_indx = sparse_logits.mask_metadata.row_sorted_indx
|
||||
+ ragged_batch_metadata = make_ragged_tensor_metadata(
|
||||
+ sparse_logits.mask_metadata.col_sum, dispatch_indx.shape[0]
|
||||
+ )
|
||||
+ gate_scal = sparse_logits.vals.flatten()[combine_indx]
|
||||
+ routing_data = RoutingData(
|
||||
+ gate_scal,
|
||||
+ ragged_batch_metadata.block_sizes,
|
||||
+ num_local_experts,
|
||||
+ num_topk,
|
||||
+ ragged_batch_metadata,
|
||||
)
|
||||
+ gather_indx = GatherIndx(combine_indx, dispatch_indx)
|
||||
+ scatter_indx = ScatterIndx(dispatch_indx, combine_indx)
|
||||
|
||||
return routing_data, gather_indx, scatter_indx
|
||||
|
||||
Reference in New Issue
Block a user