78 lines
3.4 KiB
Diff
78 lines
3.4 KiB
Diff
diff --git a/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py b/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py
|
|
index badedfc54..e05c0eea4 100644
|
|
--- a/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py
|
|
+++ b/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py
|
|
@@ -20,9 +20,16 @@ logger = init_logger(__name__)
|
|
if has_triton_kernels():
|
|
try:
|
|
import triton_kernels.swiglu
|
|
- from triton_kernels.matmul_ogs import FnSpecs, FusedActivation, matmul_ogs
|
|
- from triton_kernels.routing import RoutingData, routing, routing_from_bitmatrix
|
|
- from triton_kernels.tensor import Bitmatrix
|
|
+ from triton_kernels.matmul_ogs import (
|
|
+ FnSpecs,
|
|
+ FusedActivation,
|
|
+ GatherIndx,
|
|
+ RoutingData,
|
|
+ ScatterIndx,
|
|
+ matmul_ogs,
|
|
+ )
|
|
+ from triton_kernels.tensor import BIT, Bitmatrix, SparseMatrix, make_ragged_tensor_metadata
|
|
+ from triton_kernels.topk import topk as triton_topk
|
|
except (AttributeError, ImportError) as e:
|
|
logger.error(
|
|
"Failed to import Triton kernels. Please make sure your triton "
|
|
@@ -84,8 +91,17 @@ def triton_kernel_moe_forward(
|
|
global_num_experts: int = -1,
|
|
expert_map: torch.Tensor | None = None,
|
|
) -> torch.Tensor:
|
|
- routing_data, gather_idx, scatter_idx = routing(
|
|
- gating_output, topk, sm_first=not renormalize
|
|
+ # Use new topk API instead of deprecated routing
|
|
+ sm_first = not renormalize
|
|
+ if sm_first:
|
|
+ gating_output = torch.softmax(gating_output, dim=-1)
|
|
+ sparse_logits = triton_topk(
|
|
+ gating_output, topk, apply_softmax=not sm_first, y_indx=None, n_rows=None
|
|
+ )
|
|
+
|
|
+ # Convert to routing data using the existing make_routing_data function
|
|
+ routing_data, gather_idx, scatter_idx = make_routing_data(
|
|
+ sparse_logits.indx, sparse_logits.vals, gating_output.shape[-1]
|
|
)
|
|
|
|
return triton_kernel_fused_experts(
|
|
@@ -202,14 +218,29 @@ def make_routing_data(
|
|
bitmatrix_shape = [n_rows, bm_cols * 32]
|
|
bitmatrix_shape_max = [n_rows, None]
|
|
bitmatrix = Bitmatrix(
|
|
- bitmatrix, shape=bitmatrix_shape, shape_max=bitmatrix_shape_max, scratchpad=None
|
|
+ bitmatrix, dtype=BIT, shape=bitmatrix_shape, shape_max=bitmatrix_shape_max
|
|
)
|
|
|
|
# matmul_ogs expects invalid topk_weights to be -1s
|
|
topk_weights = torch.where(topk_ids == -1, -1.0, topk_weights)
|
|
- routing_data, gather_indx, scatter_indx = routing_from_bitmatrix(
|
|
- bitmatrix, topk_weights, topk_ids, num_local_experts, num_topk
|
|
+
|
|
+ # Use new SparseMatrix API instead of deprecated routing_from_bitmatrix
|
|
+ sparse_logits = SparseMatrix(indx=topk_ids, vals=topk_weights, mask=bitmatrix)
|
|
+ dispatch_indx = sparse_logits.mask_metadata.col_sorted_indx
|
|
+ combine_indx = sparse_logits.mask_metadata.row_sorted_indx
|
|
+ ragged_batch_metadata = make_ragged_tensor_metadata(
|
|
+ sparse_logits.mask_metadata.col_sum, dispatch_indx.shape[0]
|
|
+ )
|
|
+ gate_scal = sparse_logits.vals.flatten()[combine_indx]
|
|
+ routing_data = RoutingData(
|
|
+ gate_scal,
|
|
+ ragged_batch_metadata.block_sizes,
|
|
+ num_local_experts,
|
|
+ num_topk,
|
|
+ ragged_batch_metadata,
|
|
)
|
|
+ gather_indx = GatherIndx(combine_indx, dispatch_indx)
|
|
+ scatter_indx = ScatterIndx(dispatch_indx, combine_indx)
|
|
|
|
return routing_data, gather_indx, scatter_indx
|
|
|