diff --git a/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py b/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py index badedfc54..e05c0eea4 100644 --- a/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py +++ b/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py @@ -20,9 +20,16 @@ logger = init_logger(__name__) if has_triton_kernels(): try: import triton_kernels.swiglu - from triton_kernels.matmul_ogs import FnSpecs, FusedActivation, matmul_ogs - from triton_kernels.routing import RoutingData, routing, routing_from_bitmatrix - from triton_kernels.tensor import Bitmatrix + from triton_kernels.matmul_ogs import ( + FnSpecs, + FusedActivation, + GatherIndx, + RoutingData, + ScatterIndx, + matmul_ogs, + ) + from triton_kernels.tensor import BIT, Bitmatrix, SparseMatrix, make_ragged_tensor_metadata + from triton_kernels.topk import topk as triton_topk except (AttributeError, ImportError) as e: logger.error( "Failed to import Triton kernels. Please make sure your triton " @@ -84,8 +91,17 @@ def triton_kernel_moe_forward( global_num_experts: int = -1, expert_map: torch.Tensor | None = None, ) -> torch.Tensor: - routing_data, gather_idx, scatter_idx = routing( - gating_output, topk, sm_first=not renormalize + # Use new topk API instead of deprecated routing + sm_first = not renormalize + if sm_first: + gating_output = torch.softmax(gating_output, dim=-1) + sparse_logits = triton_topk( + gating_output, topk, apply_softmax=not sm_first, y_indx=None, n_rows=None + ) + + # Convert to routing data using the existing make_routing_data function + routing_data, gather_idx, scatter_idx = make_routing_data( + sparse_logits.indx, sparse_logits.vals, gating_output.shape[-1] ) return triton_kernel_fused_experts( @@ -202,14 +218,29 @@ def make_routing_data( bitmatrix_shape = [n_rows, bm_cols * 32] bitmatrix_shape_max = [n_rows, None] bitmatrix = Bitmatrix( - bitmatrix, shape=bitmatrix_shape, shape_max=bitmatrix_shape_max, scratchpad=None + bitmatrix, dtype=BIT, shape=bitmatrix_shape, shape_max=bitmatrix_shape_max ) # matmul_ogs expects invalid topk_weights to be -1s topk_weights = torch.where(topk_ids == -1, -1.0, topk_weights) - routing_data, gather_indx, scatter_indx = routing_from_bitmatrix( - bitmatrix, topk_weights, topk_ids, num_local_experts, num_topk + + # Use new SparseMatrix API instead of deprecated routing_from_bitmatrix + sparse_logits = SparseMatrix(indx=topk_ids, vals=topk_weights, mask=bitmatrix) + dispatch_indx = sparse_logits.mask_metadata.col_sorted_indx + combine_indx = sparse_logits.mask_metadata.row_sorted_indx + ragged_batch_metadata = make_ragged_tensor_metadata( + sparse_logits.mask_metadata.col_sum, dispatch_indx.shape[0] + ) + gate_scal = sparse_logits.vals.flatten()[combine_indx] + routing_data = RoutingData( + gate_scal, + ragged_batch_metadata.block_sizes, + num_local_experts, + num_topk, + ragged_batch_metadata, ) + gather_indx = GatherIndx(combine_indx, dispatch_indx) + scatter_indx = ScatterIndx(dispatch_indx, combine_indx) return routing_data, gather_indx, scatter_indx