Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions xtuner/v1/loss/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
ZLossContext,
ZLossKwargs,
)
from .mtp_loss import MTPLossContext


__all__ = [
Expand All @@ -29,6 +30,7 @@
"BaseLossContext",
"BaseLossKwargs",
"LMHeadLossContext",
"MTPLossContext",
]

import torch
Expand Down
92 changes: 92 additions & 0 deletions xtuner/v1/loss/mtp_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Copyright (c) OpenMMLab. All rights reserved.
from torch.distributed.device_mesh import DeviceMesh

from xtuner.v1.loss.ce_loss import CELossConfig, CELossKwargs, LMHeadLossContext
from xtuner.v1.module.mtp.utils import roll_packed_tensor
from xtuner.v1.utils.device import get_device


DEVICE = get_device()


class MTPLossKwargs(CELossKwargs):
"""Keyword arguments for MTP loss computation.

Inherits all fields from CELossKwargs. The ``shifted_labels`` field is
expected to be pre-rolled by ``MTPLossConfig.build()`` before this object
is constructed, so no additional fields are required.

Args:
shifted_labels (torch.Tensor): The shifted and rolled labels for MTP
loss computation.
loss_weight (torch.Tensor | None): Per-token loss weight.
"""


class MTPLossConfig(CELossConfig):
"""Loss configuration for Multi-Token Prediction (MTP).

Extends ``CELossConfig`` with a ``mtp_depth`` field that controls how many
additional positions the labels are rolled during ``build()``. This class
is intended for internal use by the model and is not exposed to users.

Args:
mtp_depth (int): 1-indexed MTP layer depth. The first MTP layer uses
``mtp_depth=1`` (shift=-1 on top of the existing label shift).
"""

mtp_depth: int

@property
def loss_ctx_cls(self) -> type["MTPLossContext"]:
return MTPLossContext

@property
def _loss_kwargs_cls(self) -> type["MTPLossKwargs"]:
return MTPLossKwargs

def build(self, data: dict, sp_mesh: DeviceMesh | None = None) -> "MTPLossContext | None":
"""Build MTPLossContext from data dict.

Rolls ``shifted_labels`` by ``-mtp_depth`` positions (per-sequence,
respecting packed-sequence boundaries) before constructing the loss
context. The roll is performed on the full sequence prior to any
sequence-parallel split so that boundary positions and ``cu_seq_lens``
are always consistent.

Args:
data (dict): Data dict containing loss-related fields.
Required keys: ``shifted_labels``, ``seq_ctx``.
sp_mesh (DeviceMesh | None): Sequence parallel mesh.

Returns:
MTPLossContext | None: Built loss context, or ``None`` if
``shifted_labels`` is not present in ``data``.
"""
if "shifted_labels" not in data:
return None

shifted_labels = data["shifted_labels"]
cu_seq_lens = data["seq_ctx"].cu_seq_lens_k

rolled = roll_packed_tensor(shifted_labels, cu_seq_lens, shifts=-self.mtp_depth, dim=-1, fill_value=-100)

loss_kwargs = MTPLossKwargs(shifted_labels=rolled).to(DEVICE)
if sp_mesh is not None and sp_mesh.size() > 1:
loss_kwargs = loss_kwargs.sp_split(sp_mesh)

return MTPLossContext(self, loss_kwargs)


class MTPLossContext(LMHeadLossContext):
"""Loss context for Multi-Token Prediction (MTP).

Inherits all computation logic from ``LMHeadLossContext``. The label
rolling is handled upstream in ``MTPLossConfig.build()``, so no override
is needed here.

Args:
loss_cfg (MTPLossConfig): The MTP loss configuration.
loss_kwargs (MTPLossKwargs): Pre-rolled keyword arguments for loss
computation.
"""
27 changes: 13 additions & 14 deletions xtuner/v1/model/moe/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,11 @@
BalancingLossContext,
BaseLossContext,
LMHeadLossContext,
MTPLossContext,
ZLossConfig,
ZLossContext,
)
from xtuner.v1.loss.mtp_loss import MTPLossConfig
from xtuner.v1.model.base import (
DEFAULT_FLOAT8_CFG,
BaseModel,
Expand All @@ -56,7 +58,7 @@
)
from xtuner.v1.module.decoder_layer.dense_decoder_layer import DenseDecoderLayer
from xtuner.v1.module.decoder_layer.moe_decoder_layer import MoEActFnConfig, MoEBlock, MoEDecoderLayer
from xtuner.v1.module.mtp import MTPBlock, MTPConfig, MTPLayer, roll_packed_tensor
from xtuner.v1.module.mtp import MTPBlock, MTPConfig, MTPLayer
from xtuner.v1.utils import (
get_device,
get_logger,
Expand Down Expand Up @@ -323,15 +325,17 @@ def build_loss_ctx_batch( # type: ignore[override]

# Add MTP loss contexts if MTP is enabled
if self.config.mtp_config is not None:
# Build MTP loss contexts using the same approach as LM loss
# Each MTP depth needs its own loss context
for mtp_idx in range(self.config.mtp_config.num_layers):
# MTP needs to shift labels multiple times. Since rebuild the `shifted_labels` in data_batch
mtp_loss_ctx_list = self._build_loss_ctx(self.config.lm_loss_cfg, _data_batch, sp_mesh)
mtp_loss_cfg = MTPLossConfig(
**self.config.lm_loss_cfg.model_dump(),
mtp_depth=mtp_idx + 1,
)
mtp_loss_ctx_list = self._build_loss_ctx(mtp_loss_cfg, _data_batch, sp_mesh)
if mtp_loss_ctx_list is not None:
loss_ctx_cls = mtp_loss_ctx_list[0].__class__
mtp_loss_ctx_list = loss_ctx_cls.build_batches(
mtp_loss_ctx_list, cu_seq_lens_list=cu_seq_lens_list, sp_mesh=sp_mesh
mtp_loss_ctx_list = MTPLossContext.build_batches( # type: ignore[assignment]
cast(list[MTPLossContext], mtp_loss_ctx_list), # type: ignore[arg-type]
cu_seq_lens_list=cu_seq_lens_list,
sp_mesh=sp_mesh,
)
for i, mtp_loss_ctx in enumerate(mtp_loss_ctx_list):
if "mtp" not in res[i]:
Expand Down Expand Up @@ -693,13 +697,8 @@ def _forward(
# Compute MTP losses for each depth
mtp_losses = torch.tensor(0.0, device=DEVICE)
for idx, (mtp_hidden, mtp_ctx) in enumerate(zip(mtp_outputs, mtp_loss_ctx_list)):
shifted_tensor = mtp_ctx.loss_kwargs.shifted_labels
mtp_ctx.loss_kwargs.shifted_labels = roll_packed_tensor(
shifted_tensor, seq_ctx.cu_seq_lens_k, -idx - 1, dim=-1, fill_value=-100
)

mtp_hidden_states, mtp_router_results, mtp_router_weights = mtp_hidden
mtp_loss, _ = self.lm_head(mtp_hidden_states, cast(LMHeadLossContext, mtp_ctx))
mtp_loss, _ = self.lm_head(mtp_hidden_states, cast(MTPLossContext, mtp_ctx))
mtp_losses += mtp_loss

output["router_logits"][f"mtp_layer{idx}"] = mtp_router_results
Expand Down
Loading