Skip to main content

qwen3_tts/
lib.rs

1//! # Qwen3-TTS
2//!
3//! Pure Rust inference for [Qwen3-TTS](https://github.com/QwenLM/Qwen3-TTS),
4//! a high-quality text-to-speech model from Alibaba.
5//!
6//! ## Features
7//!
8//! - **CPU inference** with optional MKL/Accelerate for faster BLAS operations
9//! - **CUDA** support for NVIDIA GPU acceleration
10//! - **Metal** support for Apple Silicon
11//! - **Streaming-friendly** architecture with incremental token generation
12//! - **Voice cloning** via ECAPA-TDNN speaker encoder (Base models)
13//! - **Auto-detection** of model variant from `config.json`
14//!
15//! ## Quick Start
16//!
17//! ```rust,ignore
18//! use qwen3_tts::{Qwen3TTS, SynthesisOptions, auto_device};
19//!
20//! // Load model — variant auto-detected from config.json
21//! let device = auto_device()?;
22//! let model = Qwen3TTS::from_pretrained("path/to/model", device)?;
23//!
24//! // Synthesize speech with default settings
25//! let audio = model.synthesize("Hello, world!", None)?;
26//! audio.save("output.wav")?;
27//!
28//! // Or with custom options
29//! let options = SynthesisOptions {
30//!     temperature: 0.8,
31//!     top_k: 30,
32//!     ..Default::default()
33//! };
34//! let audio = model.synthesize("Custom settings!", Some(options))?;
35//! ```
36//!
37//! ## Architecture
38//!
39//! The TTS pipeline consists of three stages:
40//!
41//! 1. **TalkerModel**: Transformer that generates semantic tokens from text
42//!    autoregressively. Uses dual embeddings (text + codec) with MRoPE
43//!    (multimodal rotary position encoding) across all variants.
44//!
45//! 2. **CodePredictor**: For each semantic token, generates 15 acoustic
46//!    tokens using a 5-layer autoregressive decoder. The code predictor
47//!    always has `hidden_size=1024` regardless of the talker size; 1.7B
48//!    models use a `small_to_mtp_projection` layer to bridge the gap.
49//!
50//! 3. **Decoder12Hz**: Converts the 16-codebook codec tokens to audio
51//!    waveform at 24kHz. Uses ConvNeXt blocks and transposed convolutions
52//!    for upsampling. Shared across all model variants.
53//!
54//! ## Model Variants
55//!
56//! Five official variants exist in two size classes:
57//!
58//! | Variant | Size | Talker hidden | Speaker conditioning | HuggingFace ID |
59//! |---------|------|---------------|---------------------|----------------|
60//! | 0.6B Base | 1.8 GB | 1024 | Voice cloning (ECAPA-TDNN) | `Qwen/Qwen3-TTS-12Hz-0.6B-Base` |
61//! | 0.6B CustomVoice | 1.8 GB | 1024 | 9 preset speakers | `Qwen/Qwen3-TTS-12Hz-0.6B-CustomVoice` |
62//! | 1.7B Base | 3.9 GB | 2048 | Voice cloning (ECAPA-TDNN) | `Qwen/Qwen3-TTS-12Hz-1.7B-Base` |
63//! | 1.7B CustomVoice | 3.9 GB | 2048 | 9 preset speakers | `Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice` |
64//! | 1.7B VoiceDesign | 3.8 GB | 2048 | Text-described voices | `Qwen/Qwen3-TTS-12Hz-1.7B-VoiceDesign` |
65//!
66//! **Base**: Includes a speaker encoder for voice cloning from reference audio.
67//! Supports x_vector_only (speaker embedding) and ICL (in-context learning
68//! with reference audio + text) modes.
69//!
70//! **CustomVoice**: 9 preset speakers (Serena, Vivian, Ryan, Aiden, etc.) with
71//! no speaker encoder. Uses discrete speaker token IDs for voice selection.
72//!
73//! **VoiceDesign**: Creates novel voices from text descriptions (e.g.,
74//! "a deep male voice"). No speaker encoder or preset speakers.
75//!
76//! All variants share the same speech tokenizer and decoder weights. The
77//! code predictor architecture is identical (1024 hidden, 5 layers, 16 heads)
78//! across all variants.
79//!
80//! ## Sample Rate
81//!
82//! Output audio is always 24kHz mono. Use [`audio::resample()`] if you need
83//! a different sample rate.
84
85pub mod audio;
86pub mod generation;
87#[cfg(feature = "hub")]
88pub mod hub;
89pub mod models;
90pub mod profiling;
91pub mod tokenizer;
92
93use anyhow::Result;
94use candle_core::{DType, Device, IndexOp, Tensor};
95use serde::Serialize;
96use std::collections::HashMap;
97use std::path::Path;
98use std::time::Instant;
99
100use models::codec::{Decoder12Hz, Encoder12Hz};
101use models::speaker::SpeakerEncoder;
102use models::AnyKVCache;
103
104// SynthesisTiming is defined above, already public.
105
106/// Re-exports for convenience
107pub use audio::AudioBuffer;
108#[cfg(feature = "hub")]
109pub use hub::ModelPaths;
110pub use models::config::Qwen3TTSConfig;
111// StreamingSession is defined in this module, exported as top-level type
112pub use generation::SamplingContext;
113pub use models::talker::{codec_tokens, special_tokens, tts_tokens, Language, Speaker};
114pub use models::{
115    CodePredictor, CodePredictorConfig, ModelType, ParsedModelConfig, SpeakerEncoderConfig,
116    TalkerConfig, TalkerModel,
117};
118
119/// A sequence of codec frames, where each frame contains 16 codebook values
120/// (1 semantic + 15 acoustic, formatted as `[semantic, acoustic_0..14]`).
121pub type FrameCodes = Vec<Vec<u32>>;
122
123/// Reference audio prompt for voice cloning.
124///
125/// Holds the speaker embedding and optional ICL (in-context learning) data.
126/// Created via [`Qwen3TTS::create_voice_clone_prompt`].
127pub struct VoiceClonePrompt {
128    /// Speaker embedding from the ECAPA-TDNN encoder, shape `[enc_dim]` (typically 1024).
129    pub speaker_embedding: Tensor,
130    /// Reference audio codec codes for ICL mode, shape `[T, 16]`. `None` = x_vector_only mode.
131    pub ref_codes: Option<Tensor>,
132    /// Tokenized reference text for ICL mode.
133    pub ref_text_ids: Option<Vec<u32>>,
134}
135
136/// Per-stage timing breakdown from a synthesis run.
137#[derive(Debug, Clone, Serialize)]
138pub struct SynthesisTiming {
139    /// Time spent in the prefill phase (ms).
140    pub prefill_ms: f64,
141    /// Time spent in the autoregressive generation loop (ms).
142    pub generation_ms: f64,
143    /// Number of codec frames generated.
144    pub generation_frames: usize,
145    /// Time spent decoding codec frames to audio (ms).
146    pub decode_ms: f64,
147}
148
149/// Main TTS interface using proper autoregressive pipeline.
150///
151/// Supports all 5 Qwen3-TTS model variants. Use [`model_type()`](Self::model_type)
152/// to check which variant was loaded and [`supports_voice_cloning()`](Self::supports_voice_cloning)
153/// / [`supports_preset_speakers()`](Self::supports_preset_speakers) to check capabilities.
154pub struct Qwen3TTS {
155    /// Talker model for semantic token generation
156    talker: TalkerModel,
157    /// Code predictor for acoustic token generation
158    code_predictor: CodePredictor,
159    /// 12Hz decoder for audio synthesis
160    decoder: Decoder12Hz,
161    /// Text tokenizer
162    text_tokenizer: tokenizer::TextTokenizer,
163    /// Speaker encoder for voice cloning (loaded when weights are present)
164    speaker_encoder: Option<SpeakerEncoder>,
165    /// Speech tokenizer encoder for ICL voice cloning (encodes reference audio → codes)
166    speech_encoder: Option<Encoder12Hz>,
167    /// Detected model variant (None if loaded without config.json)
168    model_type: Option<ModelType>,
169    /// Device to run inference on
170    device: Device,
171    /// Compute dtype for talker + code predictor (BF16 on CUDA, F32 otherwise)
172    compute_dtype: DType,
173}
174
175impl Qwen3TTS {
176    /// Load a model from a HuggingFace model ID or local path.
177    ///
178    /// Auto-detects the model variant (0.6B/1.7B, Base/CustomVoice/VoiceDesign)
179    /// from `config.json` if present, falling back to weight inspection.
180    ///
181    /// The text tokenizer is resolved from `model_id/tokenizer.json` if present,
182    /// otherwise downloaded from HuggingFace Hub. Use `tokenizer_id` to override.
183    pub fn from_pretrained(model_id: &str, device: Device) -> Result<Self> {
184        Self::from_pretrained_with_tokenizer(model_id, None, device)
185    }
186
187    /// Load a model with an explicit tokenizer source.
188    ///
189    /// `tokenizer_id` can be a local directory, a file path, or a HuggingFace
190    /// model ID (e.g. `"Qwen/Qwen2-0.5B"`). If `None`, resolves from the
191    /// model directory or falls back to the default tokenizer repo.
192    pub fn from_pretrained_with_tokenizer(
193        model_id: &str,
194        tokenizer_id: Option<&str>,
195        device: Device,
196    ) -> Result<Self> {
197        tracing::info!("Loading Qwen3-TTS from: {}", model_id);
198        tracing::info!("Compute dtype: {:?}", compute_dtype_for_device(&device));
199
200        // Try to parse config.json for auto-detection
201        let config_path = Path::new(model_id).join("config.json");
202        let parsed_config = if config_path.exists() {
203            match ParsedModelConfig::from_file(&config_path) {
204                Ok(cfg) => {
205                    tracing::info!("Detected model variant: {}", cfg.label());
206                    Some(cfg)
207                }
208                Err(e) => {
209                    tracing::warn!(
210                        "Failed to parse config.json, falling back to weight inspection: {}",
211                        e
212                    );
213                    None
214                }
215            }
216        } else {
217            None
218        };
219
220        // Load text tokenizer
221        let tok_source = tokenizer_id.unwrap_or(model_id);
222        let text_tokenizer = tokenizer::TextTokenizer::from_pretrained(tok_source)?;
223
224        // Load model weights
225        let model_path = Path::new(model_id).join("model.safetensors");
226        if !model_path.exists() {
227            anyhow::bail!(
228                "Model weights not found at {}. Please download the model first.",
229                model_path.display()
230            );
231        }
232        let weights = Self::load_weights(&model_path, &device)?;
233
234        // Load speech tokenizer for decoder
235        let st_path = Path::new(model_id).join("speech_tokenizer/model.safetensors");
236        let st_weights = if st_path.exists() {
237            Self::load_weights(&st_path, &device)?
238        } else {
239            // Fall back to looking in parent dir
240            let alt_path = Path::new(model_id)
241                .parent()
242                .map(|p| p.join("speech_tokenizer/model.safetensors"));
243            if let Some(p) = alt_path {
244                if p.exists() {
245                    Self::load_weights(&p, &device)?
246                } else {
247                    anyhow::bail!("Speech tokenizer weights not found");
248                }
249            } else {
250                anyhow::bail!("Speech tokenizer weights not found");
251            }
252        };
253
254        Self::build_from_components(
255            &weights,
256            &st_weights,
257            text_tokenizer,
258            parsed_config.as_ref(),
259            &device,
260        )
261    }
262
263    /// Load from pre-loaded weight tensors.
264    ///
265    /// Uses weight inspection for auto-detection. For config.json-based
266    /// detection, use [`from_pretrained`](Self::from_pretrained) instead.
267    pub fn from_weights(
268        model_weights: &HashMap<String, Tensor>,
269        decoder_weights: &HashMap<String, Tensor>,
270        text_tokenizer: tokenizer::TextTokenizer,
271        device: &Device,
272    ) -> Result<Self> {
273        Self::build_from_components(model_weights, decoder_weights, text_tokenizer, None, device)
274    }
275
276    /// Load from downloaded model paths.
277    ///
278    /// Use with [`ModelPaths::download`] for automatic model downloading.
279    ///
280    /// # Example
281    ///
282    /// ```rust,ignore
283    /// use qwen3_tts::{Qwen3TTS, ModelPaths, auto_device};
284    ///
285    /// let paths = ModelPaths::download(None)?;
286    /// let device = auto_device()?;
287    /// let model = Qwen3TTS::from_paths(&paths, device)?;
288    /// ```
289    #[cfg(feature = "hub")]
290    pub fn from_paths(paths: &hub::ModelPaths, device: Device) -> Result<Self> {
291        tracing::info!("Loading Qwen3-TTS from downloaded paths...");
292
293        let text_tokenizer = tokenizer::TextTokenizer::from_file(&paths.tokenizer)?;
294        let weights = Self::load_weights(&paths.model_weights, &device)?;
295        let st_weights = Self::load_weights(&paths.decoder_weights, &device)?;
296        let parsed_config = ParsedModelConfig::from_file(&paths.config).ok();
297
298        Self::build_from_components(
299            &weights,
300            &st_weights,
301            text_tokenizer,
302            parsed_config.as_ref(),
303            &device,
304        )
305    }
306
307    /// Shared builder: assembles all model components from pre-loaded weights.
308    ///
309    /// When `parsed_config` is `Some`, uses config.json dimensions and model type.
310    /// When `None`, auto-detects the model variant from weight shapes.
311    fn build_from_components(
312        model_weights: &HashMap<String, Tensor>,
313        decoder_weights: &HashMap<String, Tensor>,
314        text_tokenizer: tokenizer::TextTokenizer,
315        parsed_config: Option<&ParsedModelConfig>,
316        device: &Device,
317    ) -> Result<Self> {
318        let compute_dtype = compute_dtype_for_device(device);
319
320        // Build TalkerModel
321        let talker_config = if let Some(cfg) = parsed_config {
322            TalkerConfig::from_parsed(cfg)
323        } else {
324            Self::detect_talker_config(model_weights)?
325        };
326        let talker = TalkerModel::from_weights_with_config_dtype(
327            model_weights,
328            talker_config,
329            device,
330            compute_dtype,
331        )?;
332
333        // Build CodePredictor
334        let cp_config = if let Some(cfg) = parsed_config {
335            CodePredictorConfig::from_parsed(cfg)
336        } else {
337            let talker_hidden = talker.config().hidden_size;
338            if talker_hidden != 1024 {
339                CodePredictorConfig {
340                    codec_embed_dim: Some(talker_hidden),
341                    ..CodePredictorConfig::default()
342                }
343            } else {
344                CodePredictorConfig::default()
345            }
346        };
347        let cp_weights = Self::filter_weights(model_weights, "talker.code_predictor.");
348        let cp_vb = candle_nn::VarBuilder::from_tensors(cp_weights, compute_dtype, device);
349        let code_predictor = CodePredictor::new(cp_config, cp_vb)?;
350
351        // Decoder (always F32 — convolutional, no attention)
352        let decoder = Decoder12Hz::from_weights(decoder_weights, Default::default())?;
353
354        // Speaker encoder (always F32, only present in Base models)
355        let se_config = parsed_config.and_then(|c| c.speaker_encoder_config.clone());
356        let speaker_encoder =
357            Self::try_load_speaker_encoder(model_weights, se_config.as_ref(), device)?;
358
359        // Speech encoder for ICL voice cloning
360        let speech_encoder = Self::try_load_speech_encoder(decoder_weights, device)?;
361
362        let model_type = parsed_config.map(|c| c.model_type);
363
364        Ok(Self {
365            talker,
366            code_predictor,
367            decoder,
368            text_tokenizer,
369            speaker_encoder,
370            speech_encoder,
371            model_type,
372            device: device.clone(),
373            compute_dtype,
374        })
375    }
376
377    /// Detect talker config from weight shapes (fallback when no config.json).
378    fn detect_talker_config(weights: &HashMap<String, Tensor>) -> Result<TalkerConfig> {
379        let norm_weight = weights
380            .get("talker.model.norm.weight")
381            .ok_or_else(|| anyhow::anyhow!("Missing talker.model.norm.weight"))?;
382        let hidden_size = norm_weight.dim(0)?;
383        Ok(if hidden_size == 2048 {
384            TalkerConfig::custom_voice()
385        } else {
386            TalkerConfig::default()
387        })
388    }
389
390    /// Returns the detected model type, or `None` if loaded without config.json.
391    pub fn model_type(&self) -> Option<&ModelType> {
392        self.model_type.as_ref()
393    }
394
395    /// Whether this model supports voice cloning (Base models with speaker encoder).
396    pub fn supports_voice_cloning(&self) -> bool {
397        self.speaker_encoder.is_some()
398    }
399
400    /// Whether this model supports preset speaker selection (CustomVoice models).
401    ///
402    /// Returns `true` for CustomVoice, `false` for Base and VoiceDesign.
403    /// When `model_type` is unknown (loaded without config.json), returns `true`
404    /// as a permissive default.
405    pub fn supports_preset_speakers(&self) -> bool {
406        match &self.model_type {
407            Some(ModelType::CustomVoice) => true,
408            Some(ModelType::Base) | Some(ModelType::VoiceDesign) => false,
409            None => true, // permissive when unknown
410        }
411    }
412
413    /// Whether this model supports voice design (text-described voice conditioning).
414    ///
415    /// Returns `true` for VoiceDesign, `false` for all other variants.
416    pub fn supports_voice_design(&self) -> bool {
417        matches!(&self.model_type, Some(ModelType::VoiceDesign))
418    }
419
420    /// Synthesize speech from text with default voice (Ryan, English).
421    ///
422    /// Convenience wrapper around [`synthesize_with_voice`](Self::synthesize_with_voice).
423    pub fn synthesize(&self, text: &str, options: Option<SynthesisOptions>) -> Result<AudioBuffer> {
424        self.synthesize_with_voice(text, Speaker::Ryan, Language::English, options)
425    }
426
427    /// Synthesize speech with per-stage timing breakdown.
428    ///
429    /// Same as [`synthesize_with_voice`](Self::synthesize_with_voice) but also
430    /// returns a [`SynthesisTiming`] with prefill, generation, and decode durations.
431    /// Uses [`sync_device`] at timing boundaries for accurate GPU measurements.
432    pub fn synthesize_with_timing(
433        &self,
434        text: &str,
435        speaker: Speaker,
436        language: Language,
437        options: Option<SynthesisOptions>,
438    ) -> Result<(AudioBuffer, SynthesisTiming)> {
439        #[cfg(feature = "profiling")]
440        let _span = tracing::info_span!("synthesize").entered();
441
442        let options = options.unwrap_or_default();
443        let mut sampling_ctx = generation::SamplingContext::new(options.seed);
444        let input_ids = self.text_tokenizer.encode(text)?;
445        let gen_config = options.to_gen_config();
446
447        let (trailing_text_hidden, trailing_text_len, tts_pad_embed) =
448            self.build_trailing_text(&input_ids)?;
449
450        // -- Prefill --
451        #[cfg(feature = "profiling")]
452        let _prefill_span = tracing::info_span!("prefill").entered();
453
454        sync_device(&self.device)?;
455        let t_prefill = Instant::now();
456
457        let mut kv_caches = self.talker.new_kv_caches(gen_config.max_new_tokens + 256);
458        let (hidden, logits) =
459            self.talker
460                .prefill_custom_voice(&input_ids, speaker, language, &mut kv_caches)?;
461        let prefill_len = hidden.dim(1)?;
462        let offset = prefill_len;
463        let last_hidden = hidden.i((.., prefill_len - 1..prefill_len, ..))?;
464
465        sync_device(&self.device)?;
466        let prefill_ms = t_prefill.elapsed().as_secs_f64() * 1000.0;
467
468        #[cfg(feature = "profiling")]
469        drop(_prefill_span);
470
471        // -- Generation --
472        let t_gen = Instant::now();
473
474        let all_codes = self.generate_codes(
475            &gen_config,
476            &mut sampling_ctx,
477            &mut kv_caches,
478            offset,
479            last_hidden,
480            &logits,
481            &trailing_text_hidden,
482            trailing_text_len,
483            &tts_pad_embed,
484        )?;
485
486        sync_device(&self.device)?;
487        let generation_ms = t_gen.elapsed().as_secs_f64() * 1000.0;
488        let generation_frames = all_codes.len();
489
490        // -- Decode --
491        #[cfg(feature = "profiling")]
492        let _decode_span = tracing::info_span!("decode").entered();
493
494        let t_decode = Instant::now();
495        let audio = self.decode_codes(&all_codes)?;
496
497        sync_device(&self.device)?;
498        let decode_ms = t_decode.elapsed().as_secs_f64() * 1000.0;
499
500        let timing = SynthesisTiming {
501            prefill_ms,
502            generation_ms,
503            generation_frames,
504            decode_ms,
505        };
506
507        Ok((audio, timing))
508    }
509
510    /// Build trailing text embeddings from input token IDs.
511    ///
512    /// Returns (trailing_text_hidden, trailing_text_len, tts_pad_embed).
513    /// The trailing text is: remaining text tokens (all except first) projected + tts_eos.
514    /// After trailing text is exhausted, tts_pad is used for each subsequent step.
515    fn build_trailing_text(&self, input_ids: &[u32]) -> Result<(Tensor, usize, Tensor)> {
516        let trailing_text_hidden = if input_ids.len() > 1 {
517            let remaining_proj = self.talker.get_projected_text_embeddings(&input_ids[1..])?;
518            let tts_eos_embed = self.talker.get_tts_eos_embed()?;
519            Tensor::cat(&[&remaining_proj, &tts_eos_embed], 1)?
520        } else {
521            self.talker.get_tts_eos_embed()?
522        };
523        let trailing_text_len = trailing_text_hidden.dim(1)?;
524        let tts_pad_embed = self.talker.get_tts_pad_embed()?;
525        Ok((trailing_text_hidden, trailing_text_len, tts_pad_embed))
526    }
527
528    /// Core generation loop shared by all synthesis methods.
529    ///
530    /// Runs the autoregressive generation loop: for each frame, check EOS,
531    /// generate acoustic codes via CodePredictor, build the residual VQ sum
532    /// with trailing text fusion, and sample the next semantic token.
533    ///
534    /// Callers handle prefill (which varies by model variant) and post-processing
535    /// (decode, ICL ref_codes prepending, etc.).
536    #[allow(clippy::too_many_arguments)]
537    fn generate_codes(
538        &self,
539        gen_config: &generation::GenerationConfig,
540        sampling_ctx: &mut generation::SamplingContext,
541        kv_caches: &mut [AnyKVCache],
542        mut offset: usize,
543        mut last_hidden: Tensor,
544        initial_logits: &Tensor,
545        trailing_text_hidden: &Tensor,
546        trailing_text_len: usize,
547        tts_pad_embed: &Tensor,
548    ) -> Result<FrameCodes> {
549        // Pre-build the token suppression mask once (reused every frame)
550        let suppression_mask = generation::build_suppression_mask(
551            codec_tokens::CODEC_VOCAB_SIZE,
552            CODEC_EOS_TOKEN_ID,
553            &self.device,
554        )?;
555
556        // GPU-side repetition penalty mask: [1, vocab] — updated incrementally
557        // instead of transferring all generated tokens to CPU each frame.
558        let vocab_size = codec_tokens::CODEC_VOCAB_SIZE;
559        let mut penalty_mask = Tensor::zeros((1, vocab_size), DType::F32, &self.device)?;
560
561        // Pre-allocate code predictor KV caches (reused + reset each frame)
562        let mut cp_kv_caches = self.code_predictor.new_kv_caches();
563
564        // Sample first semantic token
565        let logits_2d = initial_logits.squeeze(1)?;
566        let logits_2d = self.apply_generation_penalties_gpu(
567            &logits_2d,
568            &penalty_mask,
569            gen_config,
570            0,
571            Some(&suppression_mask),
572        )?;
573        let mut semantic_token_tensor = generation::sample(&logits_2d, gen_config, sampling_ctx)?;
574        tracing::trace!(target: "gpu_sync", "to_vec1 in generate_codes first token");
575        let mut semantic_token: u32 = semantic_token_tensor.flatten_all()?.to_vec1::<u32>()?[0];
576        // Update penalty mask with this token (O(1) CPU work)
577        Self::update_penalty_mask(&mut penalty_mask, semantic_token, vocab_size)?;
578        let mut token_count: usize = 1;
579
580        // Accumulate frames as GPU tensors: Vec of [16] U32 tensors
581        // Deferred to_vec1 at the end eliminates per-frame acoustic code sync.
582        let mut gpu_frames: Vec<Tensor> = Vec::new();
583
584        #[cfg(feature = "profiling")]
585        let _gen_span = tracing::info_span!("generate_frames").entered();
586
587        for frame_idx in 0..gen_config.max_new_tokens {
588            if let Some(eos_id) = gen_config.eos_token_id {
589                if semantic_token == eos_id {
590                    break;
591                }
592            }
593
594            // Embedding lookup using GPU-resident token tensor (no CPU→GPU roundtrip)
595            let semantic_embed = self
596                .talker
597                .get_codec_embedding_from_tensor(&semantic_token_tensor)?;
598
599            #[cfg(feature = "profiling")]
600            let _cp_span = tracing::info_span!("code_predictor", frame = frame_idx).entered();
601
602            let acoustic_codes_tensor = self.code_predictor.generate_acoustic_codes(
603                &last_hidden,
604                &semantic_embed,
605                &mut cp_kv_caches,
606            )?;
607
608            #[cfg(feature = "profiling")]
609            drop(_cp_span);
610
611            // Build [16] frame tensor on GPU: [semantic_token, acoustic_0..14]
612            let frame_tensor = Tensor::cat(
613                &[&semantic_token_tensor.reshape(1)?, &acoustic_codes_tensor],
614                0,
615            )?;
616            gpu_frames.push(frame_tensor);
617
618            // Use GPU tensor directly for embedding lookup (avoids 15 CPU→GPU transfers)
619            let acoustic_embed_sum = self
620                .code_predictor
621                .get_acoustic_embeddings_sum_from_tensor(&acoustic_codes_tensor)?;
622            let summed = semantic_embed.add(&acoustic_embed_sum)?;
623
624            let text_addition = if frame_idx < trailing_text_len {
625                trailing_text_hidden.i((.., frame_idx..frame_idx + 1, ..))?
626            } else {
627                tts_pad_embed.clone()
628            };
629            let step_input = summed.add(&text_addition)?;
630
631            #[cfg(feature = "profiling")]
632            let _talker_span = tracing::info_span!("talker_step", frame = frame_idx).entered();
633
634            let (h, new_logits) =
635                self.talker
636                    .generate_step_with_embed(&step_input, kv_caches, offset)?;
637            offset += 1;
638            last_hidden = h;
639
640            #[cfg(feature = "profiling")]
641            drop(_talker_span);
642
643            #[cfg(feature = "profiling")]
644            let _sample_span = tracing::info_span!("sampling", frame = frame_idx).entered();
645
646            let logits_2d = new_logits.squeeze(1)?;
647            let logits_2d = self.apply_generation_penalties_gpu(
648                &logits_2d,
649                &penalty_mask,
650                gen_config,
651                token_count,
652                Some(&suppression_mask),
653            )?;
654            semantic_token_tensor = generation::sample(&logits_2d, gen_config, sampling_ctx)?;
655            tracing::trace!(target: "gpu_sync", "to_vec1 in generate_codes sampling");
656            semantic_token = semantic_token_tensor.flatten_all()?.to_vec1::<u32>()?[0];
657            Self::update_penalty_mask(&mut penalty_mask, semantic_token, vocab_size)?;
658            token_count += 1;
659        }
660
661        // Single GPU→CPU transfer: convert all accumulated GPU frames to FrameCodes
662        self.gpu_frames_to_frame_codes(&gpu_frames)
663    }
664
665    /// Update the GPU-side penalty mask for a single token ID.
666    ///
667    /// Sets `penalty_mask[0, token_id] = 1.0` using slice_assign with a
668    /// pre-built scalar. This is O(1) CPU work (no GPU→CPU transfer).
669    fn update_penalty_mask(
670        penalty_mask: &mut Tensor,
671        token_id: u32,
672        vocab_size: usize,
673    ) -> Result<()> {
674        let idx = token_id as usize;
675        if idx < vocab_size {
676            let one = Tensor::ones((1, 1), DType::F32, penalty_mask.device())?;
677            *penalty_mask = penalty_mask.slice_assign(&[0..1, idx..idx + 1], &one)?;
678        }
679        Ok(())
680    }
681
682    /// Convert accumulated GPU frame tensors to FrameCodes via a single bulk transfer.
683    fn gpu_frames_to_frame_codes(&self, gpu_frames: &[Tensor]) -> Result<FrameCodes> {
684        if gpu_frames.is_empty() {
685            return Ok(Vec::new());
686        }
687        // Stack all frames into [n_frames, 16], then single to_vec1
688        let stacked = Tensor::stack(gpu_frames, 0)?; // [n_frames, 16]
689        let n_frames = stacked.dim(0)?;
690        let flat: Vec<u32> = stacked.flatten_all()?.to_vec1()?;
691        let mut result = Vec::with_capacity(n_frames);
692        for f in 0..n_frames {
693            let start = f * 16;
694            result.push(flat[start..start + 16].to_vec());
695        }
696        Ok(result)
697    }
698
699    /// Synthesize speech with a specific voice and language.
700    ///
701    /// Uses the correct generation loop: CustomVoice prefill, autoregressive
702    /// semantic tokens, per-frame acoustic code prediction via CodePredictor,
703    /// residual VQ summation, and trailing text fusion.
704    ///
705    /// # Arguments
706    ///
707    /// * `text` - Text to synthesize
708    /// * `speaker` - Predefined speaker voice
709    /// * `language` - Target language
710    /// * `options` - Synthesis options (temperature, top_k, etc.)
711    ///
712    /// # Example
713    ///
714    /// ```rust,ignore
715    /// use qwen3_tts::{Qwen3TTS, Speaker, Language, SynthesisOptions};
716    ///
717    /// let audio = model.synthesize_with_voice(
718    ///     "Hello, world!",
719    ///     Speaker::Ryan,
720    ///     Language::English,
721    ///     None,
722    /// )?;
723    /// audio.save("output.wav")?;
724    /// ```
725    pub fn synthesize_with_voice(
726        &self,
727        text: &str,
728        speaker: Speaker,
729        language: Language,
730        options: Option<SynthesisOptions>,
731    ) -> Result<AudioBuffer> {
732        #[cfg(feature = "profiling")]
733        let _span = tracing::info_span!("synthesize").entered();
734
735        if let Some(ModelType::Base) = &self.model_type {
736            tracing::warn!(
737                "Using preset speaker {:?} on a Base model. Base models are trained for \
738                 voice cloning, not preset speakers — output will have an unpredictable voice. \
739                 Use synthesize_voice_clone() with reference audio instead.",
740                speaker
741            );
742        } else if let Some(ModelType::VoiceDesign) = &self.model_type {
743            tracing::warn!(
744                "Using preset speaker {:?} on a VoiceDesign model. VoiceDesign models \
745                 are trained for text-described voice creation, not preset speakers.",
746                speaker
747            );
748        }
749
750        let options = options.unwrap_or_default();
751        let mut sampling_ctx = generation::SamplingContext::new(options.seed);
752        let input_ids = self.text_tokenizer.encode(text)?;
753
754        let gen_config = options.to_gen_config();
755
756        let (trailing_text_hidden, trailing_text_len, tts_pad_embed) =
757            self.build_trailing_text(&input_ids)?;
758
759        // Prefill with CustomVoice format
760        #[cfg(feature = "profiling")]
761        let _prefill_span = tracing::info_span!("prefill").entered();
762
763        let mut kv_caches = self.talker.new_kv_caches(gen_config.max_new_tokens + 256);
764        let (hidden, logits) =
765            self.talker
766                .prefill_custom_voice(&input_ids, speaker, language, &mut kv_caches)?;
767        let prefill_len = hidden.dim(1)?;
768        let offset = prefill_len;
769        let last_hidden = hidden.i((.., prefill_len - 1..prefill_len, ..))?;
770
771        #[cfg(feature = "profiling")]
772        drop(_prefill_span);
773
774        let all_codes = self.generate_codes(
775            &gen_config,
776            &mut sampling_ctx,
777            &mut kv_caches,
778            offset,
779            last_hidden,
780            &logits,
781            &trailing_text_hidden,
782            trailing_text_len,
783            &tts_pad_embed,
784        )?;
785
786        // Decode to audio
787        #[cfg(feature = "profiling")]
788        let _decode_span = tracing::info_span!("decode").entered();
789
790        self.decode_codes(&all_codes)
791    }
792
793    /// Synthesize speech using a text-described voice (VoiceDesign model).
794    ///
795    /// Uses the same generation loop as [`Self::synthesize_with_voice`] but runs the
796    /// VoiceDesign prefill instead of the predefined-speaker prefill. The voice
797    /// is conditioned on a natural language description (e.g., "A cheerful young
798    /// female voice with high pitch and energetic tone").
799    ///
800    /// The instruct text is tokenized with ChatML framing:
801    /// `<|im_start|>user\n{instruct}<|im_end|>\n`
802    ///
803    /// # Arguments
804    ///
805    /// * `text` - Text to synthesize
806    /// * `instruct` - Natural language voice description
807    /// * `language` - Target language
808    /// * `options` - Synthesis options (temperature, top_k, etc.)
809    pub fn synthesize_voice_design(
810        &self,
811        text: &str,
812        instruct: &str,
813        language: Language,
814        options: Option<SynthesisOptions>,
815    ) -> Result<AudioBuffer> {
816        #[cfg(feature = "profiling")]
817        let _span = tracing::info_span!("synthesize").entered();
818
819        if let Some(ref mt) = self.model_type {
820            if *mt != ModelType::VoiceDesign {
821                tracing::warn!(
822                    "Using VoiceDesign synthesis on a {:?} model. This model was not trained \
823                     for text-described voice conditioning — output may be unpredictable.",
824                    mt
825                );
826            }
827        }
828
829        let options = options.unwrap_or_default();
830        let mut sampling_ctx = generation::SamplingContext::new(options.seed);
831        let input_ids = self.text_tokenizer.encode(text)?;
832
833        // Tokenize instruct with ChatML user framing: <|im_start|>user\n{instruct}<|im_end|>\n
834        let instruct_text = format!("<|im_start|>user\n{}<|im_end|>\n", instruct);
835        let instruct_ids = self.text_tokenizer.encode(&instruct_text)?;
836
837        let gen_config = options.to_gen_config();
838
839        let (trailing_text_hidden, trailing_text_len, tts_pad_embed) =
840            self.build_trailing_text(&input_ids)?;
841
842        // Prefill with VoiceDesign format
843        #[cfg(feature = "profiling")]
844        let _prefill_span = tracing::info_span!("prefill").entered();
845
846        let mut kv_caches = self.talker.new_kv_caches(gen_config.max_new_tokens + 256);
847        let (hidden, logits) = self.talker.prefill_voice_design(
848            &input_ids,
849            &instruct_ids,
850            language,
851            &mut kv_caches,
852        )?;
853        let prefill_len = hidden.dim(1)?;
854        let offset = prefill_len;
855        let last_hidden = hidden.i((.., prefill_len - 1..prefill_len, ..))?;
856
857        #[cfg(feature = "profiling")]
858        drop(_prefill_span);
859
860        let all_codes = self.generate_codes(
861            &gen_config,
862            &mut sampling_ctx,
863            &mut kv_caches,
864            offset,
865            last_hidden,
866            &logits,
867            &trailing_text_hidden,
868            trailing_text_len,
869            &tts_pad_embed,
870        )?;
871
872        // Decode to audio
873        #[cfg(feature = "profiling")]
874        let _decode_span = tracing::info_span!("decode").entered();
875
876        self.decode_codes(&all_codes)
877    }
878
879    /// Convert list of frame codes to tensor [batch, 16, num_frames]
880    pub fn codes_to_tensor(&self, codes: &[Vec<u32>]) -> Result<Tensor> {
881        codes_to_tensor(codes, &self.device)
882    }
883
884    /// Decode raw frame codes to audio.
885    ///
886    /// Takes a slice of frames (each frame is a `Vec<u32>` of 16 codebook values)
887    /// and runs the 12Hz decoder to produce an audio waveform at 24kHz.
888    pub fn decode_codes(&self, codes: &[Vec<u32>]) -> Result<AudioBuffer> {
889        let tensor = self.codes_to_tensor(codes)?;
890        self.decode_tensor(&tensor)
891    }
892
893    /// Decode a codes tensor `[1, 16, T]` to audio.
894    fn decode_tensor(&self, codes: &Tensor) -> Result<AudioBuffer> {
895        let waveform = self.decoder.decode(codes)?;
896        AudioBuffer::from_tensor(waveform, 24000)
897    }
898
899    /// Synthesize speech using a cloned voice, returning raw codes alongside audio.
900    ///
901    /// Identical to [`synthesize_voice_clone`](Self::synthesize_voice_clone) but also
902    /// returns the raw generated codes (`Vec<Vec<u32>>`) for debugging.
903    /// Each inner `Vec<u32>` is one frame: `[semantic, acoustic_0..14]` (16 values).
904    pub fn synthesize_voice_clone_debug(
905        &self,
906        text: &str,
907        prompt: &VoiceClonePrompt,
908        language: Language,
909        options: Option<SynthesisOptions>,
910    ) -> Result<(AudioBuffer, FrameCodes)> {
911        #[cfg(feature = "profiling")]
912        let _span = tracing::info_span!("synthesize").entered();
913
914        let options = options.unwrap_or_default();
915        let mut sampling_ctx = generation::SamplingContext::new(options.seed);
916        let input_ids = self.text_tokenizer.encode(text)?;
917
918        // Determine if ICL mode is active (ref_codes + ref_text present)
919        let is_icl = prompt.ref_codes.is_some() && prompt.ref_text_ids.is_some();
920
921        // ICL mode adjustments (matching mlx-audio):
922        let repetition_penalty = if is_icl {
923            options.repetition_penalty.max(ICL_MIN_REPETITION_PENALTY)
924        } else {
925            options.repetition_penalty
926        };
927        let max_new_tokens = if is_icl {
928            options
929                .max_length
930                .min(ICL_MIN_FRAMES.max(input_ids.len() * ICL_FRAMES_PER_TOKEN))
931        } else {
932            options.max_length
933        };
934        let mut gen_config = options.to_gen_config();
935        gen_config.max_new_tokens = max_new_tokens;
936        gen_config.repetition_penalty = repetition_penalty;
937
938        // Cast speaker embedding to compute dtype (speaker encoder produces F32)
939        let speaker_embed = prompt.speaker_embedding.to_dtype(self.compute_dtype)?;
940
941        // Voice clone prefill (9 positions for ICL, 10 for x_vector_only)
942        #[cfg(feature = "profiling")]
943        let _prefill_span = tracing::info_span!("prefill").entered();
944
945        // Allocate Talker KV cache with enough capacity for:
946        // - prefill tokens (9/10 tokens)
947        // - ICL prompt extension (ref text + target text + ref codec frames)
948        // - autoregressive generation budget
949        //
950        // A fixed budget `max_new_tokens + 256` can under-allocate when ICL prompt
951        // length is large, causing KV cache overflow before generation starts.
952        let mut kv_cache_capacity = gen_config.max_new_tokens + 256;
953        if is_icl {
954            let ref_text_len = prompt.ref_text_ids.as_ref().map_or(0, |ids| ids.len());
955            let ref_frames = match &prompt.ref_codes {
956                Some(codes) => codes.dim(0)?,
957                None => 0,
958            };
959
960            // Conservative upper bound that safely covers both streaming and
961            // non-streaming ICL layouts.
962            let text_part = ref_text_len + input_ids.len() + 1; // +tts_eos
963            let codec_part = ref_frames + 1; // +codec_bos
964            let icl_upper_bound = text_part + codec_part;
965            let prefill_upper_bound = 10;
966            let safety_margin = 32;
967            let required_capacity =
968                prefill_upper_bound + icl_upper_bound + gen_config.max_new_tokens + safety_margin;
969            kv_cache_capacity = kv_cache_capacity.max(required_capacity);
970        }
971
972        let mut kv_caches = self.talker.new_kv_caches(kv_cache_capacity);
973        let (hidden, logits) = self.talker.prefill_voice_clone(
974            &input_ids,
975            &speaker_embed,
976            language,
977            is_icl,
978            &mut kv_caches,
979        )?;
980        let prefill_len = hidden.dim(1)?;
981        let mut offset = prefill_len;
982
983        // Initialize last_hidden from prefill; updated by ICL block if active.
984        let mut last_hidden = hidden.i((.., prefill_len - 1..prefill_len, ..))?;
985
986        // ICL extension (if reference codes + text are provided)
987        let (trailing_text_hidden, logits) = if let (Some(ref_codes), Some(ref_text_ids)) =
988            (&prompt.ref_codes, &prompt.ref_text_ids)
989        {
990            let ref_codec_embeds = self.sum_ref_codec_embeddings(ref_codes)?;
991
992            // In ICL mode, all text tokens go into the ICL prompt (Python:
993            // text_id=input_id[:, 3:-5] passes ALL target text tokens).
994            // In the non-ICL path the first text token is consumed by the prefill,
995            // so only the remaining tokens go to trailing_text.
996            let (icl_embed, icl_trailing) =
997                self.talker
998                    .build_icl_prompt(&input_ids, ref_text_ids, &ref_codec_embeds, true)?;
999
1000            let icl_len = icl_embed.dim(1)?;
1001            if icl_len > 0 {
1002                let mask = models::transformer::create_causal_mask(icl_len, offset, &self.device)?;
1003
1004                let mut icl_hidden = icl_embed;
1005                for (i, layer) in self.talker.layers_iter().enumerate() {
1006                    icl_hidden = layer.forward(
1007                        &icl_hidden,
1008                        self.talker.rope(),
1009                        Some(&mask),
1010                        Some(&mut kv_caches[i]),
1011                        offset,
1012                    )?;
1013                }
1014                icl_hidden = self.talker.apply_norm(&icl_hidden)?;
1015                offset += icl_len;
1016
1017                let last_icl_hidden = icl_hidden.i((.., icl_len - 1..icl_len, ..))?;
1018                let new_logits = self.talker.apply_codec_head(&last_icl_hidden)?;
1019
1020                // Update last_hidden so the code predictor is conditioned on
1021                // the ICL context, not the stale prefill hidden state.
1022                last_hidden = last_icl_hidden;
1023
1024                (icl_trailing, new_logits)
1025            } else {
1026                let trailing = self.build_default_trailing_text(&input_ids)?;
1027                (trailing, logits)
1028            }
1029        } else {
1030            let trailing = self.build_default_trailing_text(&input_ids)?;
1031            (trailing, logits)
1032        };
1033
1034        #[cfg(feature = "profiling")]
1035        drop(_prefill_span);
1036
1037        let trailing_text_len = trailing_text_hidden.dim(1)?;
1038        let tts_pad_embed = self.talker.get_tts_pad_embed()?;
1039
1040        let all_codes = self.generate_codes(
1041            &gen_config,
1042            &mut sampling_ctx,
1043            &mut kv_caches,
1044            offset,
1045            last_hidden,
1046            &logits,
1047            &trailing_text_hidden,
1048            trailing_text_len,
1049            &tts_pad_embed,
1050        )?;
1051
1052        // Prepend ref_codes for ICL decoder context (same fix as synthesize_voice_clone)
1053        #[cfg(feature = "profiling")]
1054        let _decode_span = tracing::info_span!("decode").entered();
1055
1056        let audio = if let Some(ref_codes) = &prompt.ref_codes {
1057            let ref_frames = self.tensor_to_frame_codes(ref_codes)?;
1058            let ref_len = ref_frames.len();
1059            let mut combined = ref_frames;
1060            combined.extend(all_codes.iter().cloned());
1061
1062            let mut audio = self.decode_codes(&combined)?;
1063            let total_frames = combined.len();
1064            // Proportional cut: matches official Qwen3-TTS Python implementation
1065            // cut = ref_len / total_len * wav.shape[0]
1066            let cut_samples = ref_len * audio.len() / total_frames.max(1);
1067            tracing::debug!(
1068                "ICL decode: ref_frames={}, gen_frames={}, total_samples={}, cut_samples={}",
1069                ref_len,
1070                all_codes.len(),
1071                audio.len(),
1072                cut_samples,
1073            );
1074            audio.samples = audio.samples[cut_samples.min(audio.len())..].to_vec();
1075            audio
1076        } else {
1077            self.decode_codes(&all_codes)?
1078        };
1079        Ok((audio, all_codes))
1080    }
1081
1082    /// Get the device this model is running on
1083    pub fn device(&self) -> &Device {
1084        &self.device
1085    }
1086
1087    /// Create a streaming synthesis session with a specific voice and language.
1088    ///
1089    /// Returns an iterator that yields audio chunks as they are generated.
1090    /// Each chunk contains approximately `chunk_frames` frames worth of audio
1091    /// (default: 10 frames = ~800ms at 12.5 Hz frame rate).
1092    ///
1093    /// # Example
1094    ///
1095    /// ```rust,ignore
1096    /// use qwen3_tts::{Qwen3TTS, Speaker, Language, SynthesisOptions};
1097    ///
1098    /// let options = SynthesisOptions::default();
1099    /// for chunk in model.synthesize_streaming("Hello!", Speaker::Ryan, Language::English, options)? {
1100    ///     let audio = chunk?;
1101    ///     // Play or process audio chunk (each ~800ms)
1102    /// }
1103    /// ```
1104    pub fn synthesize_streaming(
1105        &self,
1106        text: &str,
1107        speaker: Speaker,
1108        language: Language,
1109        options: SynthesisOptions,
1110    ) -> Result<StreamingSession<'_>> {
1111        let input_ids = self.text_tokenizer.encode(text)?;
1112        StreamingSession::new(self, &input_ids, speaker, language, options)
1113    }
1114
1115    /// Synthesize speech using a text-described voice (VoiceDesign model), streaming.
1116    ///
1117    /// Same as [`Self::synthesize_voice_design`] but returns a streaming session
1118    /// that yields audio chunks as they are generated.
1119    ///
1120    /// The instruct text is tokenized with ChatML framing:
1121    /// `<|im_start|>user\n{instruct}<|im_end|>\n`
1122    ///
1123    /// # Arguments
1124    ///
1125    /// * `text` - Text to synthesize
1126    /// * `instruct` - Natural language voice description (e.g., "A cheerful young female voice")
1127    /// * `language` - Target language
1128    /// * `options` - Synthesis options (temperature, top_k, chunk_frames, etc.)
1129    pub fn synthesize_voice_design_streaming(
1130        &self,
1131        text: &str,
1132        instruct: &str,
1133        language: Language,
1134        options: SynthesisOptions,
1135    ) -> Result<StreamingSession<'_>> {
1136        if let Some(ref mt) = self.model_type {
1137            if *mt != ModelType::VoiceDesign {
1138                tracing::warn!(
1139                    "Using VoiceDesign synthesis on a {:?} model. This model was not trained \
1140                     for text-described voice conditioning — output may be unpredictable.",
1141                    mt
1142                );
1143            }
1144        }
1145
1146        let input_ids = self.text_tokenizer.encode(text)?;
1147
1148        // Tokenize instruct with ChatML user framing: <|im_start|>user\n{instruct}<|im_end|>\n
1149        let instruct_text = format!("<|im_start|>user\n{}<|im_end|>\n", instruct);
1150        let instruct_ids = self.text_tokenizer.encode(&instruct_text)?;
1151
1152        StreamingSession::new_voice_design(self, &input_ids, &instruct_ids, language, options)
1153    }
1154
1155    // ── Voice cloning API ─────────────────────────────────────────────────
1156
1157    /// Create a voice clone prompt from reference audio.
1158    ///
1159    /// When `ref_text` is `None`, produces an **x_vector_only** prompt (speaker
1160    /// embedding only). When `Some`, produces an **ICL** prompt (speaker embedding
1161    /// + reference audio codes + reference text) — requires a speech encoder.
1162    ///
1163    /// # Errors
1164    ///
1165    /// Returns an error if the speaker encoder is not loaded.
1166    pub fn create_voice_clone_prompt(
1167        &self,
1168        ref_audio: &AudioBuffer,
1169        ref_text: Option<&str>,
1170    ) -> Result<VoiceClonePrompt> {
1171        let encoder = self.speaker_encoder.as_ref().ok_or_else(|| {
1172            let hint = match &self.model_type {
1173                Some(ModelType::CustomVoice) => {
1174                    " CustomVoice models use preset speakers (synthesize_with_voice), \
1175                     not voice cloning. Use a Base model for voice cloning."
1176                }
1177                Some(ModelType::VoiceDesign) => {
1178                    " VoiceDesign models use text-described voices, not voice cloning. \
1179                     Use a Base model for voice cloning."
1180                }
1181                _ => {
1182                    " Ensure model weights contain `speaker_encoder.*` keys \
1183                     (only Base models include a speaker encoder)."
1184                }
1185            };
1186            anyhow::anyhow!("Speaker encoder not available.{}", hint)
1187        })?;
1188
1189        // Resample to 24kHz if needed — both encoders assume 24kHz input
1190        let ref_audio_24k;
1191        let ref_audio = if ref_audio.sample_rate != 24000 {
1192            tracing::info!(
1193                "Resampling reference audio from {}Hz to 24000Hz",
1194                ref_audio.sample_rate
1195            );
1196            ref_audio_24k = audio::resample_to_24k(ref_audio)?;
1197            &ref_audio_24k
1198        } else {
1199            ref_audio
1200        };
1201
1202        let speaker_embedding = encoder.encode(ref_audio)?; // [enc_dim]
1203
1204        // ICL data: encode reference audio to codes and tokenize reference text
1205        let (ref_codes, ref_text_ids) = if let Some(text) = ref_text {
1206            let speech_enc = self.speech_encoder.as_ref().ok_or_else(|| {
1207                anyhow::anyhow!(
1208                    "ICL voice cloning requires a speech encoder, but it was not loaded. \
1209                     Ensure the speech tokenizer weights contain encoder keys, or use \
1210                     x_vector_only mode by passing ref_text=None."
1211                )
1212            })?;
1213
1214            let codes = speech_enc.encode(ref_audio)?; // [T_frames, 16]
1215            let text_ids = self.text_tokenizer.encode(text)?;
1216
1217            (Some(codes), Some(text_ids))
1218        } else {
1219            (None, None)
1220        };
1221
1222        Ok(VoiceClonePrompt {
1223            speaker_embedding,
1224            ref_codes,
1225            ref_text_ids,
1226        })
1227    }
1228
1229    /// Synthesize speech using a cloned voice.
1230    ///
1231    /// Uses the same generation loop as [`Self::synthesize_with_voice`] but runs the
1232    /// voice-clone prefill instead of the predefined-speaker prefill.
1233    ///
1234    /// When the prompt contains ICL data (ref_codes + ref_text_ids), the model
1235    /// is conditioned on reference audio/text to better reproduce the speaker's voice.
1236    pub fn synthesize_voice_clone(
1237        &self,
1238        text: &str,
1239        prompt: &VoiceClonePrompt,
1240        language: Language,
1241        options: Option<SynthesisOptions>,
1242    ) -> Result<AudioBuffer> {
1243        self.synthesize_voice_clone_debug(text, prompt, language, options)
1244            .map(|(audio, _codes)| audio)
1245    }
1246
1247    /// Convert a ref_codes tensor `[T_frames, 16]` to `Vec<Vec<u32>>` frame format.
1248    fn tensor_to_frame_codes(&self, codes: &Tensor) -> Result<FrameCodes> {
1249        let (n_frames, n_codebooks) = codes.dims2()?;
1250        let codes_u32 = codes.to_dtype(DType::U32)?;
1251        let mut frames = Vec::with_capacity(n_frames);
1252        for f in 0..n_frames {
1253            let frame_tensor = codes_u32.i(f)?; // [16]
1254            let frame_vec: Vec<u32> = frame_tensor.to_vec1()?;
1255            debug_assert_eq!(frame_vec.len(), n_codebooks);
1256            frames.push(frame_vec);
1257        }
1258        Ok(frames)
1259    }
1260
1261    /// Sum reference codec embeddings across all 16 codebook groups.
1262    ///
1263    /// For each frame:
1264    /// - Group 0 (semantic): `talker.codec_embedding(ref_codes[:, 0])`
1265    /// - Groups 1–15 (acoustic): `code_predictor.codec_embeddings[i-1](ref_codes[:, i])`
1266    /// - Sum all 16 → single embedding per frame
1267    ///
1268    /// # Arguments
1269    /// * `ref_codes` — shape `[T_frames, 16]` of i64 codes
1270    ///
1271    /// # Returns
1272    /// Tensor of shape `[1, T_frames, hidden_size]`
1273    fn sum_ref_codec_embeddings(&self, ref_codes: &Tensor) -> Result<Tensor> {
1274        // Group 0: semantic codes → talker.codec_embedding
1275        let semantic_codes = ref_codes.i((.., 0))?; // [T_frames]
1276        let semantic_codes = semantic_codes.to_dtype(candle_core::DType::U32)?;
1277        let summed = self.talker.get_codec_embedding_batch(&semantic_codes)?; // [1, T, hidden]
1278
1279        // Groups 1-15: acoustic codes → code_predictor.embed_codes_for_group
1280        let mut summed = summed;
1281        for group in 1..16 {
1282            let group_codes = ref_codes.i((.., group))?; // [T_frames]
1283            let group_codes = group_codes.to_dtype(candle_core::DType::U32)?;
1284            let group_embed = self
1285                .code_predictor
1286                .embed_codes_for_group(group - 1, &group_codes)?; // [1, T, embed_dim]
1287            summed = summed.add(&group_embed)?;
1288        }
1289
1290        Ok(summed)
1291    }
1292
1293    /// Build default trailing text embeddings (for non-ICL mode).
1294    fn build_default_trailing_text(&self, input_ids: &[u32]) -> Result<Tensor> {
1295        let (hidden, _len, _pad) = self.build_trailing_text(input_ids)?;
1296        Ok(hidden)
1297    }
1298
1299    /// Apply repetition penalty, token suppression, and min_new_tokens EOS suppression
1300    /// using a pre-built `[1, vocab]` penalty mask on GPU.
1301    ///
1302    /// The mask is updated incrementally via [`update_penalty_mask`] after each
1303    /// sampled token, eliminating the O(n) GPU→CPU transfer that grows with
1304    /// each frame.
1305    fn apply_generation_penalties_gpu(
1306        &self,
1307        logits: &Tensor,
1308        penalty_mask: &Tensor,
1309        config: &generation::GenerationConfig,
1310        token_count: usize,
1311        suppression_mask: Option<&generation::SuppressionMask>,
1312    ) -> Result<Tensor> {
1313        let logits = logits.to_dtype(DType::F32)?;
1314
1315        // 1. Repetition penalty via pre-built GPU mask
1316        let logits = if config.repetition_penalty != 1.0 {
1317            generation::apply_repetition_penalty_with_mask(
1318                &logits,
1319                penalty_mask,
1320                config.repetition_penalty,
1321            )?
1322        } else {
1323            logits
1324        };
1325
1326        // 2. Token suppression
1327        let logits = if let Some(mask) = suppression_mask {
1328            generation::apply_token_suppression_with_mask(&logits, mask)?
1329        } else {
1330            generation::apply_token_suppression(
1331                &logits,
1332                codec_tokens::CODEC_VOCAB_SIZE,
1333                CODEC_EOS_TOKEN_ID,
1334            )?
1335        };
1336
1337        // 3. Min new tokens EOS suppression
1338        if token_count < config.min_new_tokens {
1339            if let Some(eos_id) = config.eos_token_id {
1340                let vocab = logits.dim(1)?;
1341                let batch = logits.dim(0)?;
1342                let mut mask_data = vec![0.0f32; vocab];
1343                mask_data[eos_id as usize] = 1.0;
1344                let eos_mask = Tensor::new(mask_data.as_slice(), logits.device())?
1345                    .unsqueeze(0)?
1346                    .broadcast_as((batch, vocab))?;
1347                let neg_inf = Tensor::new(&[f32::NEG_INFINITY], logits.device())?
1348                    .broadcast_as((batch, vocab))?;
1349                let zeros = Tensor::zeros((batch, vocab), DType::F32, logits.device())?;
1350                let is_eos = eos_mask.gt(&zeros)?;
1351                return Ok(is_eos.where_cond(&neg_inf, &logits)?);
1352            }
1353        }
1354
1355        Ok(logits)
1356    }
1357
1358    /// Returns `true` if a speech encoder is loaded (ICL voice cloning is available).
1359    pub fn has_speech_encoder(&self) -> bool {
1360        self.speech_encoder.is_some()
1361    }
1362
1363    // ── Private helpers ─────────────────────────────────────────────────
1364
1365    /// Attempt to load the speaker encoder from model weights.
1366    ///
1367    /// Returns `Ok(Some(encoder))` if `speaker_encoder.*` keys are found,
1368    /// `Ok(None)` if they are absent. When `config` is provided, uses the
1369    /// parsed enc_dim; otherwise falls back to defaults (enc_dim=1024).
1370    fn try_load_speaker_encoder(
1371        weights: &HashMap<String, Tensor>,
1372        config: Option<&SpeakerEncoderConfig>,
1373        device: &Device,
1374    ) -> Result<Option<SpeakerEncoder>> {
1375        let has_se_weights = weights.keys().any(|k| k.starts_with("speaker_encoder."));
1376        if !has_se_weights {
1377            return Ok(None);
1378        }
1379
1380        let config = config.cloned().unwrap_or_default();
1381        tracing::info!(
1382            "Loading speaker encoder (ECAPA-TDNN, enc_dim={}) for voice cloning...",
1383            config.enc_dim
1384        );
1385        let se_weights = Self::filter_weights(weights, "speaker_encoder.");
1386        let se_vb = candle_nn::VarBuilder::from_tensors(se_weights, DType::F32, device);
1387        let encoder = SpeakerEncoder::new(config, se_vb)?;
1388        Ok(Some(encoder))
1389    }
1390
1391    /// Attempt to load the speech encoder (Mimi) from speech tokenizer weights.
1392    ///
1393    /// The speech encoder encodes raw audio to 12Hz codec codes, needed for
1394    /// ICL voice cloning. Returns `Ok(None)` if encoder keys are absent or
1395    /// loading fails (non-fatal — ICL mode just won't be available).
1396    fn try_load_speech_encoder(
1397        weights: &HashMap<String, Tensor>,
1398        device: &Device,
1399    ) -> Result<Option<Encoder12Hz>> {
1400        // Check for encoder-related keys (either HF or candle format)
1401        let has_encoder_keys = weights
1402            .keys()
1403            .any(|k| k.starts_with("encoder.") || k.starts_with("encoder_transformer."));
1404        if !has_encoder_keys {
1405            return Ok(None);
1406        }
1407
1408        tracing::debug!("Attempting to load speech encoder (Mimi) for ICL voice cloning...");
1409        match Encoder12Hz::from_weights(weights, device) {
1410            Ok(enc) => {
1411                tracing::info!("Loaded speech encoder — ICL voice cloning available");
1412                Ok(Some(enc))
1413            }
1414            Err(e) => {
1415                tracing::debug!(
1416                    "Speech encoder not available ({}). ICL voice cloning disabled.",
1417                    e
1418                );
1419                Ok(None)
1420            }
1421        }
1422    }
1423
1424    /// Load weights from safetensors file.
1425    ///
1426    /// Tensors are loaded in their native dtype (typically BF16 for Qwen3-TTS).
1427    /// Each component's VarBuilder handles casting to its target dtype.
1428    fn load_weights(path: &Path, device: &Device) -> Result<HashMap<String, Tensor>> {
1429        Ok(candle_core::safetensors::load(path, device)?)
1430    }
1431
1432    /// Filter weights by prefix, removing the prefix from keys.
1433    pub(crate) fn filter_weights(
1434        weights: &HashMap<String, Tensor>,
1435        prefix: &str,
1436    ) -> HashMap<String, Tensor> {
1437        weights
1438            .iter()
1439            .filter_map(|(k, v)| {
1440                k.strip_prefix(prefix)
1441                    .map(|stripped| (stripped.to_string(), v.clone()))
1442            })
1443            .collect()
1444    }
1445}
1446
1447/// Convert a slice of codec frames into a tensor of shape `[1, 16, T]`.
1448///
1449/// Each frame must contain exactly 16 codebook values. The output layout is
1450/// `[q0_f0, q0_f1, ...], [q1_f0, q1_f1, ...]` matching the decoder's expectation.
1451pub fn codes_to_tensor(codes: &[Vec<u32>], device: &Device) -> Result<Tensor> {
1452    let num_frames = codes.len();
1453    if num_frames == 0 {
1454        return Ok(Tensor::zeros((1, 16, 0), DType::I64, device)?);
1455    }
1456
1457    let mut data = vec![0i64; 16 * num_frames];
1458    for (frame, frame_codes) in codes.iter().enumerate() {
1459        for (q, &code) in frame_codes.iter().enumerate() {
1460            data[q * num_frames + frame] = code as i64;
1461        }
1462    }
1463
1464    Ok(Tensor::from_vec(data, (1, 16, num_frames), device)?)
1465}
1466
1467/// Return the recommended compute dtype for the given device.
1468///
1469/// Returns `BF16` for CUDA/Metal (lower memory, faster attention) and `F32` for CPU.
1470pub fn compute_dtype_for_device(device: &Device) -> DType {
1471    if device.is_cuda() || device.is_metal() {
1472        DType::BF16
1473    } else {
1474        DType::F32
1475    }
1476}
1477
1478/// Force the GPU to complete all pending work before returning.
1479///
1480/// On CUDA/Metal, GPU operations are asynchronous — `Instant::now()` would
1481/// measure submission time, not completion time. This helper forces a sync
1482/// by creating a tiny tensor and reading it back to the CPU.
1483///
1484/// On CPU this is a no-op.
1485pub fn sync_device(device: &Device) -> Result<()> {
1486    match device {
1487        Device::Cpu => Ok(()),
1488        _ => {
1489            // Force a GPU→CPU sync by reading a scalar back
1490            let _: Vec<f32> = Tensor::zeros(1, DType::F32, device)?.to_vec1()?;
1491            Ok(())
1492        }
1493    }
1494}
1495
1496/// The codec end-of-sequence token ID (2150).
1497///
1498/// Generation stops when this token is sampled. This is in the codec vocabulary
1499/// `[0, 3072)`, not the text vocabulary.
1500pub const CODEC_EOS_TOKEN_ID: u32 = codec_tokens::CODEC_EOS;
1501
1502/// Number of audio samples per codec frame at 24kHz (1920 = 80ms per frame at 12Hz).
1503pub const SAMPLES_PER_FRAME: usize = 1920;
1504
1505/// ICL mode: minimum frames to generate regardless of text length (matching mlx-audio)
1506const ICL_MIN_FRAMES: usize = 75;
1507
1508/// ICL mode: estimated frames per input text token for max-length cap (matching mlx-audio)
1509const ICL_FRAMES_PER_TOKEN: usize = 6;
1510
1511/// ICL mode: minimum repetition penalty to prevent degenerate loops (matching mlx-audio)
1512const ICL_MIN_REPETITION_PENALTY: f64 = 1.5;
1513
1514/// Streaming synthesis session.
1515///
1516/// Yields audio chunks as they are generated. Use with
1517/// [`Qwen3TTS::synthesize_streaming`].
1518pub struct StreamingSession<'a> {
1519    model: &'a Qwen3TTS,
1520    config: generation::GenerationConfig,
1521    sampling_ctx: generation::SamplingContext,
1522    kv_caches: Vec<AnyKVCache>,
1523    offset: usize,
1524    last_hidden: Tensor,
1525    current_token: Option<u32>,
1526    current_token_tensor: Option<Tensor>,
1527    frames_generated: usize,
1528    frame_buffer: FrameCodes,
1529    chunk_frames: usize,
1530    done: bool,
1531    // Trailing text state for residual VQ + text fusion
1532    trailing_text_hidden: Tensor,
1533    trailing_text_len: usize,
1534    tts_pad_embed: Tensor,
1535    // GPU-side repetition penalty mask [1, vocab] — updated incrementally
1536    penalty_mask: Tensor,
1537    token_count: usize,
1538    // Pre-built suppression mask (reused every frame)
1539    suppression_mask: generation::SuppressionMask,
1540    // Pre-allocated code predictor KV caches (reused + reset each frame)
1541    cp_kv_caches: Vec<AnyKVCache>,
1542}
1543
1544impl<'a> StreamingSession<'a> {
1545    fn new(
1546        model: &'a Qwen3TTS,
1547        input_ids: &[u32],
1548        speaker: Speaker,
1549        language: Language,
1550        options: SynthesisOptions,
1551    ) -> Result<Self> {
1552        let sampling_ctx = generation::SamplingContext::new(options.seed);
1553        let config = options.to_gen_config();
1554
1555        let (trailing_text_hidden, trailing_text_len, tts_pad_embed) =
1556            model.build_trailing_text(input_ids)?;
1557
1558        let mut kv_caches = model.talker.new_kv_caches(config.max_new_tokens + 256);
1559        let prefill_result =
1560            model
1561                .talker
1562                .prefill_custom_voice(input_ids, speaker, language, &mut kv_caches)?;
1563
1564        Self::from_prefill(
1565            model,
1566            config,
1567            sampling_ctx,
1568            kv_caches,
1569            prefill_result,
1570            trailing_text_hidden,
1571            trailing_text_len,
1572            tts_pad_embed,
1573            options.chunk_frames,
1574        )
1575    }
1576
1577    /// Create a streaming session using voice design (text-described voice).
1578    ///
1579    /// Uses `prefill_voice_design` instead of `prefill_custom_voice` to condition
1580    /// on a natural language voice description rather than a predefined speaker.
1581    fn new_voice_design(
1582        model: &'a Qwen3TTS,
1583        input_ids: &[u32],
1584        instruct_ids: &[u32],
1585        language: Language,
1586        options: SynthesisOptions,
1587    ) -> Result<Self> {
1588        let sampling_ctx = generation::SamplingContext::new(options.seed);
1589        let config = options.to_gen_config();
1590
1591        let (trailing_text_hidden, trailing_text_len, tts_pad_embed) =
1592            model.build_trailing_text(input_ids)?;
1593
1594        let mut kv_caches = model.talker.new_kv_caches(config.max_new_tokens + 256);
1595        let prefill_result =
1596            model
1597                .talker
1598                .prefill_voice_design(input_ids, instruct_ids, language, &mut kv_caches)?;
1599
1600        Self::from_prefill(
1601            model,
1602            config,
1603            sampling_ctx,
1604            kv_caches,
1605            prefill_result,
1606            trailing_text_hidden,
1607            trailing_text_len,
1608            tts_pad_embed,
1609            options.chunk_frames,
1610        )
1611    }
1612
1613    /// Shared post-prefill constructor.
1614    ///
1615    /// Extracts `last_hidden` from the prefill result, builds the suppression and
1616    /// penalty masks, samples the first semantic token, and assembles the session.
1617    #[allow(clippy::too_many_arguments)]
1618    fn from_prefill(
1619        model: &'a Qwen3TTS,
1620        config: generation::GenerationConfig,
1621        mut sampling_ctx: generation::SamplingContext,
1622        kv_caches: Vec<AnyKVCache>,
1623        prefill_result: (Tensor, Tensor),
1624        trailing_text_hidden: Tensor,
1625        trailing_text_len: usize,
1626        tts_pad_embed: Tensor,
1627        chunk_frames: usize,
1628    ) -> Result<Self> {
1629        let (hidden, logits) = prefill_result;
1630        let prefill_len = hidden.dim(1)?;
1631        let last_hidden = hidden.i((.., prefill_len - 1..prefill_len, ..))?;
1632
1633        // Build suppression mask once for reuse across all frames
1634        let suppression_mask = generation::build_suppression_mask(
1635            codec_tokens::CODEC_VOCAB_SIZE,
1636            CODEC_EOS_TOKEN_ID,
1637            &model.device,
1638        )?;
1639
1640        // Sample first token with full penalty pipeline
1641        let vocab_size = codec_tokens::CODEC_VOCAB_SIZE;
1642        let mut penalty_mask = Tensor::zeros((1, vocab_size), DType::F32, &model.device)?;
1643        let logits_2d = logits.squeeze(1)?;
1644        let logits_2d = model.apply_generation_penalties_gpu(
1645            &logits_2d,
1646            &penalty_mask,
1647            &config,
1648            0,
1649            Some(&suppression_mask),
1650        )?;
1651        let first_token = generation::sample(&logits_2d, &config, &mut sampling_ctx)?;
1652        let first_token_id: u32 = first_token.flatten_all()?.to_vec1::<u32>()?[0];
1653        Qwen3TTS::update_penalty_mask(&mut penalty_mask, first_token_id, vocab_size)?;
1654
1655        let done = config.eos_token_id == Some(first_token_id);
1656        let cp_kv_caches = model.code_predictor.new_kv_caches();
1657
1658        Ok(Self {
1659            model,
1660            config,
1661            sampling_ctx,
1662            kv_caches,
1663            offset: prefill_len,
1664            last_hidden,
1665            current_token: if done { None } else { Some(first_token_id) },
1666            current_token_tensor: if done { None } else { Some(first_token) },
1667            frames_generated: 0,
1668            frame_buffer: Vec::new(),
1669            chunk_frames,
1670            done,
1671            trailing_text_hidden,
1672            trailing_text_len,
1673            tts_pad_embed,
1674            penalty_mask,
1675            token_count: 1,
1676            suppression_mask,
1677            cp_kv_caches,
1678        })
1679    }
1680
1681    /// Generate the next chunk of audio.
1682    ///
1683    /// Returns `Some(AudioBuffer)` for each chunk, or `None` when generation is complete.
1684    pub fn next_chunk(&mut self) -> Result<Option<AudioBuffer>> {
1685        if self.done {
1686            // Flush remaining buffer
1687            if !self.frame_buffer.is_empty() {
1688                let codes = self.model.codes_to_tensor(&self.frame_buffer)?;
1689                self.frame_buffer.clear();
1690                let audio = self.model.decoder.decode(&codes)?;
1691                return Ok(Some(AudioBuffer::from_tensor(audio, 24000)?));
1692            }
1693            return Ok(None);
1694        }
1695
1696        // Generate frames until we have enough for a chunk
1697        while self.frame_buffer.len() < self.chunk_frames
1698            && self.frames_generated < self.config.max_new_tokens
1699        {
1700            let (token_id, token_tensor) =
1701                match (self.current_token, self.current_token_tensor.take()) {
1702                    (Some(id), Some(t)) => (id, t),
1703                    _ => {
1704                        self.done = true;
1705                        break;
1706                    }
1707                };
1708
1709            // Embedding lookup using GPU-resident token tensor (no CPU→GPU roundtrip)
1710            let semantic_embed = self
1711                .model
1712                .talker
1713                .get_codec_embedding_from_tensor(&token_tensor)?;
1714
1715            // Generate 15 acoustic codes (stays on GPU)
1716            let acoustic_codes_tensor = self.model.code_predictor.generate_acoustic_codes(
1717                &self.last_hidden,
1718                &semantic_embed,
1719                &mut self.cp_kv_caches,
1720            )?;
1721
1722            // Build frame on GPU, then transfer for frame_buffer
1723            let semantic_t = Tensor::new(&[token_id], self.model.device())?;
1724            let frame_tensor = Tensor::cat(&[&semantic_t, &acoustic_codes_tensor], 0)?;
1725            let frame_codes: Vec<u32> = frame_tensor.to_vec1()?;
1726            self.frame_buffer.push(frame_codes);
1727
1728            let frame_idx = self.frames_generated;
1729            self.frames_generated += 1;
1730
1731            // Build residual VQ sum + trailing text for next step
1732            let acoustic_embed_sum = self
1733                .model
1734                .code_predictor
1735                .get_acoustic_embeddings_sum_from_tensor(&acoustic_codes_tensor)?;
1736            let summed = semantic_embed.add(&acoustic_embed_sum)?;
1737
1738            let text_addition = if frame_idx < self.trailing_text_len {
1739                self.trailing_text_hidden
1740                    .i((.., frame_idx..frame_idx + 1, ..))?
1741            } else {
1742                self.tts_pad_embed.clone()
1743            };
1744            let step_input = summed.add(&text_addition)?;
1745
1746            // Run talker step with fused embedding
1747            let (h, new_logits) = self.model.talker.generate_step_with_embed(
1748                &step_input,
1749                &mut self.kv_caches,
1750                self.offset,
1751            )?;
1752            self.offset += 1;
1753            self.last_hidden = h;
1754
1755            // Sample next semantic token with repetition penalty + token suppression + min_new_tokens
1756            let logits_2d = new_logits.squeeze(1)?;
1757            let logits_2d = self.model.apply_generation_penalties_gpu(
1758                &logits_2d,
1759                &self.penalty_mask,
1760                &self.config,
1761                self.token_count,
1762                Some(&self.suppression_mask),
1763            )?;
1764            let next_token_tensor =
1765                generation::sample(&logits_2d, &self.config, &mut self.sampling_ctx)?;
1766            let next_token_id: u32 = next_token_tensor.flatten_all()?.to_vec1::<u32>()?[0];
1767            Qwen3TTS::update_penalty_mask(
1768                &mut self.penalty_mask,
1769                next_token_id,
1770                codec_tokens::CODEC_VOCAB_SIZE,
1771            )?;
1772            self.token_count += 1;
1773
1774            if self.config.eos_token_id == Some(next_token_id) {
1775                self.current_token = None;
1776                self.current_token_tensor = None;
1777                self.done = true;
1778            } else {
1779                self.current_token = Some(next_token_id);
1780                self.current_token_tensor = Some(next_token_tensor);
1781            }
1782        }
1783
1784        // Decode the buffered frames
1785        if self.frame_buffer.is_empty() {
1786            return Ok(None);
1787        }
1788
1789        let codes = self.model.codes_to_tensor(&self.frame_buffer)?;
1790        self.frame_buffer.clear();
1791        let audio = self.model.decoder.decode(&codes)?;
1792        Ok(Some(AudioBuffer::from_tensor(audio, 24000)?))
1793    }
1794
1795    /// Returns the total number of frames generated so far.
1796    pub fn frames_generated(&self) -> usize {
1797        self.frames_generated
1798    }
1799
1800    /// Returns true if generation is complete.
1801    pub fn is_done(&self) -> bool {
1802        self.done && self.frame_buffer.is_empty()
1803    }
1804}
1805
1806impl<'a> Iterator for StreamingSession<'a> {
1807    type Item = Result<AudioBuffer>;
1808
1809    fn next(&mut self) -> Option<Self::Item> {
1810        match self.next_chunk() {
1811            Ok(Some(audio)) => Some(Ok(audio)),
1812            Ok(None) => None,
1813            Err(e) => Some(Err(e)),
1814        }
1815    }
1816}
1817
1818/// Options for speech synthesis
1819#[derive(Debug, Clone)]
1820pub struct SynthesisOptions {
1821    /// Maximum number of frames to generate
1822    pub max_length: usize,
1823    /// Sampling temperature (higher = more random)
1824    pub temperature: f64,
1825    /// Top-k sampling
1826    pub top_k: usize,
1827    /// Top-p (nucleus) sampling
1828    pub top_p: f64,
1829    /// Repetition penalty (1.0 = disabled, 1.05 = Python default)
1830    pub repetition_penalty: f64,
1831    /// End-of-sequence token ID (defaults to codec EOS token 2150)
1832    pub eos_token_id: Option<u32>,
1833    /// Frames per streaming chunk (default: 10 = ~800ms)
1834    pub chunk_frames: usize,
1835    /// Minimum tokens before EOS is allowed (default: 2, matching Python)
1836    pub min_new_tokens: usize,
1837    /// Random seed for deterministic generation. `None` = non-deterministic.
1838    pub seed: Option<u64>,
1839}
1840
1841impl SynthesisOptions {
1842    /// Convert to a [`GenerationConfig`](generation::GenerationConfig) for the generation loop.
1843    pub fn to_gen_config(&self) -> generation::GenerationConfig {
1844        generation::GenerationConfig {
1845            max_new_tokens: self.max_length,
1846            temperature: self.temperature,
1847            top_k: self.top_k,
1848            top_p: self.top_p,
1849            repetition_penalty: self.repetition_penalty,
1850            eos_token_id: self.eos_token_id,
1851            min_new_tokens: self.min_new_tokens,
1852        }
1853    }
1854}
1855
1856impl Default for SynthesisOptions {
1857    fn default() -> Self {
1858        Self {
1859            max_length: 2048,
1860            temperature: 0.9,
1861            top_k: 50,
1862            top_p: 0.9,
1863            repetition_penalty: 1.05,
1864            eos_token_id: Some(CODEC_EOS_TOKEN_ID),
1865            chunk_frames: 10, // ~800ms per chunk at 12.5 Hz
1866            min_new_tokens: 2,
1867            seed: None,
1868        }
1869    }
1870}
1871
1872/// Select the best available compute device for inference.
1873///
1874/// Checks for available hardware in order: CUDA → Metal → CPU.
1875/// Falls back to CPU if no GPU acceleration is available.
1876///
1877/// # Feature Flags
1878///
1879/// - `cuda`: Enables NVIDIA GPU support
1880/// - `metal`: Enables Apple Silicon GPU support
1881///
1882/// # Example
1883///
1884/// ```rust,ignore
1885/// let device = qwen3_tts::auto_device()?;
1886/// let model = Qwen3TTS::from_pretrained("path/to/model", device)?;
1887/// ```
1888pub fn auto_device() -> Result<Device> {
1889    #[cfg(feature = "cuda")]
1890    {
1891        if let Ok(device) = Device::cuda_if_available(0) {
1892            if device.is_cuda() {
1893                tracing::info!("Using CUDA device");
1894                return Ok(device);
1895            }
1896        }
1897    }
1898
1899    #[cfg(feature = "metal")]
1900    {
1901        if let Ok(device) = Device::new_metal(0) {
1902            tracing::info!("Using Metal device");
1903            return Ok(device);
1904        }
1905    }
1906
1907    tracing::info!("Using CPU device");
1908    Ok(Device::Cpu)
1909}
1910
1911/// Parse a device string into a [`Device`].
1912///
1913/// Supported formats:
1914/// - `"auto"` — select best available via [`auto_device`]
1915/// - `"cpu"` — force CPU
1916/// - `"cuda"` or `"cuda:0"` — CUDA device 0
1917/// - `"cuda:N"` — CUDA device N
1918/// - `"metal"` — Apple Silicon GPU
1919///
1920/// # Errors
1921///
1922/// Returns an error if the device string is unrecognized, the requested
1923/// backend wasn't compiled in, or hardware initialization fails.
1924pub fn parse_device(device_str: &str) -> Result<Device> {
1925    match device_str.to_lowercase().as_str() {
1926        "auto" => auto_device(),
1927        "cpu" => Ok(Device::Cpu),
1928        s if s.starts_with("cuda") => {
1929            #[cfg(feature = "cuda")]
1930            {
1931                let ordinal: usize = if s == "cuda" {
1932                    0
1933                } else if let Some(idx) = s.strip_prefix("cuda:") {
1934                    idx.parse()
1935                        .map_err(|e| anyhow::anyhow!("invalid CUDA device index: {e}"))?
1936                } else {
1937                    0
1938                };
1939                Device::cuda_if_available(ordinal)
1940                    .map_err(|e| anyhow::anyhow!("failed to init CUDA device {ordinal}: {e}"))
1941            }
1942            #[cfg(not(feature = "cuda"))]
1943            anyhow::bail!("CUDA support not compiled in. Rebuild with: cargo build --features cuda")
1944        }
1945        "metal" => {
1946            #[cfg(feature = "metal")]
1947            {
1948                Device::new_metal(0)
1949                    .map_err(|e| anyhow::anyhow!("failed to init Metal device: {e}"))
1950            }
1951            #[cfg(not(feature = "metal"))]
1952            anyhow::bail!(
1953                "Metal support not compiled in. Rebuild with: cargo build --features metal"
1954            )
1955        }
1956        other => {
1957            anyhow::bail!("unknown device '{other}'. Supported: auto, cpu, cuda, cuda:N, metal")
1958        }
1959    }
1960}
1961
1962/// Human-readable label for a [`Device`].
1963pub fn device_info(device: &Device) -> String {
1964    match device {
1965        Device::Cpu => "CPU".to_string(),
1966        Device::Cuda(_) => "CUDA".to_string(),
1967        Device::Metal(_) => "Metal".to_string(),
1968    }
1969}
1970
1971#[cfg(test)]
1972mod tests {
1973    use super::*;
1974
1975    #[test]
1976    fn test_synthesis_options_default() {
1977        let options = SynthesisOptions::default();
1978        assert_eq!(options.max_length, 2048);
1979        assert!((options.temperature - 0.9).abs() < 1e-6);
1980        assert_eq!(options.top_k, 50);
1981        assert!((options.top_p - 0.9).abs() < 1e-6);
1982        assert!((options.repetition_penalty - 1.05).abs() < 1e-6);
1983        assert_eq!(options.eos_token_id, Some(CODEC_EOS_TOKEN_ID));
1984        assert_eq!(options.chunk_frames, 10);
1985        assert_eq!(options.min_new_tokens, 2);
1986    }
1987
1988    #[test]
1989    fn test_synthesis_options_custom() {
1990        let options = SynthesisOptions {
1991            max_length: 512,
1992            temperature: 0.5,
1993            top_k: 10,
1994            top_p: 0.8,
1995            repetition_penalty: 1.2,
1996            eos_token_id: Some(CODEC_EOS_TOKEN_ID),
1997            chunk_frames: 5,
1998            min_new_tokens: 0,
1999            seed: Some(42),
2000        };
2001        assert_eq!(options.max_length, 512);
2002        assert!((options.temperature - 0.5).abs() < 1e-6);
2003        assert_eq!(options.eos_token_id, Some(CODEC_EOS_TOKEN_ID));
2004        assert_eq!(options.chunk_frames, 5);
2005    }
2006
2007    #[test]
2008    fn test_synthesis_options_clone() {
2009        let options = SynthesisOptions::default();
2010        let cloned = options.clone();
2011        assert_eq!(cloned.max_length, options.max_length);
2012        assert_eq!(cloned.top_k, options.top_k);
2013    }
2014
2015    #[test]
2016    fn test_synthesis_options_debug() {
2017        let options = SynthesisOptions::default();
2018        let debug_str = format!("{:?}", options);
2019        assert!(debug_str.contains("max_length"));
2020        assert!(debug_str.contains("2048"));
2021    }
2022
2023    #[test]
2024    fn test_auto_device() {
2025        // Should always succeed on CPU
2026        let device = auto_device().unwrap();
2027        // Just verify it returns a valid device
2028        assert!(
2029            matches!(device, Device::Cpu)
2030                || matches!(device, Device::Cuda(_))
2031                || matches!(device, Device::Metal(_))
2032        );
2033    }
2034
2035    #[test]
2036    fn test_audio_buffer_reexport() {
2037        // Verify re-exports work
2038        let buffer = AudioBuffer::new(vec![0.0f32; 100], 24000);
2039        assert_eq!(buffer.sample_rate, 24000);
2040    }
2041
2042    #[test]
2043    fn test_config_reexport() {
2044        // Verify config re-export works
2045        let config = Qwen3TTSConfig::default();
2046        assert_eq!(config.model_type, "qwen3_tts");
2047    }
2048
2049    #[test]
2050    fn test_codes_to_tensor_empty() {
2051        let device = Device::Cpu;
2052        let codes: Vec<Vec<u32>> = vec![];
2053        let tensor = codes_to_tensor(&codes, &device).unwrap();
2054        assert_eq!(tensor.dims(), &[1, 16, 0]);
2055    }
2056
2057    #[test]
2058    fn test_codes_to_tensor_single_frame() {
2059        let device = Device::Cpu;
2060        let codes = vec![vec![0u32; 16]];
2061        let tensor = codes_to_tensor(&codes, &device).unwrap();
2062        assert_eq!(tensor.dims(), &[1, 16, 1]);
2063    }
2064
2065    #[test]
2066    fn test_codes_to_tensor_layout() {
2067        let device = Device::Cpu;
2068        // 2 frames, each with 16 codebooks
2069        let codes = vec![
2070            (0..16).map(|i| i as u32).collect::<Vec<_>>(), // frame 0
2071            (100..116).map(|i| i as u32).collect::<Vec<_>>(), // frame 1
2072        ];
2073        let tensor = codes_to_tensor(&codes, &device).unwrap();
2074        assert_eq!(tensor.dims(), &[1, 16, 2]);
2075
2076        // Verify layout: tensor[0, q, frame] = codes[frame][q]
2077        let vals: Vec<i64> = tensor.flatten_all().unwrap().to_vec1().unwrap();
2078        // q=0: [frame0_q0, frame1_q0] = [0, 100]
2079        assert_eq!(vals[0], 0);
2080        assert_eq!(vals[1], 100);
2081        // q=1: [frame0_q1, frame1_q1] = [1, 101]
2082        assert_eq!(vals[2], 1);
2083        assert_eq!(vals[3], 101);
2084    }
2085
2086    #[test]
2087    fn test_parse_device_cpu() {
2088        let device = parse_device("cpu").unwrap();
2089        assert!(matches!(device, Device::Cpu));
2090    }
2091
2092    #[test]
2093    fn test_parse_device_auto() {
2094        let device = parse_device("auto").unwrap();
2095        // Should succeed regardless of hardware
2096        assert!(
2097            matches!(device, Device::Cpu)
2098                || matches!(device, Device::Cuda(_))
2099                || matches!(device, Device::Metal(_))
2100        );
2101    }
2102
2103    #[test]
2104    fn test_parse_device_unknown() {
2105        let result = parse_device("tpu");
2106        assert!(result.is_err());
2107    }
2108
2109    #[test]
2110    fn test_parse_device_case_insensitive() {
2111        let device = parse_device("CPU").unwrap();
2112        assert!(matches!(device, Device::Cpu));
2113    }
2114
2115    #[test]
2116    fn test_device_info() {
2117        assert_eq!(device_info(&Device::Cpu), "CPU");
2118    }
2119
2120    #[test]
2121    fn test_compute_dtype_for_device() {
2122        let dtype = compute_dtype_for_device(&Device::Cpu);
2123        assert_eq!(dtype, DType::F32);
2124    }
2125
2126    #[test]
2127    fn test_update_penalty_mask() {
2128        let device = Device::Cpu;
2129        let vocab_size = 3072;
2130        let mut mask = Tensor::zeros((1, vocab_size), DType::F32, &device).unwrap();
2131
2132        Qwen3TTS::update_penalty_mask(&mut mask, 42, vocab_size).unwrap();
2133
2134        let vals: Vec<f32> = mask.flatten_all().unwrap().to_vec1().unwrap();
2135        assert_eq!(vals[42], 1.0);
2136        // Neighboring positions should be untouched
2137        assert_eq!(vals[41], 0.0);
2138        assert_eq!(vals[43], 0.0);
2139    }
2140
2141    #[test]
2142    fn test_update_penalty_mask_out_of_range() {
2143        let device = Device::Cpu;
2144        let vocab_size = 3072;
2145        let mut mask = Tensor::zeros((1, vocab_size), DType::F32, &device).unwrap();
2146
2147        // Token beyond vocab_size should be a no-op (no panic)
2148        Qwen3TTS::update_penalty_mask(&mut mask, 9999, vocab_size).unwrap();
2149
2150        let sum: f32 = mask.sum_all().unwrap().to_scalar().unwrap();
2151        assert_eq!(sum, 0.0);
2152    }
2153
2154    #[test]
2155    fn test_suppression_mask_deterministic() {
2156        let device = Device::Cpu;
2157        let vocab = codec_tokens::CODEC_VOCAB_SIZE;
2158        let mask1 = generation::build_suppression_mask(vocab, CODEC_EOS_TOKEN_ID, &device).unwrap();
2159        let mask2 = generation::build_suppression_mask(vocab, CODEC_EOS_TOKEN_ID, &device).unwrap();
2160
2161        // Apply both masks to uniform logits and verify identical output
2162        let logits = Tensor::ones((1, vocab), DType::F32, &device).unwrap();
2163        let out1 = generation::apply_token_suppression_with_mask(&logits, &mask1).unwrap();
2164        let out2 = generation::apply_token_suppression_with_mask(&logits, &mask2).unwrap();
2165        let v1: Vec<f32> = out1.flatten_all().unwrap().to_vec1().unwrap();
2166        let v2: Vec<f32> = out2.flatten_all().unwrap().to_vec1().unwrap();
2167        assert_eq!(v1, v2);
2168    }
2169}