Files
2026-03-22 17:26:26 -04:00

78 lines
3.4 KiB
Diff

diff --git a/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py b/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py
index badedfc54..e05c0eea4 100644
--- a/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py
+++ b/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py
@@ -20,9 +20,16 @@ logger = init_logger(__name__)
if has_triton_kernels():
try:
import triton_kernels.swiglu
- from triton_kernels.matmul_ogs import FnSpecs, FusedActivation, matmul_ogs
- from triton_kernels.routing import RoutingData, routing, routing_from_bitmatrix
- from triton_kernels.tensor import Bitmatrix
+ from triton_kernels.matmul_ogs import (
+ FnSpecs,
+ FusedActivation,
+ GatherIndx,
+ RoutingData,
+ ScatterIndx,
+ matmul_ogs,
+ )
+ from triton_kernels.tensor import BIT, Bitmatrix, SparseMatrix, make_ragged_tensor_metadata
+ from triton_kernels.topk import topk as triton_topk
except (AttributeError, ImportError) as e:
logger.error(
"Failed to import Triton kernels. Please make sure your triton "
@@ -84,8 +91,17 @@ def triton_kernel_moe_forward(
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
) -> torch.Tensor:
- routing_data, gather_idx, scatter_idx = routing(
- gating_output, topk, sm_first=not renormalize
+ # Use new topk API instead of deprecated routing
+ sm_first = not renormalize
+ if sm_first:
+ gating_output = torch.softmax(gating_output, dim=-1)
+ sparse_logits = triton_topk(
+ gating_output, topk, apply_softmax=not sm_first, y_indx=None, n_rows=None
+ )
+
+ # Convert to routing data using the existing make_routing_data function
+ routing_data, gather_idx, scatter_idx = make_routing_data(
+ sparse_logits.indx, sparse_logits.vals, gating_output.shape[-1]
)
return triton_kernel_fused_experts(
@@ -202,14 +218,29 @@ def make_routing_data(
bitmatrix_shape = [n_rows, bm_cols * 32]
bitmatrix_shape_max = [n_rows, None]
bitmatrix = Bitmatrix(
- bitmatrix, shape=bitmatrix_shape, shape_max=bitmatrix_shape_max, scratchpad=None
+ bitmatrix, dtype=BIT, shape=bitmatrix_shape, shape_max=bitmatrix_shape_max
)
# matmul_ogs expects invalid topk_weights to be -1s
topk_weights = torch.where(topk_ids == -1, -1.0, topk_weights)
- routing_data, gather_indx, scatter_indx = routing_from_bitmatrix(
- bitmatrix, topk_weights, topk_ids, num_local_experts, num_topk
+
+ # Use new SparseMatrix API instead of deprecated routing_from_bitmatrix
+ sparse_logits = SparseMatrix(indx=topk_ids, vals=topk_weights, mask=bitmatrix)
+ dispatch_indx = sparse_logits.mask_metadata.col_sorted_indx
+ combine_indx = sparse_logits.mask_metadata.row_sorted_indx
+ ragged_batch_metadata = make_ragged_tensor_metadata(
+ sparse_logits.mask_metadata.col_sum, dispatch_indx.shape[0]
+ )
+ gate_scal = sparse_logits.vals.flatten()[combine_indx]
+ routing_data = RoutingData(
+ gate_scal,
+ ragged_batch_metadata.block_sizes,
+ num_local_experts,
+ num_topk,
+ ragged_batch_metadata,
)
+ gather_indx = GatherIndx(combine_indx, dispatch_indx)
+ scatter_indx = ScatterIndx(dispatch_indx, combine_indx)
return routing_data, gather_indx, scatter_indx