add ltx2 vae in sana-video;#13229
Conversation
5db0b20 to
c03b739
Compare
|
@lawrence-cj thanks for the PR! Could you also provide some sample outputs? |
| if getattr(self, "vae", None): | ||
| if hasattr(self.vae.config, "scale_factor_temporal"): | ||
| self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal | ||
| elif hasattr(self.vae.config, "temporal_compression_ratio"): | ||
| # LTX2 VAE uses temporal_compression_ratio | ||
| self.vae_scale_factor_temporal = self.vae.config.temporal_compression_ratio | ||
| else: | ||
| self.vae_scale_factor_temporal = getattr(self.vae, "temporal_compression_ratio", 4) | ||
|
|
||
| if hasattr(self.vae.config, "scale_factor_spatial"): | ||
| self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial | ||
| elif hasattr(self.vae.config, "spatial_compression_ratio"): | ||
| # LTX2 VAE uses spatial_compression_ratio | ||
| self.vae_scale_factor_spatial = self.vae.config.spatial_compression_ratio | ||
| else: | ||
| self.vae_scale_factor_spatial = getattr(self.vae, "spatial_compression_ratio", 8) | ||
| else: | ||
| self.vae_scale_factor_temporal = 4 | ||
| self.vae_scale_factor_spatial = 8 |
There was a problem hiding this comment.
Hmm, should this be conditioned on the class type of the VAE being used?
sayakpaul
left a comment
There was a problem hiding this comment.
Thanks, I just left one comment. But it looks good to me.
Updated code and result. |
dg845
left a comment
There was a problem hiding this comment.
Thanks for the PR! The code looks good to me. However, running the example script doesn't work for me because I don't have access to the Sana_video/safetensors/sana_ltxvae_sft checkpoint. Would it be possible to provide a checkpoint for testing?
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
Hi @dg845 , the repo-id is updated: |
yiyixuxu
left a comment
There was a problem hiding this comment.
thanks for the PR!
I think the code would be clear if we just add a new code path based on the type of vae
if getattr(self, "vae", None):
if isinstance(self.vae, AutoencoderKLLTX2Vide):
self.vae_scale_factor_temporal = self.vae.config.temporal_compression_ratio
self.vae_scale_factor_spatial = self.vae.config.spatial_compression_ratio
elif isinstance(self.vae, (AutoencoderDC, AutoencoderKLWan):
# current codesimilar for the latents_mean/latents_std
Co-authored-by: YiYi Xu <yixu310@gmail.com>
Co-authored-by: YiYi Xu <yixu310@gmail.com>
dg845
left a comment
There was a problem hiding this comment.
I think we should update the vae type annotation in SanaVideoPipeline.__init__ to reflect the fact that LTX-2 VAEs are now supported (this will also fix a runtime warning):
class SanaVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
def __init__(
...
vae: AutoencoderDC | AutoencoderKLWan | AutoencoderKLLTX2Video,
...
):
...I also agree with #13229 (review) that the code would be more clear if we split the logic based on the VAE class. It may make the code less general but I think the improved readability is worth it unless it is necessary to support arbitrary VAEs (if so, the vae type annotation should reflect this).
|
As an aside, I think the example script could be simplified by letting the LTX-2 refinement pipeline sample ...
# ── Stage 2: LTX2 Refine ──
packed = LTX2Pipeline._pack_latents(
video_latent.to(device=device, dtype=dtype),
patch_size=ltx_pipe.transformer_spatial_patch_size,
patch_size_t=ltx_pipe.transformer_temporal_patch_size,
)
_, _, lF, lH, lW = video_latent.shape
pH, pW, pT = (
lH * ltx_pipe.vae_spatial_compression_ratio,
lW * ltx_pipe.vae_spatial_compression_ratio,
(lF - 1) * ltx_pipe.vae_temporal_compression_ratio + 1,
)
del video_latent
gc.collect()
torch.cuda.empty_cache()
# Let audio_latents take on its default value of `None` so that latents are sampled from the prior
video, audio = ltx_pipe(
latents=packed,
prompt=prompt,
negative_prompt=negative_prompt,
height=pH,
width=pW,
num_frames=pT,
num_inference_steps=3,
noise_scale=STAGE_2_DISTILLED_SIGMA_VALUES[0],
sigmas=STAGE_2_DISTILLED_SIGMA_VALUES,
guidance_scale=1.0,
frame_rate=frame_rate,
generator=torch.Generator(device=device).manual_seed(seed),
output_type="np",
return_dict=False,
)
video = torch.from_numpy((video * 255).round().astype("uint8"))
encode_video(
video[0],
fps=frame_rate,
audio=audio[0].float().cpu(),
audio_sample_rate=ltx_pipe.vocoder.config.output_sampling_rate,
output_path="sana_ltx2_refined_audio.mp4",
)The resulting generated audio is reasonably good: sana_ltx2_refined_audio.mp4 |
|
@dg845 thanks! Now the code logi depends on VAE type |
This PR adds LTX-VAE support for SANA-Video.
Cc: @dg845 @sayakpaul
GPU memory needed: 47GB for LTX refiner
SANA-Video with LTX2-Refiner:
Result
sana_ltx_refined.mp4