Skip to content

add ltx2 vae in sana-video;#13229

Merged
yiyixuxu merged 12 commits intohuggingface:mainfrom
lawrence-cj:sana-video-ltx2vae
Mar 18, 2026
Merged

add ltx2 vae in sana-video;#13229
yiyixuxu merged 12 commits intohuggingface:mainfrom
lawrence-cj:sana-video-ltx2vae

Conversation

@lawrence-cj
Copy link
Contributor

@lawrence-cj lawrence-cj commented Mar 9, 2026

This PR adds LTX-VAE support for SANA-Video.

Cc: @dg845 @sayakpaul

GPU memory needed: 47GB for LTX refiner

SANA-Video with LTX2-Refiner:

"""Sana Video + LTX2 Refiner: Stage 1 generate latent → Stage 2 refine (3 steps)."""

import gc
import torch
from diffusers import SanaVideoPipeline, FlowMatchEulerDiscreteScheduler
from diffusers.pipelines.ltx2 import LTX2Pipeline, LTX2LatentUpsamplePipeline
from diffusers.pipelines.ltx2.latent_upsampler import LTX2LatentUpsamplerModel
from diffusers.pipelines.ltx2.utils import STAGE_2_DISTILLED_SIGMA_VALUES
from diffusers.pipelines.ltx2.export_utils import encode_video

device = "cuda"
dtype = torch.bfloat16
prompt = "A cat walking on the grass, facing the camera."
negative_prompt = "A chaotic sequence with misshapen, deformed limbs in heavy motion blur, sudden disappearance, jump cuts, jerky movements, rapid shot changes, frames out of sync, inconsistent character shapes, temporal artifacts, jitter, and ghosting effects, creating a disorienting visual experience."
motion_score = 30
height, width, frames, frame_rate = 704, 1280, 81, 16.0
seed = 42

# ── Load all models ──
sana_pipe = SanaVideoPipeline.from_pretrained(
    "Efficient-Large-Model/SANA-Video_2B_720p_diffusers", torch_dtype=dtype,
)
sana_pipe.text_encoder.to(dtype)
sana_pipe.enable_model_cpu_offload()

ltx_pipe = LTX2Pipeline.from_pretrained("Lightricks/LTX-2", torch_dtype=dtype)
ltx_pipe.load_lora_weights(
    "Lightricks/LTX-2", adapter_name="stage_2_distilled",
    weight_name="ltx-2-19b-distilled-lora-384.safetensors",
)
ltx_pipe.set_adapters("stage_2_distilled", 1.0)
ltx_pipe.vae.enable_tiling()
ltx_pipe.enable_model_cpu_offload()

latent_upsampler = LTX2LatentUpsamplerModel.from_pretrained(
    "Lightricks/LTX-2", subfolder="latent_upsampler", torch_dtype=dtype,
)
upsample_pipe = LTX2LatentUpsamplePipeline(vae=ltx_pipe.vae, latent_upsampler=latent_upsampler)
upsample_pipe.enable_model_cpu_offload(device=device)

# ── Stage 1: Sana Video ──
video_latent = sana_pipe(
    prompt=prompt + f" motion score: {motion_score}.", negative_prompt=negative_prompt,
    height=height, width=width, frames=frames,
    guidance_scale=6.0, num_inference_steps=50,
    generator=torch.Generator(device=device).manual_seed(seed),
    output_type="latent", return_dict=True,
).frames

del sana_pipe; gc.collect(); torch.cuda.empty_cache()

# ── Stage 1.5: Latent Upsample (2x spatial) ──
video_latent = upsample_pipe(
    latents=video_latent.to(device=device, dtype=dtype),
    latents_normalized=True,
    height=height, width=width, num_frames=frames,
    output_type="latent", return_dict=False,
)[0]
latents_mean = ltx_pipe.vae.latents_mean.view(1, -1, 1, 1, 1).to(video_latent.device, video_latent.dtype)
latents_std = ltx_pipe.vae.latents_std.view(1, -1, 1, 1, 1).to(video_latent.device, video_latent.dtype)
video_latent = (video_latent - latents_mean) * ltx_pipe.vae.config.scaling_factor / latents_std

# ── Stage 2: LTX2 Refine ──
packed = LTX2Pipeline._pack_latents(
    video_latent.to(device=device, dtype=dtype),
    patch_size=ltx_pipe.transformer_spatial_patch_size,
    patch_size_t=ltx_pipe.transformer_temporal_patch_size,
)
_, _, lF, lH, lW = video_latent.shape
pH, pW, pT = (
    lH * ltx_pipe.vae_spatial_compression_ratio,
    lW * ltx_pipe.vae_spatial_compression_ratio,
    (lF - 1) * ltx_pipe.vae_temporal_compression_ratio + 1,
)

del video_latent
gc.collect()
torch.cuda.empty_cache()

video, _audio = ltx_pipe(
    latents=packed,
    prompt=prompt,
    negative_prompt=negative_prompt,
    height=pH,
    width=pW,
    num_frames=pT,
    num_inference_steps=3,
    noise_scale=STAGE_2_DISTILLED_SIGMA_VALUES[0],
    sigmas=STAGE_2_DISTILLED_SIGMA_VALUES,
    guidance_scale=1.0,
    frame_rate=frame_rate,
    generator=torch.Generator(device=device).manual_seed(seed),
    output_type="np",
    return_dict=False,
)

