Skip to main content

qwen3_tts/models/
code_predictor.rs

1//! Code Predictor for Qwen3-TTS
2//!
3//! The code predictor generates acoustic tokens (groups 2-16) given the
4//! semantic token (group 1) and the hidden state from the talker model.
5//!
6//! Architecture:
7//! - 5 transformer layers with same structure as talker
8//! - 15 codec embeddings (one per acoustic group)
9//! - 15 lm_heads (one per acoustic group)
10
11use anyhow::Result;
12use candle_core::{IndexOp, Module, Tensor, D};
13use candle_nn::{embedding, linear_no_bias, rms_norm, Embedding, Linear, RmsNorm, VarBuilder};
14
15use super::config::Qwen3TTSConfig;
16use super::kv_cache::{AnyKVCache, KVCache, PreAllocKVCache};
17use super::transformer::{DecoderLayer, RoPEType, RotaryEmbedding};
18use candle_core::DType;
19
20/// Code predictor configuration
21#[derive(Debug, Clone)]
22pub struct CodePredictorConfig {
23    /// Hidden dimension
24    pub hidden_size: usize,
25    /// Intermediate size for MLP
26    pub intermediate_size: usize,
27    /// Number of transformer layers
28    pub num_hidden_layers: usize,
29    /// Number of attention heads
30    pub num_attention_heads: usize,
31    /// Number of KV heads (for GQA)
32    pub num_key_value_heads: usize,
33    /// Head dimension
34    pub head_dim: usize,
35    /// RMS norm epsilon
36    pub rms_norm_eps: f64,
37    /// RoPE theta
38    pub rope_theta: f64,
39    /// Vocabulary size for codec tokens
40    pub vocab_size: usize,
41    /// Number of code groups (total, including semantic)
42    pub num_code_groups: usize,
43    /// Codec embedding dimension (may differ from hidden_size for CustomVoice models)
44    /// When different from hidden_size, a small_to_mtp_projection is used
45    pub codec_embed_dim: Option<usize>,
46}
47
48impl Default for CodePredictorConfig {
49    fn default() -> Self {
50        Self {
51            hidden_size: 1024,
52            intermediate_size: 3072,
53            num_hidden_layers: 5,
54            num_attention_heads: 16,
55            num_key_value_heads: 8,
56            head_dim: 128,
57            rms_norm_eps: 1e-6,
58            rope_theta: 1000000.0,
59            vocab_size: 2048,
60            num_code_groups: 16,
61            codec_embed_dim: None, // When None, uses hidden_size
62        }
63    }
64}
65
66impl CodePredictorConfig {
67    /// Create config from parsed HuggingFace config.json.
68    ///
69    /// When the talker hidden_size differs from the code predictor hidden_size
70    /// (e.g. 1.7B models: talker=2048, CP=1024), `codec_embed_dim` is set to
71    /// the talker's hidden_size so the `small_to_mtp_projection` layer is created.
72    pub fn from_parsed(parsed: &super::config::ParsedModelConfig) -> Self {
73        let codec_embed_dim = if parsed.talker_hidden_size != parsed.cp_hidden_size {
74            Some(parsed.talker_hidden_size)
75        } else {
76            None
77        };
78        Self {
79            hidden_size: parsed.cp_hidden_size,
80            intermediate_size: parsed.cp_intermediate_size,
81            num_hidden_layers: parsed.cp_num_hidden_layers,
82            num_attention_heads: parsed.cp_num_attention_heads,
83            num_key_value_heads: parsed.cp_num_key_value_heads,
84            head_dim: parsed.cp_head_dim,
85            rms_norm_eps: parsed.cp_rms_norm_eps,
86            rope_theta: parsed.cp_rope_theta,
87            vocab_size: parsed.cp_vocab_size,
88            num_code_groups: parsed.cp_num_code_groups,
89            codec_embed_dim,
90        }
91    }
92
93    /// Get the codec embedding dimension (defaults to hidden_size)
94    pub fn codec_embed_dim(&self) -> usize {
95        self.codec_embed_dim.unwrap_or(self.hidden_size)
96    }
97
98    /// Create config for CustomVoice model
99    pub fn custom_voice() -> Self {
100        Self {
101            hidden_size: 1024,
102            intermediate_size: 3072,
103            num_hidden_layers: 5,
104            num_attention_heads: 16,
105            num_key_value_heads: 8,
106            head_dim: 128,
107            rms_norm_eps: 1e-6,
108            rope_theta: 1000000.0,
109            vocab_size: 2048,
110            num_code_groups: 16,
111            codec_embed_dim: Some(2048), // CustomVoice uses 2048-dim codec embeddings
112        }
113    }
114
115    /// Create a Qwen3TTSConfig for building decoder layers
116    fn to_layer_config(&self) -> Qwen3TTSConfig {
117        Qwen3TTSConfig {
118            hidden_size: self.hidden_size,
119            intermediate_size: self.intermediate_size,
120            num_hidden_layers: self.num_hidden_layers,
121            num_attention_heads: self.num_attention_heads,
122            num_key_value_heads: Some(self.num_key_value_heads),
123            head_dim_override: Some(self.head_dim),
124            rms_norm_eps: self.rms_norm_eps,
125            rope_theta: self.rope_theta,
126            vocab_size: self.vocab_size,
127            ..Default::default()
128        }
129    }
130}
131
132/// Code predictor model
133pub struct CodePredictor {
134    /// Codec embeddings for each acoustic group (0-14 for groups 2-16)
135    codec_embeddings: Vec<Embedding>,
136    /// Projection from codec_embed_dim to hidden_size (for CustomVoice models)
137    small_to_mtp_projection: Option<Linear>,
138    /// Transformer layers
139    layers: Vec<DecoderLayer>,
140    /// Final normalization
141    norm: RmsNorm,
142    /// LM heads for each acoustic group (0-14 for groups 2-16)
143    lm_heads: Vec<Linear>,
144    /// Rotary embeddings
145    rope: RoPEType,
146    /// Configuration
147    config: CodePredictorConfig,
148    /// Cached causal mask for prefill (always 2×2, created once)
149    prefill_mask: Tensor,
150    /// Device (needed for PreAllocKVCache creation)
151    device: candle_core::Device,
152    /// Compute dtype (needed for PreAllocKVCache creation)
153    dtype: DType,
154}
155
156impl CodePredictor {
157    /// Create new code predictor
158    pub fn new(config: CodePredictorConfig, vb: VarBuilder) -> Result<Self> {
159        let layer_config = config.to_layer_config();
160        let num_acoustic_groups = config.num_code_groups - 1;
161        let codec_embed_dim = config.codec_embed_dim();
162
163        // Create codec embeddings (one per acoustic group)
164        // Note: for CustomVoice, codec_embed_dim (2048) differs from hidden_size (1024)
165        let mut codec_embeddings = Vec::with_capacity(num_acoustic_groups);
166        for i in 0..num_acoustic_groups {
167            codec_embeddings.push(embedding(
168                config.vocab_size,
169                codec_embed_dim,
170                vb.pp(format!("model.codec_embedding.{}", i)),
171            )?);
172        }
173
174        // Projection layer for CustomVoice models (2048 -> 1024)
175        let small_to_mtp_projection = if codec_embed_dim != config.hidden_size {
176            Some(candle_nn::linear(
177                codec_embed_dim,
178                config.hidden_size,
179                vb.pp("small_to_mtp_projection"),
180            )?)
181        } else {
182            None
183        };
184
185        // Create transformer layers
186        let mut layers = Vec::with_capacity(config.num_hidden_layers);
187        for i in 0..config.num_hidden_layers {
188            layers.push(DecoderLayer::new(
189                &layer_config,
190                vb.pp(format!("model.layers.{}", i)),
191            )?);
192        }
193
194        // Final norm
195        let norm = rms_norm(config.hidden_size, config.rms_norm_eps, vb.pp("model.norm"))?;
196
197        // LM heads (one per acoustic group)
198        let mut lm_heads = Vec::with_capacity(num_acoustic_groups);
199        for i in 0..num_acoustic_groups {
200            lm_heads.push(linear_no_bias(
201                config.hidden_size,
202                config.vocab_size,
203                vb.pp(format!("lm_head.{}", i)),
204            )?);
205        }
206
207        // Rotary embeddings
208        let rope = RoPEType::Standard(RotaryEmbedding::new(
209            config.head_dim,
210            1024, // Max sequence length for code predictor
211            config.rope_theta,
212            vb.device(),
213        )?);
214
215        // Pre-build the 2×2 causal mask for prefill (talker_hidden + semantic_embed).
216        // This never changes, so building it once avoids per-frame allocation.
217        let prefill_mask = super::transformer::create_causal_mask(2, 0, vb.device())?;
218
219        let device = vb.device().clone();
220        let dtype = vb.dtype();
221
222        Ok(Self {
223            codec_embeddings,
224            small_to_mtp_projection,
225            layers,
226            norm,
227            lm_heads,
228            rope,
229            config,
230            prefill_mask,
231            device,
232            dtype,
233        })
234    }
235
236    /// Generate next token logits for a specific group
237    ///
238    /// # Arguments
239    /// * `hidden` - Hidden states from forward pass, shape [batch, seq, hidden]
240    /// * `group_idx` - Which acoustic group (0-14 for groups 2-16)
241    /// * `position` - Which position to use for prediction
242    pub fn get_logits(&self, hidden: &Tensor, group_idx: usize, position: usize) -> Result<Tensor> {
243        let pos_hidden = hidden.i((.., position..position + 1, ..))?;
244        Ok(self.lm_heads[group_idx].forward(&pos_hidden)?)
245    }
246
247    /// Run a prefill pass through the code predictor transformer layers.
248    ///
249    /// Takes pre-built hidden states (e.g. talker_hidden concatenated with code
250    /// embeddings), runs through all layers with KV caches, and returns the
251    /// normed hidden states. Use `get_logits` to extract per-group predictions.
252    ///
253    /// This is a low-level method for reference validation.
254    pub fn forward_prefill(
255        &self,
256        hidden: &Tensor,
257        _prev_codes: &[u32],
258        kv_caches: &mut [AnyKVCache],
259    ) -> Result<Tensor> {
260        let device = hidden.device();
261        let input = if let Some(proj) = &self.small_to_mtp_projection {
262            proj.forward(hidden)?
263        } else {
264            hidden.clone()
265        };
266
267        let seq_len = input.dim(1)?;
268        let mask = self.create_causal_mask(seq_len, device)?;
269
270        let mut h = input;
271        for (i, layer) in self.layers.iter().enumerate() {
272            h = layer.forward(&h, &self.rope, Some(&mask), Some(&mut kv_caches[i]), 0)?;
273        }
274        Ok(self.norm.forward(&h)?)
275    }
276
277    /// Create a set of KV caches for the code predictor (one per layer).
278    ///
279    /// Callers should create this once and pass it to [`CodePredictor::generate_acoustic_codes`]
280    /// on each frame — the method resets the caches internally, avoiding
281    /// per-frame allocation.
282    pub fn new_kv_caches(&self) -> Vec<AnyKVCache> {
283        // Code predictor: 2 prefill + 15 decode = 17 max tokens
284        const CP_MAX_SEQ: usize = 17;
285
286        (0..self.config.num_hidden_layers)
287            .map(|_| {
288                if self.device.is_cuda() || self.device.is_metal() {
289                    PreAllocKVCache::new(
290                        1, // batch
291                        self.config.num_key_value_heads,
292                        CP_MAX_SEQ,
293                        self.config.head_dim,
294                        self.dtype,
295                        &self.device,
296                    )
297                    .map(AnyKVCache::PreAlloc)
298                    .unwrap_or_else(|_| AnyKVCache::Concat(KVCache::new()))
299                } else {
300                    AnyKVCache::Concat(KVCache::new())
301                }
302            })
303            .collect()
304    }
305
306    /// Generate all 15 acoustic tokens autoregressively.
307    ///
308    /// Each acoustic code is predicted conditioned on the talker hidden state,
309    /// the semantic token embedding, and all previously generated acoustic codes.
310    /// Uses KV caching for sequential generation.
311    ///
312    /// # Arguments
313    /// * `talker_hidden` - Hidden state from talker model, shape `[batch, 1, hidden]`
314    /// * `semantic_embed` - Embedding of semantic token, shape `[batch, 1, hidden]`
315    /// * `cp_kv_caches` - Reusable KV caches (created via [`CodePredictor::new_kv_caches`]). Reset internally each call.
316    ///
317    /// # Returns
318    /// GPU tensor of shape `[num_acoustic]` containing the 15 acoustic code IDs.
319    /// Stays on device to avoid GPU→CPU sync; callers should use tensor ops directly.
320    pub fn generate_acoustic_codes(
321        &self,
322        talker_hidden: &Tensor,
323        semantic_embed: &Tensor,
324        cp_kv_caches: &mut [AnyKVCache],
325    ) -> Result<Tensor> {
326        #[cfg(feature = "profiling")]
327        let _span = tracing::info_span!("code_predictor_inner").entered();
328
329        // Reset caches from previous frame
330        for cache in cp_kv_caches.iter_mut() {
331            cache.reset();
332        }
333
334        let device = talker_hidden.device();
335        let num_acoustic = self.config.num_code_groups - 1; // 15 acoustic codes
336
337        // Step 1: Prefill with [talker_hidden, semantic_embed]
338        let input = Tensor::cat(&[talker_hidden, semantic_embed], 1)?;
339
340        // Apply projection if needed (CustomVoice: 2048 -> 1024)
341        let input = if let Some(proj) = &self.small_to_mtp_projection {
342            proj.forward(&input)?
343        } else {
344            input
345        };
346
347        let seq_len = input.dim(1)?;
348        // Use cached mask for the standard 2-token prefill, create on-the-fly otherwise
349        let dynamic_mask;
350        let mask = if seq_len == 2 {
351            &self.prefill_mask
352        } else {
353            dynamic_mask = self.create_causal_mask(seq_len, device)?;
354            &dynamic_mask
355        };
356
357        let mut hidden = input;
358        for (i, layer) in self.layers.iter().enumerate() {
359            hidden = layer.forward(
360                &hidden,
361                &self.rope,
362                Some(mask),
363                Some(&mut cp_kv_caches[i]),
364                0,
365            )?;
366        }
367        hidden = self.norm.forward(&hidden)?;
368
369        // Step 2: Predict first acoustic code from last position
370        // Keep codes as GPU tensors to avoid per-step GPU→CPU syncs.
371        // Pre-allocate a single [num_acoustic] tensor and write each code into it
372        // to avoid Tensor::cat overhead on many small tensors.
373        let last_hidden = hidden.i((.., seq_len - 1..seq_len, ..))?;
374        let logits = self.lm_heads[0].forward(&last_hidden)?;
375        let first_code = logits.argmax(D::Minus1)?.flatten_all()?; // [1] tensor on GPU
376
377        let mut all_codes = Tensor::zeros(num_acoustic, candle_core::DType::U32, device)?;
378        let range = 0..1;
379        all_codes = all_codes.slice_assign(&[range], &first_code)?;
380
381        // Also keep a reference to the latest code for embedding lookup
382        let mut prev_code = first_code;
383
384        // Step 3: Autoregressively generate remaining 14 codes
385        let mut offset = seq_len;
386        for group_idx in 1..num_acoustic {
387            // Embed previous code using the previous group's embedding (stays on GPU)
388            let code_embed = self.codec_embeddings[group_idx - 1].forward(&prev_code)?;
389            let code_embed = code_embed.unsqueeze(0)?; // [1, 1, codec_embed_dim]
390
391            // Apply projection if needed
392            let code_embed = if let Some(proj) = &self.small_to_mtp_projection {
393                proj.forward(&code_embed)?
394            } else {
395                code_embed
396            };
397
398            // Single token attending to all previous positions via KV cache —
399            // no masking needed (all-zeros mask is a no-op).
400            let mut h = code_embed;
401            for (i, layer) in self.layers.iter().enumerate() {
402                h = layer.forward(&h, &self.rope, None, Some(&mut cp_kv_caches[i]), offset)?;
403            }
404            h = self.norm.forward(&h)?;
405
406            // Predict next code (stays on GPU)
407            let logits = self.lm_heads[group_idx].forward(&h)?;
408            let next_code = logits.argmax(D::Minus1)?.flatten_all()?; // [1] tensor on GPU
409            let range = group_idx..group_idx + 1;
410            all_codes = all_codes.slice_assign(&[range], &next_code)?;
411            prev_code = next_code;
412            offset += 1;
413        }
414
415        Ok(all_codes)
416    }
417
418    fn create_causal_mask(&self, seq_len: usize, device: &candle_core::Device) -> Result<Tensor> {
419        super::transformer::create_causal_mask(seq_len, 0, device)
420    }
421
422    /// Get acoustic code embedding for a specific group
423    ///
424    /// group_idx: 0-14 for acoustic groups 2-16
425    /// Returns: [1, 1, codec_embed_dim] tensor
426    pub fn get_acoustic_embedding(
427        &self,
428        code: u32,
429        group_idx: usize,
430        device: &candle_core::Device,
431    ) -> Result<Tensor> {
432        if group_idx >= self.codec_embeddings.len() {
433            anyhow::bail!(
434                "Invalid group_idx {} (max {})",
435                group_idx,
436                self.codec_embeddings.len() - 1
437            );
438        }
439        let code_tensor = Tensor::new(&[code], device)?;
440        let embed = self.codec_embeddings[group_idx].forward(&code_tensor)?;
441        Ok(embed.unsqueeze(0)?) // [1, 1, codec_embed_dim]
442    }
443
444    /// Embed a sequence of codes for a specific acoustic group.
445    ///
446    /// Used by ICL voice cloning to build reference codec embeddings.
447    ///
448    /// # Arguments
449    /// * `group_idx` — acoustic group (0–14 for codebook groups 2–16)
450    /// * `codes` — 1-D i64 tensor of codec token IDs, shape `[T]`
451    ///
452    /// # Returns
453    /// Tensor of shape `[1, T, codec_embed_dim]`
454    pub fn embed_codes_for_group(&self, group_idx: usize, codes: &Tensor) -> Result<Tensor> {
455        if group_idx >= self.codec_embeddings.len() {
456            anyhow::bail!(
457                "Invalid group_idx {} (max {})",
458                group_idx,
459                self.codec_embeddings.len() - 1
460            );
461        }
462        let embed = self.codec_embeddings[group_idx].forward(codes)?; // [T, codec_embed_dim]
463        Ok(embed.unsqueeze(0)?) // [1, T, codec_embed_dim]
464    }
465
466    /// Get sum of all acoustic code embeddings
467    ///
468    /// acoustic_codes: 15 acoustic codes for groups 2-16
469    /// Returns: [1, 1, codec_embed_dim] tensor with summed embeddings
470    pub fn get_acoustic_embeddings_sum(
471        &self,
472        acoustic_codes: &[u32],
473        device: &candle_core::Device,
474    ) -> Result<Tensor> {
475        if acoustic_codes.len() != self.codec_embeddings.len() {
476            anyhow::bail!(
477                "Expected {} acoustic codes, got {}",
478                self.codec_embeddings.len(),
479                acoustic_codes.len()
480            );
481        }
482
483        let first = self.get_acoustic_embedding(acoustic_codes[0], 0, device)?;
484        acoustic_codes[1..]
485            .iter()
486            .enumerate()
487            .try_fold(first, |acc, (i, &code)| {
488                let embed = self.get_acoustic_embedding(code, i + 1, device)?;
489                acc.add(&embed).map_err(Into::into)
490            })
491    }
492
493    /// Get sum of all acoustic code embeddings from a GPU tensor.
494    ///
495    /// Like `get_acoustic_embeddings_sum` but takes codes as a \[num_acoustic\] tensor
496    /// already on device, avoiding 15 small CPU→GPU transfers.
497    pub fn get_acoustic_embeddings_sum_from_tensor(
498        &self,
499        acoustic_codes: &Tensor,
500    ) -> Result<Tensor> {
501        let n = acoustic_codes.dim(0)?;
502        if n != self.codec_embeddings.len() {
503            anyhow::bail!(
504                "Expected {} acoustic codes, got {}",
505                self.codec_embeddings.len(),
506                n
507            );
508        }
509
510        let first_code = acoustic_codes.narrow(0, 0, 1)?;
511        let first = self.codec_embeddings[0]
512            .forward(&first_code)?
513            .unsqueeze(0)?;
514        (1..n).try_fold(first, |acc, i| {
515            let code = acoustic_codes.narrow(0, i, 1)?;
516            let embed = self.codec_embeddings[i].forward(&code)?.unsqueeze(0)?;
517            acc.add(&embed).map_err(Into::into)
518        })
519    }
520}
521
522#[cfg(test)]
523mod tests {
524    use super::*;
525    use candle_core::{DType, Device};
526    use candle_nn::VarMap;
527
528    fn create_mock_vb(device: &Device) -> VarBuilder<'static> {
529        let varmap = VarMap::new();
530        VarBuilder::from_varmap(&varmap, DType::F32, device)
531    }
532
533    #[test]
534    fn test_config_default() {
535        let config = CodePredictorConfig::default();
536        assert_eq!(config.num_hidden_layers, 5);
537        assert_eq!(config.num_code_groups, 16);
538        assert_eq!(config.hidden_size, 1024);
539    }
540
541    #[test]
542    fn test_code_predictor_construction() {
543        let device = Device::Cpu;
544        let vb = create_mock_vb(&device);
545
546        let config = CodePredictorConfig {
547            hidden_size: 32,
548            intermediate_size: 64,
549            num_hidden_layers: 2,
550            num_attention_heads: 4,
551            num_key_value_heads: 2,
552            head_dim: 8,
553            vocab_size: 64,
554            num_code_groups: 4,
555            ..Default::default()
556        };
557
558        let predictor = CodePredictor::new(config, vb);
559        assert!(predictor.is_ok());
560
561        let predictor = predictor.unwrap();
562        assert_eq!(predictor.codec_embeddings.len(), 3); // 4-1 acoustic groups
563        assert_eq!(predictor.layers.len(), 2);
564        assert_eq!(predictor.lm_heads.len(), 3);
565    }
566}