1use anyhow::Result;
12use candle_core::{IndexOp, Module, Tensor, D};
13use candle_nn::{embedding, linear_no_bias, rms_norm, Embedding, Linear, RmsNorm, VarBuilder};
14
15use super::config::Qwen3TTSConfig;
16use super::kv_cache::{AnyKVCache, KVCache, PreAllocKVCache};
17use super::transformer::{DecoderLayer, RoPEType, RotaryEmbedding};
18use candle_core::DType;
19
20#[derive(Debug, Clone)]
22pub struct CodePredictorConfig {
23 pub hidden_size: usize,
25 pub intermediate_size: usize,
27 pub num_hidden_layers: usize,
29 pub num_attention_heads: usize,
31 pub num_key_value_heads: usize,
33 pub head_dim: usize,
35 pub rms_norm_eps: f64,
37 pub rope_theta: f64,
39 pub vocab_size: usize,
41 pub num_code_groups: usize,
43 pub codec_embed_dim: Option<usize>,
46}
47
48impl Default for CodePredictorConfig {
49 fn default() -> Self {
50 Self {
51 hidden_size: 1024,
52 intermediate_size: 3072,
53 num_hidden_layers: 5,
54 num_attention_heads: 16,
55 num_key_value_heads: 8,
56 head_dim: 128,
57 rms_norm_eps: 1e-6,
58 rope_theta: 1000000.0,
59 vocab_size: 2048,
60 num_code_groups: 16,
61 codec_embed_dim: None, }
63 }
64}
65
66impl CodePredictorConfig {
67 pub fn from_parsed(parsed: &super::config::ParsedModelConfig) -> Self {
73 let codec_embed_dim = if parsed.talker_hidden_size != parsed.cp_hidden_size {
74 Some(parsed.talker_hidden_size)
75 } else {
76 None
77 };
78 Self {
79 hidden_size: parsed.cp_hidden_size,
80 intermediate_size: parsed.cp_intermediate_size,
81 num_hidden_layers: parsed.cp_num_hidden_layers,
82 num_attention_heads: parsed.cp_num_attention_heads,
83 num_key_value_heads: parsed.cp_num_key_value_heads,
84 head_dim: parsed.cp_head_dim,
85 rms_norm_eps: parsed.cp_rms_norm_eps,
86 rope_theta: parsed.cp_rope_theta,
87 vocab_size: parsed.cp_vocab_size,
88 num_code_groups: parsed.cp_num_code_groups,
89 codec_embed_dim,
90 }
91 }
92
93 pub fn codec_embed_dim(&self) -> usize {
95 self.codec_embed_dim.unwrap_or(self.hidden_size)
96 }
97
98 pub fn custom_voice() -> Self {
100 Self {
101 hidden_size: 1024,
102 intermediate_size: 3072,
103 num_hidden_layers: 5,
104 num_attention_heads: 16,
105 num_key_value_heads: 8,
106 head_dim: 128,
107 rms_norm_eps: 1e-6,
108 rope_theta: 1000000.0,
109 vocab_size: 2048,
110 num_code_groups: 16,
111 codec_embed_dim: Some(2048), }
113 }
114
115 fn to_layer_config(&self) -> Qwen3TTSConfig {
117 Qwen3TTSConfig {
118 hidden_size: self.hidden_size,
119 intermediate_size: self.intermediate_size,
120 num_hidden_layers: self.num_hidden_layers,
121 num_attention_heads: self.num_attention_heads,
122 num_key_value_heads: Some(self.num_key_value_heads),
123 head_dim_override: Some(self.head_dim),
124 rms_norm_eps: self.rms_norm_eps,
125 rope_theta: self.rope_theta,
126 vocab_size: self.vocab_size,
127 ..Default::default()
128 }
129 }
130}
131
132pub struct CodePredictor {
134 codec_embeddings: Vec<Embedding>,
136 small_to_mtp_projection: Option<Linear>,
138 layers: Vec<DecoderLayer>,
140 norm: RmsNorm,
142 lm_heads: Vec<Linear>,
144 rope: RoPEType,
146 config: CodePredictorConfig,
148 prefill_mask: Tensor,
150 device: candle_core::Device,
152 dtype: DType,
154}
155
156impl CodePredictor {
157 pub fn new(config: CodePredictorConfig, vb: VarBuilder) -> Result<Self> {
159 let layer_config = config.to_layer_config();
160 let num_acoustic_groups = config.num_code_groups - 1;
161 let codec_embed_dim = config.codec_embed_dim();
162
163 let mut codec_embeddings = Vec::with_capacity(num_acoustic_groups);
166 for i in 0..num_acoustic_groups {
167 codec_embeddings.push(embedding(
168 config.vocab_size,
169 codec_embed_dim,
170 vb.pp(format!("model.codec_embedding.{}", i)),
171 )?);
172 }
173
174 let small_to_mtp_projection = if codec_embed_dim != config.hidden_size {
176 Some(candle_nn::linear(
177 codec_embed_dim,
178 config.hidden_size,
179 vb.pp("small_to_mtp_projection"),
180 )?)
181 } else {
182 None
183 };
184
185 let mut layers = Vec::with_capacity(config.num_hidden_layers);
187 for i in 0..config.num_hidden_layers {
188 layers.push(DecoderLayer::new(
189 &layer_config,
190 vb.pp(format!("model.layers.{}", i)),
191 )?);
192 }
193
194 let norm = rms_norm(config.hidden_size, config.rms_norm_eps, vb.pp("model.norm"))?;
196
197 let mut lm_heads = Vec::with_capacity(num_acoustic_groups);
199 for i in 0..num_acoustic_groups {
200 lm_heads.push(linear_no_bias(
201 config.hidden_size,
202 config.vocab_size,
203 vb.pp(format!("lm_head.{}", i)),
204 )?);
205 }
206
207 let rope = RoPEType::Standard(RotaryEmbedding::new(
209 config.head_dim,
210 1024, config.rope_theta,
212 vb.device(),
213 )?);
214
215 let prefill_mask = super::transformer::create_causal_mask(2, 0, vb.device())?;
218
219 let device = vb.device().clone();
220 let dtype = vb.dtype();
221
222 Ok(Self {
223 codec_embeddings,
224 small_to_mtp_projection,
225 layers,
226 norm,
227 lm_heads,
228 rope,
229 config,
230 prefill_mask,
231 device,
232 dtype,
233 })
234 }
235
236 pub fn get_logits(&self, hidden: &Tensor, group_idx: usize, position: usize) -> Result<Tensor> {
243 let pos_hidden = hidden.i((.., position..position + 1, ..))?;
244 Ok(self.lm_heads[group_idx].forward(&pos_hidden)?)
245 }
246
247 pub fn forward_prefill(
255 &self,
256 hidden: &Tensor,
257 _prev_codes: &[u32],
258 kv_caches: &mut [AnyKVCache],
259 ) -> Result<Tensor> {
260 let device = hidden.device();
261 let input = if let Some(proj) = &self.small_to_mtp_projection {
262 proj.forward(hidden)?
263 } else {
264 hidden.clone()
265 };
266
267 let seq_len = input.dim(1)?;
268 let mask = self.create_causal_mask(seq_len, device)?;
269
270 let mut h = input;
271 for (i, layer) in self.layers.iter().enumerate() {
272 h = layer.forward(&h, &self.rope, Some(&mask), Some(&mut kv_caches[i]), 0)?;
273 }
274 Ok(self.norm.forward(&h)?)
275 }
276
277 pub fn new_kv_caches(&self) -> Vec<AnyKVCache> {
283 const CP_MAX_SEQ: usize = 17;
285
286 (0..self.config.num_hidden_layers)
287 .map(|_| {
288 if self.device.is_cuda() || self.device.is_metal() {
289 PreAllocKVCache::new(
290 1, self.config.num_key_value_heads,
292 CP_MAX_SEQ,
293 self.config.head_dim,
294 self.dtype,
295 &self.device,
296 )
297 .map(AnyKVCache::PreAlloc)
298 .unwrap_or_else(|_| AnyKVCache::Concat(KVCache::new()))
299 } else {
300 AnyKVCache::Concat(KVCache::new())
301 }
302 })
303 .collect()
304 }
305
306 pub fn generate_acoustic_codes(
321 &self,
322 talker_hidden: &Tensor,
323 semantic_embed: &Tensor,
324 cp_kv_caches: &mut [AnyKVCache],
325 ) -> Result<Tensor> {
326 #[cfg(feature = "profiling")]
327 let _span = tracing::info_span!("code_predictor_inner").entered();
328
329 for cache in cp_kv_caches.iter_mut() {
331 cache.reset();
332 }
333
334 let device = talker_hidden.device();
335 let num_acoustic = self.config.num_code_groups - 1; let input = Tensor::cat(&[talker_hidden, semantic_embed], 1)?;
339
340 let input = if let Some(proj) = &self.small_to_mtp_projection {
342 proj.forward(&input)?
343 } else {
344 input
345 };
346
347 let seq_len = input.dim(1)?;
348 let dynamic_mask;
350 let mask = if seq_len == 2 {
351 &self.prefill_mask
352 } else {
353 dynamic_mask = self.create_causal_mask(seq_len, device)?;
354 &dynamic_mask
355 };
356
357 let mut hidden = input;
358 for (i, layer) in self.layers.iter().enumerate() {
359 hidden = layer.forward(
360 &hidden,
361 &self.rope,
362 Some(mask),
363 Some(&mut cp_kv_caches[i]),
364 0,
365 )?;
366 }
367 hidden = self.norm.forward(&hidden)?;
368
369 let last_hidden = hidden.i((.., seq_len - 1..seq_len, ..))?;
374 let logits = self.lm_heads[0].forward(&last_hidden)?;
375 let first_code = logits.argmax(D::Minus1)?.flatten_all()?; let mut all_codes = Tensor::zeros(num_acoustic, candle_core::DType::U32, device)?;
378 let range = 0..1;
379 all_codes = all_codes.slice_assign(&[range], &first_code)?;
380
381 let mut prev_code = first_code;
383
384 let mut offset = seq_len;
386 for group_idx in 1..num_acoustic {
387 let code_embed = self.codec_embeddings[group_idx - 1].forward(&prev_code)?;
389 let code_embed = code_embed.unsqueeze(0)?; let code_embed = if let Some(proj) = &self.small_to_mtp_projection {
393 proj.forward(&code_embed)?
394 } else {
395 code_embed
396 };
397
398 let mut h = code_embed;
401 for (i, layer) in self.layers.iter().enumerate() {
402 h = layer.forward(&h, &self.rope, None, Some(&mut cp_kv_caches[i]), offset)?;
403 }
404 h = self.norm.forward(&h)?;
405
406 let logits = self.lm_heads[group_idx].forward(&h)?;
408 let next_code = logits.argmax(D::Minus1)?.flatten_all()?; let range = group_idx..group_idx + 1;
410 all_codes = all_codes.slice_assign(&[range], &next_code)?;
411 prev_code = next_code;
412 offset += 1;
413 }
414
415 Ok(all_codes)
416 }
417
418 fn create_causal_mask(&self, seq_len: usize, device: &candle_core::Device) -> Result<Tensor> {
419 super::transformer::create_causal_mask(seq_len, 0, device)
420 }
421
422 pub fn get_acoustic_embedding(
427 &self,
428 code: u32,
429 group_idx: usize,
430 device: &candle_core::Device,
431 ) -> Result<Tensor> {
432 if group_idx >= self.codec_embeddings.len() {
433 anyhow::bail!(
434 "Invalid group_idx {} (max {})",
435 group_idx,
436 self.codec_embeddings.len() - 1
437 );
438 }
439 let code_tensor = Tensor::new(&[code], device)?;
440 let embed = self.codec_embeddings[group_idx].forward(&code_tensor)?;
441 Ok(embed.unsqueeze(0)?) }
443
444 pub fn embed_codes_for_group(&self, group_idx: usize, codes: &Tensor) -> Result<Tensor> {
455 if group_idx >= self.codec_embeddings.len() {
456 anyhow::bail!(
457 "Invalid group_idx {} (max {})",
458 group_idx,
459 self.codec_embeddings.len() - 1
460 );
461 }
462 let embed = self.codec_embeddings[group_idx].forward(codes)?; Ok(embed.unsqueeze(0)?) }
465
466 pub fn get_acoustic_embeddings_sum(
471 &self,
472 acoustic_codes: &[u32],
473 device: &candle_core::Device,
474 ) -> Result<Tensor> {
475 if acoustic_codes.len() != self.codec_embeddings.len() {
476 anyhow::bail!(
477 "Expected {} acoustic codes, got {}",
478 self.codec_embeddings.len(),
479 acoustic_codes.len()
480 );
481 }
482
483 let first = self.get_acoustic_embedding(acoustic_codes[0], 0, device)?;
484 acoustic_codes[1..]
485 .iter()
486 .enumerate()
487 .try_fold(first, |acc, (i, &code)| {
488 let embed = self.get_acoustic_embedding(code, i + 1, device)?;
489 acc.add(&embed).map_err(Into::into)
490 })
491 }
492
493 pub fn get_acoustic_embeddings_sum_from_tensor(
498 &self,
499 acoustic_codes: &Tensor,
500 ) -> Result<Tensor> {
501 let n = acoustic_codes.dim(0)?;
502 if n != self.codec_embeddings.len() {
503 anyhow::bail!(
504 "Expected {} acoustic codes, got {}",
505 self.codec_embeddings.len(),
506 n
507 );
508 }
509
510 let first_code = acoustic_codes.narrow(0, 0, 1)?;
511 let first = self.codec_embeddings[0]
512 .forward(&first_code)?
513 .unsqueeze(0)?;
514 (1..n).try_fold(first, |acc, i| {
515 let code = acoustic_codes.narrow(0, i, 1)?;
516 let embed = self.codec_embeddings[i].forward(&code)?.unsqueeze(0)?;
517 acc.add(&embed).map_err(Into::into)
518 })
519 }
520}
521
522#[cfg(test)]
523mod tests {
524 use super::*;
525 use candle_core::{DType, Device};
526 use candle_nn::VarMap;
527
528 fn create_mock_vb(device: &Device) -> VarBuilder<'static> {
529 let varmap = VarMap::new();
530 VarBuilder::from_varmap(&varmap, DType::F32, device)
531 }
532
533 #[test]
534 fn test_config_default() {
535 let config = CodePredictorConfig::default();
536 assert_eq!(config.num_hidden_layers, 5);
537 assert_eq!(config.num_code_groups, 16);
538 assert_eq!(config.hidden_size, 1024);
539 }
540
541 #[test]
542 fn test_code_predictor_construction() {
543 let device = Device::Cpu;
544 let vb = create_mock_vb(&device);
545
546 let config = CodePredictorConfig {
547 hidden_size: 32,
548 intermediate_size: 64,
549 num_hidden_layers: 2,
550 num_attention_heads: 4,
551 num_key_value_heads: 2,
552 head_dim: 8,
553 vocab_size: 64,
554 num_code_groups: 4,
555 ..Default::default()
556 };
557
558 let predictor = CodePredictor::new(config, vb);
559 assert!(predictor.is_ok());
560
561 let predictor = predictor.unwrap();
562 assert_eq!(predictor.codec_embeddings.len(), 3); assert_eq!(predictor.layers.len(), 2);
564 assert_eq!(predictor.lm_heads.len(), 3);
565 }
566}