video = torch.from_numpy((video * 255).round().astype("uint8"))
encode_video(
    video[0],
    fps=frame_rate,
    audio=None,
    audio_sample_rate=None,
    output_path=args.output_path,
)

Result

sana_ltx_refined.mp4

@sayakpaul
Copy link
Member

@lawrence-cj thanks for the PR! Could you also provide some sample outputs?

Comment on lines +226 to +244
if getattr(self, "vae", None):
if hasattr(self.vae.config, "scale_factor_temporal"):
self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal
elif hasattr(self.vae.config, "temporal_compression_ratio"):
# LTX2 VAE uses temporal_compression_ratio
self.vae_scale_factor_temporal = self.vae.config.temporal_compression_ratio
else:
self.vae_scale_factor_temporal = getattr(self.vae, "temporal_compression_ratio", 4)

if hasattr(self.vae.config, "scale_factor_spatial"):
self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial
elif hasattr(self.vae.config, "spatial_compression_ratio"):
# LTX2 VAE uses spatial_compression_ratio
self.vae_scale_factor_spatial = self.vae.config.spatial_compression_ratio
else:
self.vae_scale_factor_spatial = getattr(self.vae, "spatial_compression_ratio", 8)
else:
self.vae_scale_factor_temporal = 4
self.vae_scale_factor_spatial = 8
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, should this be conditioned on the class type of the VAE being used?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fix it.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I just left one comment. But it looks good to me.

@sayakpaul sayakpaul requested a review from dg845 March 10, 2026 03:19
@lawrence-cj
Copy link
Contributor Author

lawrence-cj commented Mar 10, 2026

Could you also provide some sample outputs?

Updated code and result.

@sayakpaul @dg845

Copy link
Collaborator

@dg845 dg845 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR! The code looks good to me. However, running the example script doesn't work for me because I don't have access to the Sana_video/safetensors/sana_ltxvae_sft checkpoint. Would it be possible to provide a checkpoint for testing?

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@lawrence-cj
Copy link
Contributor Author

Hi @dg845 , the repo-id is updated: Efficient-Large-Model/SANA-Video_2B_720p_diffusers

Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the PR!
I think the code would be clear if we just add a new code path based on the type of vae

if getattr(self, "vae", None):
     if isinstance(self.vae, AutoencoderKLLTX2Vide):
        self.vae_scale_factor_temporal = self.vae.config.temporal_compression_ratio
        self.vae_scale_factor_spatial = self.vae.config.spatial_compression_ratio
    elif isinstance(self.vae,  (AutoencoderDC, AutoencoderKLWan):
       # current code

similar for the latents_mean/latents_std

Copy link
Collaborator

@dg845 dg845 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should update the vae type annotation in SanaVideoPipeline.__init__ to reflect the fact that LTX-2 VAEs are now supported (this will also fix a runtime warning):

class SanaVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
    def __init__(
        ...
        vae: AutoencoderDC | AutoencoderKLWan | AutoencoderKLLTX2Video,
        ...
    ):
        ...

I also agree with #13229 (review) that the code would be more clear if we split the logic based on the VAE class. It may make the code less general but I think the improved readability is worth it unless it is necessary to support arbitrary VAEs (if so, the vae type annotation should reflect this).

@dg845
Copy link
Collaborator

dg845 commented Mar 18, 2026

As an aside, I think the example script could be simplified by letting the LTX-2 refinement pipeline sample audio_latents internally from the Gaussian prior:

...
# ── Stage 2: LTX2 Refine ──
packed = LTX2Pipeline._pack_latents(
    video_latent.to(device=device, dtype=dtype),
    patch_size=ltx_pipe.transformer_spatial_patch_size,
    patch_size_t=ltx_pipe.transformer_temporal_patch_size,
)
_, _, lF, lH, lW = video_latent.shape
pH, pW, pT = (
    lH * ltx_pipe.vae_spatial_compression_ratio,
    lW * ltx_pipe.vae_spatial_compression_ratio,
    (lF - 1) * ltx_pipe.vae_temporal_compression_ratio + 1,
)

del video_latent
gc.collect()
torch.cuda.empty_cache()

# Let audio_latents take on its default value of `None` so that latents are sampled from the prior
video, audio = ltx_pipe(
    latents=packed,
    prompt=prompt,
    negative_prompt=negative_prompt,
    height=pH,
    width=pW,
    num_frames=pT,
    num_inference_steps=3,
    noise_scale=STAGE_2_DISTILLED_SIGMA_VALUES[0],
    sigmas=STAGE_2_DISTILLED_SIGMA_VALUES,
    guidance_scale=1.0,
    frame_rate=frame_rate,
    generator=torch.Generator(device=device).manual_seed(seed),
    output_type="np",
    return_dict=False,
)

video = torch.from_numpy((video * 255).round().astype("uint8"))
encode_video(
    video[0],
    fps=frame_rate,
    audio=audio[0].float().cpu(),
    audio_sample_rate=ltx_pipe.vocoder.config.output_sampling_rate,
    output_path="sana_ltx2_refined_audio.mp4",
)

The resulting generated audio is reasonably good:

sana_ltx2_refined_audio.mp4

@lawrence-cj
Copy link
Contributor Author

@dg845 thanks! Now the code logi depends on VAE type

Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks!

@yiyixuxu yiyixuxu merged commit c6f72ad into huggingface:main Mar 18, 2026
11 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants