diff --git a/xtuner/v1/loss/__init__.py b/xtuner/v1/loss/__init__.py index e54e39c54..513a2d15e 100644 --- a/xtuner/v1/loss/__init__.py +++ b/xtuner/v1/loss/__init__.py @@ -11,6 +11,7 @@ ZLossContext, ZLossKwargs, ) +from .mtp_loss import MTPLossContext __all__ = [ @@ -29,6 +30,7 @@ "BaseLossContext", "BaseLossKwargs", "LMHeadLossContext", + "MTPLossContext", ] import torch diff --git a/xtuner/v1/loss/mtp_loss.py b/xtuner/v1/loss/mtp_loss.py new file mode 100644 index 000000000..cce8db7b8 --- /dev/null +++ b/xtuner/v1/loss/mtp_loss.py @@ -0,0 +1,92 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from torch.distributed.device_mesh import DeviceMesh + +from xtuner.v1.loss.ce_loss import CELossConfig, CELossKwargs, LMHeadLossContext +from xtuner.v1.module.mtp.utils import roll_packed_tensor +from xtuner.v1.utils.device import get_device + + +DEVICE = get_device() + + +class MTPLossKwargs(CELossKwargs): + """Keyword arguments for MTP loss computation. + + Inherits all fields from CELossKwargs. The ``shifted_labels`` field is + expected to be pre-rolled by ``MTPLossConfig.build()`` before this object + is constructed, so no additional fields are required. + + Args: + shifted_labels (torch.Tensor): The shifted and rolled labels for MTP + loss computation. + loss_weight (torch.Tensor | None): Per-token loss weight. + """ + + +class MTPLossConfig(CELossConfig): + """Loss configuration for Multi-Token Prediction (MTP). + + Extends ``CELossConfig`` with a ``mtp_depth`` field that controls how many + additional positions the labels are rolled during ``build()``. This class + is intended for internal use by the model and is not exposed to users. + + Args: + mtp_depth (int): 1-indexed MTP layer depth. The first MTP layer uses + ``mtp_depth=1`` (shift=-1 on top of the existing label shift). + """ + + mtp_depth: int + + @property + def loss_ctx_cls(self) -> type["MTPLossContext"]: + return MTPLossContext + + @property + def _loss_kwargs_cls(self) -> type["MTPLossKwargs"]: + return MTPLossKwargs + + def build(self, data: dict, sp_mesh: DeviceMesh | None = None) -> "MTPLossContext | None": + """Build MTPLossContext from data dict. + + Rolls ``shifted_labels`` by ``-mtp_depth`` positions (per-sequence, + respecting packed-sequence boundaries) before constructing the loss + context. The roll is performed on the full sequence prior to any + sequence-parallel split so that boundary positions and ``cu_seq_lens`` + are always consistent. + + Args: + data (dict): Data dict containing loss-related fields. + Required keys: ``shifted_labels``, ``seq_ctx``. + sp_mesh (DeviceMesh | None): Sequence parallel mesh. + + Returns: + MTPLossContext | None: Built loss context, or ``None`` if + ``shifted_labels`` is not present in ``data``. + """ + if "shifted_labels" not in data: + return None + + shifted_labels = data["shifted_labels"] + cu_seq_lens = data["seq_ctx"].cu_seq_lens_k + + rolled = roll_packed_tensor(shifted_labels, cu_seq_lens, shifts=-self.mtp_depth, dim=-1, fill_value=-100) + + loss_kwargs = MTPLossKwargs(shifted_labels=rolled).to(DEVICE) + if sp_mesh is not None and sp_mesh.size() > 1: + loss_kwargs = loss_kwargs.sp_split(sp_mesh) + + return MTPLossContext(self, loss_kwargs) + + +class MTPLossContext(LMHeadLossContext): + """Loss context for Multi-Token Prediction (MTP). + + Inherits all computation logic from ``LMHeadLossContext``. The label + rolling is handled upstream in ``MTPLossConfig.build()``, so no override + is needed here. + + Args: + loss_cfg (MTPLossConfig): The MTP loss configuration. + loss_kwargs (MTPLossKwargs): Pre-rolled keyword arguments for loss + computation. + """ diff --git a/xtuner/v1/model/moe/moe.py b/xtuner/v1/model/moe/moe.py index 758d01422..faaa0c9eb 100644 --- a/xtuner/v1/model/moe/moe.py +++ b/xtuner/v1/model/moe/moe.py @@ -30,9 +30,11 @@ BalancingLossContext, BaseLossContext, LMHeadLossContext, + MTPLossContext, ZLossConfig, ZLossContext, ) +from xtuner.v1.loss.mtp_loss import MTPLossConfig from xtuner.v1.model.base import ( DEFAULT_FLOAT8_CFG, BaseModel, @@ -56,7 +58,7 @@ ) from xtuner.v1.module.decoder_layer.dense_decoder_layer import DenseDecoderLayer from xtuner.v1.module.decoder_layer.moe_decoder_layer import MoEActFnConfig, MoEBlock, MoEDecoderLayer -from xtuner.v1.module.mtp import MTPBlock, MTPConfig, MTPLayer, roll_packed_tensor +from xtuner.v1.module.mtp import MTPBlock, MTPConfig, MTPLayer from xtuner.v1.utils import ( get_device, get_logger, @@ -323,15 +325,17 @@ def build_loss_ctx_batch( # type: ignore[override] # Add MTP loss contexts if MTP is enabled if self.config.mtp_config is not None: - # Build MTP loss contexts using the same approach as LM loss - # Each MTP depth needs its own loss context for mtp_idx in range(self.config.mtp_config.num_layers): - # MTP needs to shift labels multiple times. Since rebuild the `shifted_labels` in data_batch - mtp_loss_ctx_list = self._build_loss_ctx(self.config.lm_loss_cfg, _data_batch, sp_mesh) + mtp_loss_cfg = MTPLossConfig( + **self.config.lm_loss_cfg.model_dump(), + mtp_depth=mtp_idx + 1, + ) + mtp_loss_ctx_list = self._build_loss_ctx(mtp_loss_cfg, _data_batch, sp_mesh) if mtp_loss_ctx_list is not None: - loss_ctx_cls = mtp_loss_ctx_list[0].__class__ - mtp_loss_ctx_list = loss_ctx_cls.build_batches( - mtp_loss_ctx_list, cu_seq_lens_list=cu_seq_lens_list, sp_mesh=sp_mesh + mtp_loss_ctx_list = MTPLossContext.build_batches( # type: ignore[assignment] + cast(list[MTPLossContext], mtp_loss_ctx_list), # type: ignore[arg-type] + cu_seq_lens_list=cu_seq_lens_list, + sp_mesh=sp_mesh, ) for i, mtp_loss_ctx in enumerate(mtp_loss_ctx_list): if "mtp" not in res[i]: @@ -693,13 +697,8 @@ def _forward( # Compute MTP losses for each depth mtp_losses = torch.tensor(0.0, device=DEVICE) for idx, (mtp_hidden, mtp_ctx) in enumerate(zip(mtp_outputs, mtp_loss_ctx_list)): - shifted_tensor = mtp_ctx.loss_kwargs.shifted_labels - mtp_ctx.loss_kwargs.shifted_labels = roll_packed_tensor( - shifted_tensor, seq_ctx.cu_seq_lens_k, -idx - 1, dim=-1, fill_value=-100 - ) - mtp_hidden_states, mtp_router_results, mtp_router_weights = mtp_hidden - mtp_loss, _ = self.lm_head(mtp_hidden_states, cast(LMHeadLossContext, mtp_ctx)) + mtp_loss, _ = self.lm_head(mtp_hidden_states, cast(MTPLossContext, mtp_ctx)) mtp_losses += mtp_loss output["router_logits"][f"mtp_layer{idx}"] = mtp_router_results