1pub mod audio;
86pub mod generation;
87#[cfg(feature = "hub")]
88pub mod hub;
89pub mod models;
90pub mod profiling;
91pub mod tokenizer;
92
93use anyhow::Result;
94use candle_core::{DType, Device, IndexOp, Tensor};
95use serde::Serialize;
96use std::collections::HashMap;
97use std::path::Path;
98use std::time::Instant;
99
100use models::codec::{Decoder12Hz, Encoder12Hz};
101use models::speaker::SpeakerEncoder;
102use models::AnyKVCache;
103
104pub use audio::AudioBuffer;
108#[cfg(feature = "hub")]
109pub use hub::ModelPaths;
110pub use models::config::Qwen3TTSConfig;
111pub use generation::SamplingContext;
113pub use models::talker::{codec_tokens, special_tokens, tts_tokens, Language, Speaker};
114pub use models::{
115 CodePredictor, CodePredictorConfig, ModelType, ParsedModelConfig, SpeakerEncoderConfig,
116 TalkerConfig, TalkerModel,
117};
118
119pub type FrameCodes = Vec<Vec<u32>>;
122
123pub struct VoiceClonePrompt {
128 pub speaker_embedding: Tensor,
130 pub ref_codes: Option<Tensor>,
132 pub ref_text_ids: Option<Vec<u32>>,
134}
135
136#[derive(Debug, Clone, Serialize)]
138pub struct SynthesisTiming {
139 pub prefill_ms: f64,
141 pub generation_ms: f64,
143 pub generation_frames: usize,
145 pub decode_ms: f64,
147}
148
149pub struct Qwen3TTS {
155 talker: TalkerModel,
157 code_predictor: CodePredictor,
159 decoder: Decoder12Hz,
161 text_tokenizer: tokenizer::TextTokenizer,
163 speaker_encoder: Option<SpeakerEncoder>,
165 speech_encoder: Option<Encoder12Hz>,
167 model_type: Option<ModelType>,
169 device: Device,
171 compute_dtype: DType,
173}
174
175impl Qwen3TTS {
176 pub fn from_pretrained(model_id: &str, device: Device) -> Result<Self> {
184 Self::from_pretrained_with_tokenizer(model_id, None, device)
185 }
186
187 pub fn from_pretrained_with_tokenizer(
193 model_id: &str,
194 tokenizer_id: Option<&str>,
195 device: Device,
196 ) -> Result<Self> {
197 tracing::info!("Loading Qwen3-TTS from: {}", model_id);
198 tracing::info!("Compute dtype: {:?}", compute_dtype_for_device(&device));
199
200 let config_path = Path::new(model_id).join("config.json");
202 let parsed_config = if config_path.exists() {
203 match ParsedModelConfig::from_file(&config_path) {
204 Ok(cfg) => {
205 tracing::info!("Detected model variant: {}", cfg.label());
206 Some(cfg)
207 }
208 Err(e) => {
209 tracing::warn!(
210 "Failed to parse config.json, falling back to weight inspection: {}",
211 e
212 );
213 None
214 }
215 }
216 } else {
217 None
218 };
219
220 let tok_source = tokenizer_id.unwrap_or(model_id);
222 let text_tokenizer = tokenizer::TextTokenizer::from_pretrained(tok_source)?;
223
224 let model_path = Path::new(model_id).join("model.safetensors");
226 if !model_path.exists() {
227 anyhow::bail!(
228 "Model weights not found at {}. Please download the model first.",
229 model_path.display()
230 );
231 }
232 let weights = Self::load_weights(&model_path, &device)?;
233
234 let st_path = Path::new(model_id).join("speech_tokenizer/model.safetensors");
236 let st_weights = if st_path.exists() {
237 Self::load_weights(&st_path, &device)?
238 } else {
239 let alt_path = Path::new(model_id)
241 .parent()
242 .map(|p| p.join("speech_tokenizer/model.safetensors"));
243 if let Some(p) = alt_path {
244 if p.exists() {
245 Self::load_weights(&p, &device)?
246 } else {
247 anyhow::bail!("Speech tokenizer weights not found");
248 }
249 } else {
250 anyhow::bail!("Speech tokenizer weights not found");
251 }
252 };
253
254 Self::build_from_components(
255 &weights,
256 &st_weights,
257 text_tokenizer,
258 parsed_config.as_ref(),
259 &device,
260 )
261 }
262
263 pub fn from_weights(
268 model_weights: &HashMap<String, Tensor>,
269 decoder_weights: &HashMap<String, Tensor>,
270 text_tokenizer: tokenizer::TextTokenizer,
271 device: &Device,
272 ) -> Result<Self> {
273 Self::build_from_components(model_weights, decoder_weights, text_tokenizer, None, device)
274 }
275
276 #[cfg(feature = "hub")]
290 pub fn from_paths(paths: &hub::ModelPaths, device: Device) -> Result<Self> {
291 tracing::info!("Loading Qwen3-TTS from downloaded paths...");
292
293 let text_tokenizer = tokenizer::TextTokenizer::from_file(&paths.tokenizer)?;
294 let weights = Self::load_weights(&paths.model_weights, &device)?;
295 let st_weights = Self::load_weights(&paths.decoder_weights, &device)?;
296 let parsed_config = ParsedModelConfig::from_file(&paths.config).ok();
297
298 Self::build_from_components(
299 &weights,
300 &st_weights,
301 text_tokenizer,
302 parsed_config.as_ref(),
303 &device,
304 )
305 }
306
307 fn build_from_components(
312 model_weights: &HashMap<String, Tensor>,
313 decoder_weights: &HashMap<String, Tensor>,
314 text_tokenizer: tokenizer::TextTokenizer,
315 parsed_config: Option<&ParsedModelConfig>,
316 device: &Device,
317 ) -> Result<Self> {
318 let compute_dtype = compute_dtype_for_device(device);
319
320 let talker_config = if let Some(cfg) = parsed_config {
322 TalkerConfig::from_parsed(cfg)
323 } else {
324 Self::detect_talker_config(model_weights)?
325 };
326 let talker = TalkerModel::from_weights_with_config_dtype(
327 model_weights,
328 talker_config,
329 device,
330 compute_dtype,
331 )?;
332
333 let cp_config = if let Some(cfg) = parsed_config {
335 CodePredictorConfig::from_parsed(cfg)
336 } else {
337 let talker_hidden = talker.config().hidden_size;
338 if talker_hidden != 1024 {
339 CodePredictorConfig {
340 codec_embed_dim: Some(talker_hidden),
341 ..CodePredictorConfig::default()
342 }
343 } else {
344 CodePredictorConfig::default()
345 }
346 };
347 let cp_weights = Self::filter_weights(model_weights, "talker.code_predictor.");
348 let cp_vb = candle_nn::VarBuilder::from_tensors(cp_weights, compute_dtype, device);
349 let code_predictor = CodePredictor::new(cp_config, cp_vb)?;
350
351 let decoder = Decoder12Hz::from_weights(decoder_weights, Default::default())?;
353
354 let se_config = parsed_config.and_then(|c| c.speaker_encoder_config.clone());
356 let speaker_encoder =
357 Self::try_load_speaker_encoder(model_weights, se_config.as_ref(), device)?;
358
359 let speech_encoder = Self::try_load_speech_encoder(decoder_weights, device)?;
361
362 let model_type = parsed_config.map(|c| c.model_type);
363
364 Ok(Self {
365 talker,
366 code_predictor,
367 decoder,
368 text_tokenizer,
369 speaker_encoder,
370 speech_encoder,
371 model_type,
372 device: device.clone(),
373 compute_dtype,
374 })
375 }
376
377 fn detect_talker_config(weights: &HashMap<String, Tensor>) -> Result<TalkerConfig> {
379 let norm_weight = weights
380 .get("talker.model.norm.weight")
381 .ok_or_else(|| anyhow::anyhow!("Missing talker.model.norm.weight"))?;
382 let hidden_size = norm_weight.dim(0)?;
383 Ok(if hidden_size == 2048 {
384 TalkerConfig::custom_voice()
385 } else {
386 TalkerConfig::default()
387 })
388 }
389
390 pub fn model_type(&self) -> Option<&ModelType> {
392 self.model_type.as_ref()
393 }
394
395 pub fn supports_voice_cloning(&self) -> bool {
397 self.speaker_encoder.is_some()
398 }
399
400 pub fn supports_preset_speakers(&self) -> bool {
406 match &self.model_type {
407 Some(ModelType::CustomVoice) => true,
408 Some(ModelType::Base) | Some(ModelType::VoiceDesign) => false,
409 None => true, }
411 }
412
413 pub fn supports_voice_design(&self) -> bool {
417 matches!(&self.model_type, Some(ModelType::VoiceDesign))
418 }
419
420 pub fn synthesize(&self, text: &str, options: Option<SynthesisOptions>) -> Result<AudioBuffer> {
424 self.synthesize_with_voice(text, Speaker::Ryan, Language::English, options)
425 }
426
427 pub fn synthesize_with_timing(
433 &self,
434 text: &str,
435 speaker: Speaker,
436 language: Language,
437 options: Option<SynthesisOptions>,
438 ) -> Result<(AudioBuffer, SynthesisTiming)> {
439 #[cfg(feature = "profiling")]
440 let _span = tracing::info_span!("synthesize").entered();
441
442 let options = options.unwrap_or_default();
443 let mut sampling_ctx = generation::SamplingContext::new(options.seed);
444 let input_ids = self.text_tokenizer.encode(text)?;
445 let gen_config = options.to_gen_config();
446
447 let (trailing_text_hidden, trailing_text_len, tts_pad_embed) =
448 self.build_trailing_text(&input_ids)?;
449
450 #[cfg(feature = "profiling")]
452 let _prefill_span = tracing::info_span!("prefill").entered();
453
454 sync_device(&self.device)?;
455 let t_prefill = Instant::now();
456
457 let mut kv_caches = self.talker.new_kv_caches(gen_config.max_new_tokens + 256);
458 let (hidden, logits) =
459 self.talker
460 .prefill_custom_voice(&input_ids, speaker, language, &mut kv_caches)?;
461 let prefill_len = hidden.dim(1)?;
462 let offset = prefill_len;
463 let last_hidden = hidden.i((.., prefill_len - 1..prefill_len, ..))?;
464
465 sync_device(&self.device)?;
466 let prefill_ms = t_prefill.elapsed().as_secs_f64() * 1000.0;
467
468 #[cfg(feature = "profiling")]
469 drop(_prefill_span);
470
471 let t_gen = Instant::now();
473
474 let all_codes = self.generate_codes(
475 &gen_config,
476 &mut sampling_ctx,
477 &mut kv_caches,
478 offset,
479 last_hidden,
480 &logits,
481 &trailing_text_hidden,
482 trailing_text_len,
483 &tts_pad_embed,
484 )?;
485
486 sync_device(&self.device)?;
487 let generation_ms = t_gen.elapsed().as_secs_f64() * 1000.0;
488 let generation_frames = all_codes.len();
489
490 #[cfg(feature = "profiling")]
492 let _decode_span = tracing::info_span!("decode").entered();
493
494 let t_decode = Instant::now();
495 let audio = self.decode_codes(&all_codes)?;
496
497 sync_device(&self.device)?;
498 let decode_ms = t_decode.elapsed().as_secs_f64() * 1000.0;
499
500 let timing = SynthesisTiming {
501 prefill_ms,
502 generation_ms,
503 generation_frames,
504 decode_ms,
505 };
506
507 Ok((audio, timing))
508 }
509
510 fn build_trailing_text(&self, input_ids: &[u32]) -> Result<(Tensor, usize, Tensor)> {
516 let trailing_text_hidden = if input_ids.len() > 1 {
517 let remaining_proj = self.talker.get_projected_text_embeddings(&input_ids[1..])?;
518 let tts_eos_embed = self.talker.get_tts_eos_embed()?;
519 Tensor::cat(&[&remaining_proj, &tts_eos_embed], 1)?
520 } else {
521 self.talker.get_tts_eos_embed()?
522 };
523 let trailing_text_len = trailing_text_hidden.dim(1)?;
524 let tts_pad_embed = self.talker.get_tts_pad_embed()?;
525 Ok((trailing_text_hidden, trailing_text_len, tts_pad_embed))
526 }
527
528 #[allow(clippy::too_many_arguments)]
537 fn generate_codes(
538 &self,
539 gen_config: &generation::GenerationConfig,
540 sampling_ctx: &mut generation::SamplingContext,
541 kv_caches: &mut [AnyKVCache],
542 mut offset: usize,
543 mut last_hidden: Tensor,
544 initial_logits: &Tensor,
545 trailing_text_hidden: &Tensor,
546 trailing_text_len: usize,
547 tts_pad_embed: &Tensor,
548 ) -> Result<FrameCodes> {
549 let suppression_mask = generation::build_suppression_mask(
551 codec_tokens::CODEC_VOCAB_SIZE,
552 CODEC_EOS_TOKEN_ID,
553 &self.device,
554 )?;
555
556 let vocab_size = codec_tokens::CODEC_VOCAB_SIZE;
559 let mut penalty_mask = Tensor::zeros((1, vocab_size), DType::F32, &self.device)?;
560
561 let mut cp_kv_caches = self.code_predictor.new_kv_caches();
563
564 let logits_2d = initial_logits.squeeze(1)?;
566 let logits_2d = self.apply_generation_penalties_gpu(
567 &logits_2d,
568 &penalty_mask,
569 gen_config,
570 0,
571 Some(&suppression_mask),
572 )?;
573 let mut semantic_token_tensor = generation::sample(&logits_2d, gen_config, sampling_ctx)?;
574 tracing::trace!(target: "gpu_sync", "to_vec1 in generate_codes first token");
575 let mut semantic_token: u32 = semantic_token_tensor.flatten_all()?.to_vec1::<u32>()?[0];
576 Self::update_penalty_mask(&mut penalty_mask, semantic_token, vocab_size)?;
578 let mut token_count: usize = 1;
579
580 let mut gpu_frames: Vec<Tensor> = Vec::new();
583
584 #[cfg(feature = "profiling")]
585 let _gen_span = tracing::info_span!("generate_frames").entered();
586
587 for frame_idx in 0..gen_config.max_new_tokens {
588 if let Some(eos_id) = gen_config.eos_token_id {
589 if semantic_token == eos_id {
590 break;
591 }
592 }
593
594 let semantic_embed = self
596 .talker
597 .get_codec_embedding_from_tensor(&semantic_token_tensor)?;
598
599 #[cfg(feature = "profiling")]
600 let _cp_span = tracing::info_span!("code_predictor", frame = frame_idx).entered();
601
602 let acoustic_codes_tensor = self.code_predictor.generate_acoustic_codes(
603 &last_hidden,
604 &semantic_embed,
605 &mut cp_kv_caches,
606 )?;
607
608 #[cfg(feature = "profiling")]
609 drop(_cp_span);
610
611 let frame_tensor = Tensor::cat(
613 &[&semantic_token_tensor.reshape(1)?, &acoustic_codes_tensor],
614 0,
615 )?;
616 gpu_frames.push(frame_tensor);
617
618 let acoustic_embed_sum = self
620 .code_predictor
621 .get_acoustic_embeddings_sum_from_tensor(&acoustic_codes_tensor)?;
622 let summed = semantic_embed.add(&acoustic_embed_sum)?;
623
624 let text_addition = if frame_idx < trailing_text_len {
625 trailing_text_hidden.i((.., frame_idx..frame_idx + 1, ..))?
626 } else {
627 tts_pad_embed.clone()
628 };
629 let step_input = summed.add(&text_addition)?;
630
631 #[cfg(feature = "profiling")]
632 let _talker_span = tracing::info_span!("talker_step", frame = frame_idx).entered();
633
634 let (h, new_logits) =
635 self.talker
636 .generate_step_with_embed(&step_input, kv_caches, offset)?;
637 offset += 1;
638 last_hidden = h;
639
640 #[cfg(feature = "profiling")]
641 drop(_talker_span);
642
643 #[cfg(feature = "profiling")]
644 let _sample_span = tracing::info_span!("sampling", frame = frame_idx).entered();
645
646 let logits_2d = new_logits.squeeze(1)?;
647 let logits_2d = self.apply_generation_penalties_gpu(
648 &logits_2d,
649 &penalty_mask,
650 gen_config,
651 token_count,
652 Some(&suppression_mask),
653 )?;
654 semantic_token_tensor = generation::sample(&logits_2d, gen_config, sampling_ctx)?;
655 tracing::trace!(target: "gpu_sync", "to_vec1 in generate_codes sampling");
656 semantic_token = semantic_token_tensor.flatten_all()?.to_vec1::<u32>()?[0];
657 Self::update_penalty_mask(&mut penalty_mask, semantic_token, vocab_size)?;
658 token_count += 1;
659 }
660
661 self.gpu_frames_to_frame_codes(&gpu_frames)
663 }
664
665 fn update_penalty_mask(
670 penalty_mask: &mut Tensor,
671 token_id: u32,
672 vocab_size: usize,
673 ) -> Result<()> {
674 let idx = token_id as usize;
675 if idx < vocab_size {
676 let one = Tensor::ones((1, 1), DType::F32, penalty_mask.device())?;
677 *penalty_mask = penalty_mask.slice_assign(&[0..1, idx..idx + 1], &one)?;
678 }
679 Ok(())
680 }
681
682 fn gpu_frames_to_frame_codes(&self, gpu_frames: &[Tensor]) -> Result<FrameCodes> {
684 if gpu_frames.is_empty() {
685 return Ok(Vec::new());
686 }
687 let stacked = Tensor::stack(gpu_frames, 0)?; let n_frames = stacked.dim(0)?;
690 let flat: Vec<u32> = stacked.flatten_all()?.to_vec1()?;
691 let mut result = Vec::with_capacity(n_frames);
692 for f in 0..n_frames {
693 let start = f * 16;
694 result.push(flat[start..start + 16].to_vec());
695 }
696 Ok(result)
697 }
698
699 pub fn synthesize_with_voice(
726 &self,
727 text: &str,
728 speaker: Speaker,
729 language: Language,
730 options: Option<SynthesisOptions>,
731 ) -> Result<AudioBuffer> {
732 #[cfg(feature = "profiling")]
733 let _span = tracing::info_span!("synthesize").entered();
734
735 if let Some(ModelType::Base) = &self.model_type {
736 tracing::warn!(
737 "Using preset speaker {:?} on a Base model. Base models are trained for \
738 voice cloning, not preset speakers — output will have an unpredictable voice. \
739 Use synthesize_voice_clone() with reference audio instead.",
740 speaker
741 );
742 } else if let Some(ModelType::VoiceDesign) = &self.model_type {
743 tracing::warn!(
744 "Using preset speaker {:?} on a VoiceDesign model. VoiceDesign models \
745 are trained for text-described voice creation, not preset speakers.",
746 speaker
747 );
748 }
749
750 let options = options.unwrap_or_default();
751 let mut sampling_ctx = generation::SamplingContext::new(options.seed);
752 let input_ids = self.text_tokenizer.encode(text)?;
753
754 let gen_config = options.to_gen_config();
755
756 let (trailing_text_hidden, trailing_text_len, tts_pad_embed) =
757 self.build_trailing_text(&input_ids)?;
758
759 #[cfg(feature = "profiling")]
761 let _prefill_span = tracing::info_span!("prefill").entered();
762
763 let mut kv_caches = self.talker.new_kv_caches(gen_config.max_new_tokens + 256);
764 let (hidden, logits) =
765 self.talker
766 .prefill_custom_voice(&input_ids, speaker, language, &mut kv_caches)?;
767 let prefill_len = hidden.dim(1)?;
768 let offset = prefill_len;
769 let last_hidden = hidden.i((.., prefill_len - 1..prefill_len, ..))?;
770
771 #[cfg(feature = "profiling")]
772 drop(_prefill_span);
773
774 let all_codes = self.generate_codes(
775 &gen_config,
776 &mut sampling_ctx,
777 &mut kv_caches,
778 offset,
779 last_hidden,
780 &logits,
781 &trailing_text_hidden,
782 trailing_text_len,
783 &tts_pad_embed,
784 )?;
785
786 #[cfg(feature = "profiling")]
788 let _decode_span = tracing::info_span!("decode").entered();
789
790 self.decode_codes(&all_codes)
791 }
792
793 pub fn synthesize_voice_design(
810 &self,
811 text: &str,
812 instruct: &str,
813 language: Language,
814 options: Option<SynthesisOptions>,
815 ) -> Result<AudioBuffer> {
816 #[cfg(feature = "profiling")]
817 let _span = tracing::info_span!("synthesize").entered();
818
819 if let Some(ref mt) = self.model_type {
820 if *mt != ModelType::VoiceDesign {
821 tracing::warn!(
822 "Using VoiceDesign synthesis on a {:?} model. This model was not trained \
823 for text-described voice conditioning — output may be unpredictable.",
824 mt
825 );
826 }
827 }
828
829 let options = options.unwrap_or_default();
830 let mut sampling_ctx = generation::SamplingContext::new(options.seed);
831 let input_ids = self.text_tokenizer.encode(text)?;
832
833 let instruct_text = format!("<|im_start|>user\n{}<|im_end|>\n", instruct);
835 let instruct_ids = self.text_tokenizer.encode(&instruct_text)?;
836
837 let gen_config = options.to_gen_config();
838
839 let (trailing_text_hidden, trailing_text_len, tts_pad_embed) =
840 self.build_trailing_text(&input_ids)?;
841
842 #[cfg(feature = "profiling")]
844 let _prefill_span = tracing::info_span!("prefill").entered();
845
846 let mut kv_caches = self.talker.new_kv_caches(gen_config.max_new_tokens + 256);
847 let (hidden, logits) = self.talker.prefill_voice_design(
848 &input_ids,
849 &instruct_ids,
850 language,
851 &mut kv_caches,
852 )?;
853 let prefill_len = hidden.dim(1)?;
854 let offset = prefill_len;
855 let last_hidden = hidden.i((.., prefill_len - 1..prefill_len, ..))?;
856
857 #[cfg(feature = "profiling")]
858 drop(_prefill_span);
859
860 let all_codes = self.generate_codes(
861 &gen_config,
862 &mut sampling_ctx,
863 &mut kv_caches,
864 offset,
865 last_hidden,
866 &logits,
867 &trailing_text_hidden,
868 trailing_text_len,
869 &tts_pad_embed,
870 )?;
871
872 #[cfg(feature = "profiling")]
874 let _decode_span = tracing::info_span!("decode").entered();
875
876 self.decode_codes(&all_codes)
877 }
878
879 pub fn codes_to_tensor(&self, codes: &[Vec<u32>]) -> Result<Tensor> {
881 codes_to_tensor(codes, &self.device)
882 }
883
884 pub fn decode_codes(&self, codes: &[Vec<u32>]) -> Result<AudioBuffer> {
889 let tensor = self.codes_to_tensor(codes)?;
890 self.decode_tensor(&tensor)
891 }
892
893 fn decode_tensor(&self, codes: &Tensor) -> Result<AudioBuffer> {
895 let waveform = self.decoder.decode(codes)?;
896 AudioBuffer::from_tensor(waveform, 24000)
897 }
898
899 pub fn synthesize_voice_clone_debug(
905 &self,
906 text: &str,
907 prompt: &VoiceClonePrompt,
908 language: Language,
909 options: Option<SynthesisOptions>,
910 ) -> Result<(AudioBuffer, FrameCodes)> {
911 #[cfg(feature = "profiling")]
912 let _span = tracing::info_span!("synthesize").entered();
913
914 let options = options.unwrap_or_default();
915 let mut sampling_ctx = generation::SamplingContext::new(options.seed);
916 let input_ids = self.text_tokenizer.encode(text)?;
917
918 let is_icl = prompt.ref_codes.is_some() && prompt.ref_text_ids.is_some();
920
921 let repetition_penalty = if is_icl {
923 options.repetition_penalty.max(ICL_MIN_REPETITION_PENALTY)
924 } else {
925 options.repetition_penalty
926 };
927 let max_new_tokens = if is_icl {
928 options
929 .max_length
930 .min(ICL_MIN_FRAMES.max(input_ids.len() * ICL_FRAMES_PER_TOKEN))
931 } else {
932 options.max_length
933 };
934 let mut gen_config = options.to_gen_config();
935 gen_config.max_new_tokens = max_new_tokens;
936 gen_config.repetition_penalty = repetition_penalty;
937
938 let speaker_embed = prompt.speaker_embedding.to_dtype(self.compute_dtype)?;
940
941 #[cfg(feature = "profiling")]
943 let _prefill_span = tracing::info_span!("prefill").entered();
944
945 let mut kv_cache_capacity = gen_config.max_new_tokens + 256;
953 if is_icl {
954 let ref_text_len = prompt.ref_text_ids.as_ref().map_or(0, |ids| ids.len());
955 let ref_frames = match &prompt.ref_codes {
956 Some(codes) => codes.dim(0)?,
957 None => 0,
958 };
959
960 let text_part = ref_text_len + input_ids.len() + 1; let codec_part = ref_frames + 1; let icl_upper_bound = text_part + codec_part;
965 let prefill_upper_bound = 10;
966 let safety_margin = 32;
967 let required_capacity =
968 prefill_upper_bound + icl_upper_bound + gen_config.max_new_tokens + safety_margin;
969 kv_cache_capacity = kv_cache_capacity.max(required_capacity);
970 }
971
972 let mut kv_caches = self.talker.new_kv_caches(kv_cache_capacity);
973 let (hidden, logits) = self.talker.prefill_voice_clone(
974 &input_ids,
975 &speaker_embed,
976 language,
977 is_icl,
978 &mut kv_caches,
979 )?;
980 let prefill_len = hidden.dim(1)?;
981 let mut offset = prefill_len;
982
983 let mut last_hidden = hidden.i((.., prefill_len - 1..prefill_len, ..))?;
985
986 let (trailing_text_hidden, logits) = if let (Some(ref_codes), Some(ref_text_ids)) =
988 (&prompt.ref_codes, &prompt.ref_text_ids)
989 {
990 let ref_codec_embeds = self.sum_ref_codec_embeddings(ref_codes)?;
991
992 let (icl_embed, icl_trailing) =
997 self.talker
998 .build_icl_prompt(&input_ids, ref_text_ids, &ref_codec_embeds, true)?;
999
1000 let icl_len = icl_embed.dim(1)?;
1001 if icl_len > 0 {
1002 let mask = models::transformer::create_causal_mask(icl_len, offset, &self.device)?;
1003
1004 let mut icl_hidden = icl_embed;
1005 for (i, layer) in self.talker.layers_iter().enumerate() {
1006 icl_hidden = layer.forward(
1007 &icl_hidden,
1008 self.talker.rope(),
1009 Some(&mask),
1010 Some(&mut kv_caches[i]),
1011 offset,
1012 )?;
1013 }
1014 icl_hidden = self.talker.apply_norm(&icl_hidden)?;
1015 offset += icl_len;
1016
1017 let last_icl_hidden = icl_hidden.i((.., icl_len - 1..icl_len, ..))?;
1018 let new_logits = self.talker.apply_codec_head(&last_icl_hidden)?;
1019
1020 last_hidden = last_icl_hidden;
1023
1024 (icl_trailing, new_logits)
1025 } else {
1026 let trailing = self.build_default_trailing_text(&input_ids)?;
1027 (trailing, logits)
1028 }
1029 } else {
1030 let trailing = self.build_default_trailing_text(&input_ids)?;
1031 (trailing, logits)
1032 };
1033
1034 #[cfg(feature = "profiling")]
1035 drop(_prefill_span);
1036
1037 let trailing_text_len = trailing_text_hidden.dim(1)?;
1038 let tts_pad_embed = self.talker.get_tts_pad_embed()?;
1039
1040 let all_codes = self.generate_codes(
1041 &gen_config,
1042 &mut sampling_ctx,
1043 &mut kv_caches,
1044 offset,
1045 last_hidden,
1046 &logits,
1047 &trailing_text_hidden,
1048 trailing_text_len,
1049 &tts_pad_embed,
1050 )?;
1051
1052 #[cfg(feature = "profiling")]
1054 let _decode_span = tracing::info_span!("decode").entered();
1055
1056 let audio = if let Some(ref_codes) = &prompt.ref_codes {
1057 let ref_frames = self.tensor_to_frame_codes(ref_codes)?;
1058 let ref_len = ref_frames.len();
1059 let mut combined = ref_frames;
1060 combined.extend(all_codes.iter().cloned());
1061
1062 let mut audio = self.decode_codes(&combined)?;
1063 let total_frames = combined.len();
1064 let cut_samples = ref_len * audio.len() / total_frames.max(1);
1067 tracing::debug!(
1068 "ICL decode: ref_frames={}, gen_frames={}, total_samples={}, cut_samples={}",
1069 ref_len,
1070 all_codes.len(),
1071 audio.len(),
1072 cut_samples,
1073 );
1074 audio.samples = audio.samples[cut_samples.min(audio.len())..].to_vec();
1075 audio
1076 } else {
1077 self.decode_codes(&all_codes)?
1078 };
1079 Ok((audio, all_codes))
1080 }
1081
1082 pub fn device(&self) -> &Device {
1084 &self.device
1085 }
1086
1087 pub fn synthesize_streaming(
1105 &self,
1106 text: &str,
1107 speaker: Speaker,
1108 language: Language,
1109 options: SynthesisOptions,
1110 ) -> Result<StreamingSession<'_>> {
1111 let input_ids = self.text_tokenizer.encode(text)?;
1112 StreamingSession::new(self, &input_ids, speaker, language, options)
1113 }
1114
1115 pub fn synthesize_voice_design_streaming(
1130 &self,
1131 text: &str,
1132 instruct: &str,
1133 language: Language,
1134 options: SynthesisOptions,
1135 ) -> Result<StreamingSession<'_>> {
1136 if let Some(ref mt) = self.model_type {
1137 if *mt != ModelType::VoiceDesign {
1138 tracing::warn!(
1139 "Using VoiceDesign synthesis on a {:?} model. This model was not trained \
1140 for text-described voice conditioning — output may be unpredictable.",
1141 mt
1142 );
1143 }
1144 }
1145
1146 let input_ids = self.text_tokenizer.encode(text)?;
1147
1148 let instruct_text = format!("<|im_start|>user\n{}<|im_end|>\n", instruct);
1150 let instruct_ids = self.text_tokenizer.encode(&instruct_text)?;
1151
1152 StreamingSession::new_voice_design(self, &input_ids, &instruct_ids, language, options)
1153 }
1154
1155 pub fn create_voice_clone_prompt(
1167 &self,
1168 ref_audio: &AudioBuffer,
1169 ref_text: Option<&str>,
1170 ) -> Result<VoiceClonePrompt> {
1171 let encoder = self.speaker_encoder.as_ref().ok_or_else(|| {
1172 let hint = match &self.model_type {
1173 Some(ModelType::CustomVoice) => {
1174 " CustomVoice models use preset speakers (synthesize_with_voice), \
1175 not voice cloning. Use a Base model for voice cloning."
1176 }
1177 Some(ModelType::VoiceDesign) => {
1178 " VoiceDesign models use text-described voices, not voice cloning. \
1179 Use a Base model for voice cloning."
1180 }
1181 _ => {
1182 " Ensure model weights contain `speaker_encoder.*` keys \
1183 (only Base models include a speaker encoder)."
1184 }
1185 };
1186 anyhow::anyhow!("Speaker encoder not available.{}", hint)
1187 })?;
1188
1189 let ref_audio_24k;
1191 let ref_audio = if ref_audio.sample_rate != 24000 {
1192 tracing::info!(
1193 "Resampling reference audio from {}Hz to 24000Hz",
1194 ref_audio.sample_rate
1195 );
1196 ref_audio_24k = audio::resample_to_24k(ref_audio)?;
1197 &ref_audio_24k
1198 } else {
1199 ref_audio
1200 };
1201
1202 let speaker_embedding = encoder.encode(ref_audio)?; let (ref_codes, ref_text_ids) = if let Some(text) = ref_text {
1206 let speech_enc = self.speech_encoder.as_ref().ok_or_else(|| {
1207 anyhow::anyhow!(
1208 "ICL voice cloning requires a speech encoder, but it was not loaded. \
1209 Ensure the speech tokenizer weights contain encoder keys, or use \
1210 x_vector_only mode by passing ref_text=None."
1211 )
1212 })?;
1213
1214 let codes = speech_enc.encode(ref_audio)?; let text_ids = self.text_tokenizer.encode(text)?;
1216
1217 (Some(codes), Some(text_ids))
1218 } else {
1219 (None, None)
1220 };
1221
1222 Ok(VoiceClonePrompt {
1223 speaker_embedding,
1224 ref_codes,
1225 ref_text_ids,
1226 })
1227 }
1228
1229 pub fn synthesize_voice_clone(
1237 &self,
1238 text: &str,
1239 prompt: &VoiceClonePrompt,
1240 language: Language,
1241 options: Option<SynthesisOptions>,
1242 ) -> Result<AudioBuffer> {
1243 self.synthesize_voice_clone_debug(text, prompt, language, options)
1244 .map(|(audio, _codes)| audio)
1245 }
1246
1247 fn tensor_to_frame_codes(&self, codes: &Tensor) -> Result<FrameCodes> {
1249 let (n_frames, n_codebooks) = codes.dims2()?;
1250 let codes_u32 = codes.to_dtype(DType::U32)?;
1251 let mut frames = Vec::with_capacity(n_frames);
1252 for f in 0..n_frames {
1253 let frame_tensor = codes_u32.i(f)?; let frame_vec: Vec<u32> = frame_tensor.to_vec1()?;
1255 debug_assert_eq!(frame_vec.len(), n_codebooks);
1256 frames.push(frame_vec);
1257 }
1258 Ok(frames)
1259 }
1260
1261 fn sum_ref_codec_embeddings(&self, ref_codes: &Tensor) -> Result<Tensor> {
1274 let semantic_codes = ref_codes.i((.., 0))?; let semantic_codes = semantic_codes.to_dtype(candle_core::DType::U32)?;
1277 let summed = self.talker.get_codec_embedding_batch(&semantic_codes)?; let mut summed = summed;
1281 for group in 1..16 {
1282 let group_codes = ref_codes.i((.., group))?; let group_codes = group_codes.to_dtype(candle_core::DType::U32)?;
1284 let group_embed = self
1285 .code_predictor
1286 .embed_codes_for_group(group - 1, &group_codes)?; summed = summed.add(&group_embed)?;
1288 }
1289
1290 Ok(summed)
1291 }
1292
1293 fn build_default_trailing_text(&self, input_ids: &[u32]) -> Result<Tensor> {
1295 let (hidden, _len, _pad) = self.build_trailing_text(input_ids)?;
1296 Ok(hidden)
1297 }
1298
1299 fn apply_generation_penalties_gpu(
1306 &self,
1307 logits: &Tensor,
1308 penalty_mask: &Tensor,
1309 config: &generation::GenerationConfig,
1310 token_count: usize,
1311 suppression_mask: Option<&generation::SuppressionMask>,
1312 ) -> Result<Tensor> {
1313 let logits = logits.to_dtype(DType::F32)?;
1314
1315 let logits = if config.repetition_penalty != 1.0 {
1317 generation::apply_repetition_penalty_with_mask(
1318 &logits,
1319 penalty_mask,
1320 config.repetition_penalty,
1321 )?
1322 } else {
1323 logits
1324 };
1325
1326 let logits = if let Some(mask) = suppression_mask {
1328 generation::apply_token_suppression_with_mask(&logits, mask)?
1329 } else {
1330 generation::apply_token_suppression(
1331 &logits,
1332 codec_tokens::CODEC_VOCAB_SIZE,
1333 CODEC_EOS_TOKEN_ID,
1334 )?
1335 };
1336
1337 if token_count < config.min_new_tokens {
1339 if let Some(eos_id) = config.eos_token_id {
1340 let vocab = logits.dim(1)?;
1341 let batch = logits.dim(0)?;
1342 let mut mask_data = vec![0.0f32; vocab];
1343 mask_data[eos_id as usize] = 1.0;
1344 let eos_mask = Tensor::new(mask_data.as_slice(), logits.device())?
1345 .unsqueeze(0)?
1346 .broadcast_as((batch, vocab))?;
1347 let neg_inf = Tensor::new(&[f32::NEG_INFINITY], logits.device())?
1348 .broadcast_as((batch, vocab))?;
1349 let zeros = Tensor::zeros((batch, vocab), DType::F32, logits.device())?;
1350 let is_eos = eos_mask.gt(&zeros)?;
1351 return Ok(is_eos.where_cond(&neg_inf, &logits)?);
1352 }
1353 }
1354
1355 Ok(logits)
1356 }
1357
1358 pub fn has_speech_encoder(&self) -> bool {
1360 self.speech_encoder.is_some()
1361 }
1362
1363 fn try_load_speaker_encoder(
1371 weights: &HashMap<String, Tensor>,
1372 config: Option<&SpeakerEncoderConfig>,
1373 device: &Device,
1374 ) -> Result<Option<SpeakerEncoder>> {
1375 let has_se_weights = weights.keys().any(|k| k.starts_with("speaker_encoder."));
1376 if !has_se_weights {
1377 return Ok(None);
1378 }
1379
1380 let config = config.cloned().unwrap_or_default();
1381 tracing::info!(
1382 "Loading speaker encoder (ECAPA-TDNN, enc_dim={}) for voice cloning...",
1383 config.enc_dim
1384 );
1385 let se_weights = Self::filter_weights(weights, "speaker_encoder.");
1386 let se_vb = candle_nn::VarBuilder::from_tensors(se_weights, DType::F32, device);
1387 let encoder = SpeakerEncoder::new(config, se_vb)?;
1388 Ok(Some(encoder))
1389 }
1390
1391 fn try_load_speech_encoder(
1397 weights: &HashMap<String, Tensor>,
1398 device: &Device,
1399 ) -> Result<Option<Encoder12Hz>> {
1400 let has_encoder_keys = weights
1402 .keys()
1403 .any(|k| k.starts_with("encoder.") || k.starts_with("encoder_transformer."));
1404 if !has_encoder_keys {
1405 return Ok(None);
1406 }
1407
1408 tracing::debug!("Attempting to load speech encoder (Mimi) for ICL voice cloning...");
1409 match Encoder12Hz::from_weights(weights, device) {
1410 Ok(enc) => {
1411 tracing::info!("Loaded speech encoder — ICL voice cloning available");
1412 Ok(Some(enc))
1413 }
1414 Err(e) => {
1415 tracing::debug!(
1416 "Speech encoder not available ({}). ICL voice cloning disabled.",
1417 e
1418 );
1419 Ok(None)
1420 }
1421 }
1422 }
1423
1424 fn load_weights(path: &Path, device: &Device) -> Result<HashMap<String, Tensor>> {
1429 Ok(candle_core::safetensors::load(path, device)?)
1430 }
1431
1432 pub(crate) fn filter_weights(
1434 weights: &HashMap<String, Tensor>,
1435 prefix: &str,
1436 ) -> HashMap<String, Tensor> {
1437 weights
1438 .iter()
1439 .filter_map(|(k, v)| {
1440 k.strip_prefix(prefix)
1441 .map(|stripped| (stripped.to_string(), v.clone()))
1442 })
1443 .collect()
1444 }
1445}
1446
1447pub fn codes_to_tensor(codes: &[Vec<u32>], device: &Device) -> Result<Tensor> {
1452 let num_frames = codes.len();
1453 if num_frames == 0 {
1454 return Ok(Tensor::zeros((1, 16, 0), DType::I64, device)?);
1455 }
1456
1457 let mut data = vec![0i64; 16 * num_frames];
1458 for (frame, frame_codes) in codes.iter().enumerate() {
1459 for (q, &code) in frame_codes.iter().enumerate() {
1460 data[q * num_frames + frame] = code as i64;
1461 }
1462 }
1463
1464 Ok(Tensor::from_vec(data, (1, 16, num_frames), device)?)
1465}
1466
1467pub fn compute_dtype_for_device(device: &Device) -> DType {
1471 if device.is_cuda() || device.is_metal() {
1472 DType::BF16
1473 } else {
1474 DType::F32
1475 }
1476}
1477
1478pub fn sync_device(device: &Device) -> Result<()> {
1486 match device {
1487 Device::Cpu => Ok(()),
1488 _ => {
1489 let _: Vec<f32> = Tensor::zeros(1, DType::F32, device)?.to_vec1()?;
1491 Ok(())
1492 }
1493 }
1494}
1495
1496pub const CODEC_EOS_TOKEN_ID: u32 = codec_tokens::CODEC_EOS;
1501
1502pub const SAMPLES_PER_FRAME: usize = 1920;
1504
1505const ICL_MIN_FRAMES: usize = 75;
1507
1508const ICL_FRAMES_PER_TOKEN: usize = 6;
1510
1511const ICL_MIN_REPETITION_PENALTY: f64 = 1.5;
1513
1514pub struct StreamingSession<'a> {
1519 model: &'a Qwen3TTS,
1520 config: generation::GenerationConfig,
1521 sampling_ctx: generation::SamplingContext,
1522 kv_caches: Vec<AnyKVCache>,
1523 offset: usize,
1524 last_hidden: Tensor,
1525 current_token: Option<u32>,
1526 current_token_tensor: Option<Tensor>,
1527 frames_generated: usize,
1528 frame_buffer: FrameCodes,
1529 chunk_frames: usize,
1530 done: bool,
1531 trailing_text_hidden: Tensor,
1533 trailing_text_len: usize,
1534 tts_pad_embed: Tensor,
1535 penalty_mask: Tensor,
1537 token_count: usize,
1538 suppression_mask: generation::SuppressionMask,
1540 cp_kv_caches: Vec<AnyKVCache>,
1542}
1543
1544impl<'a> StreamingSession<'a> {
1545 fn new(
1546 model: &'a Qwen3TTS,
1547 input_ids: &[u32],
1548 speaker: Speaker,
1549 language: Language,
1550 options: SynthesisOptions,
1551 ) -> Result<Self> {
1552 let sampling_ctx = generation::SamplingContext::new(options.seed);
1553 let config = options.to_gen_config();
1554
1555 let (trailing_text_hidden, trailing_text_len, tts_pad_embed) =
1556 model.build_trailing_text(input_ids)?;
1557
1558 let mut kv_caches = model.talker.new_kv_caches(config.max_new_tokens + 256);
1559 let prefill_result =
1560 model
1561 .talker
1562 .prefill_custom_voice(input_ids, speaker, language, &mut kv_caches)?;
1563
1564 Self::from_prefill(
1565 model,
1566 config,
1567 sampling_ctx,
1568 kv_caches,
1569 prefill_result,
1570 trailing_text_hidden,
1571 trailing_text_len,
1572 tts_pad_embed,
1573 options.chunk_frames,
1574 )
1575 }
1576
1577 fn new_voice_design(
1582 model: &'a Qwen3TTS,
1583 input_ids: &[u32],
1584 instruct_ids: &[u32],
1585 language: Language,
1586 options: SynthesisOptions,
1587 ) -> Result<Self> {
1588 let sampling_ctx = generation::SamplingContext::new(options.seed);
1589 let config = options.to_gen_config();
1590
1591 let (trailing_text_hidden, trailing_text_len, tts_pad_embed) =
1592 model.build_trailing_text(input_ids)?;
1593
1594 let mut kv_caches = model.talker.new_kv_caches(config.max_new_tokens + 256);
1595 let prefill_result =
1596 model
1597 .talker
1598 .prefill_voice_design(input_ids, instruct_ids, language, &mut kv_caches)?;
1599
1600 Self::from_prefill(
1601 model,
1602 config,
1603 sampling_ctx,
1604 kv_caches,
1605 prefill_result,
1606 trailing_text_hidden,
1607 trailing_text_len,
1608 tts_pad_embed,
1609 options.chunk_frames,
1610 )
1611 }
1612
1613 #[allow(clippy::too_many_arguments)]
1618 fn from_prefill(
1619 model: &'a Qwen3TTS,
1620 config: generation::GenerationConfig,
1621 mut sampling_ctx: generation::SamplingContext,
1622 kv_caches: Vec<AnyKVCache>,
1623 prefill_result: (Tensor, Tensor),
1624 trailing_text_hidden: Tensor,
1625 trailing_text_len: usize,
1626 tts_pad_embed: Tensor,
1627 chunk_frames: usize,
1628 ) -> Result<Self> {
1629 let (hidden, logits) = prefill_result;
1630 let prefill_len = hidden.dim(1)?;
1631 let last_hidden = hidden.i((.., prefill_len - 1..prefill_len, ..))?;
1632
1633 let suppression_mask = generation::build_suppression_mask(
1635 codec_tokens::CODEC_VOCAB_SIZE,
1636 CODEC_EOS_TOKEN_ID,
1637 &model.device,
1638 )?;
1639
1640 let vocab_size = codec_tokens::CODEC_VOCAB_SIZE;
1642 let mut penalty_mask = Tensor::zeros((1, vocab_size), DType::F32, &model.device)?;
1643 let logits_2d = logits.squeeze(1)?;
1644 let logits_2d = model.apply_generation_penalties_gpu(
1645 &logits_2d,
1646 &penalty_mask,
1647 &config,
1648 0,
1649 Some(&suppression_mask),
1650 )?;
1651 let first_token = generation::sample(&logits_2d, &config, &mut sampling_ctx)?;
1652 let first_token_id: u32 = first_token.flatten_all()?.to_vec1::<u32>()?[0];
1653 Qwen3TTS::update_penalty_mask(&mut penalty_mask, first_token_id, vocab_size)?;
1654
1655 let done = config.eos_token_id == Some(first_token_id);
1656 let cp_kv_caches = model.code_predictor.new_kv_caches();
1657
1658 Ok(Self {
1659 model,
1660 config,
1661 sampling_ctx,
1662 kv_caches,
1663 offset: prefill_len,
1664 last_hidden,
1665 current_token: if done { None } else { Some(first_token_id) },
1666 current_token_tensor: if done { None } else { Some(first_token) },
1667 frames_generated: 0,
1668 frame_buffer: Vec::new(),
1669 chunk_frames,
1670 done,
1671 trailing_text_hidden,
1672 trailing_text_len,
1673 tts_pad_embed,
1674 penalty_mask,
1675 token_count: 1,
1676 suppression_mask,
1677 cp_kv_caches,
1678 })
1679 }
1680
1681 pub fn next_chunk(&mut self) -> Result<Option<AudioBuffer>> {
1685 if self.done {
1686 if !self.frame_buffer.is_empty() {
1688 let codes = self.model.codes_to_tensor(&self.frame_buffer)?;
1689 self.frame_buffer.clear();
1690 let audio = self.model.decoder.decode(&codes)?;
1691 return Ok(Some(AudioBuffer::from_tensor(audio, 24000)?));
1692 }
1693 return Ok(None);
1694 }
1695
1696 while self.frame_buffer.len() < self.chunk_frames
1698 && self.frames_generated < self.config.max_new_tokens
1699 {
1700 let (token_id, token_tensor) =
1701 match (self.current_token, self.current_token_tensor.take()) {
1702 (Some(id), Some(t)) => (id, t),
1703 _ => {
1704 self.done = true;
1705 break;
1706 }
1707 };
1708
1709 let semantic_embed = self
1711 .model
1712 .talker
1713 .get_codec_embedding_from_tensor(&token_tensor)?;
1714
1715 let acoustic_codes_tensor = self.model.code_predictor.generate_acoustic_codes(
1717 &self.last_hidden,
1718 &semantic_embed,
1719 &mut self.cp_kv_caches,
1720 )?;
1721
1722 let semantic_t = Tensor::new(&[token_id], self.model.device())?;
1724 let frame_tensor = Tensor::cat(&[&semantic_t, &acoustic_codes_tensor], 0)?;
1725 let frame_codes: Vec<u32> = frame_tensor.to_vec1()?;
1726 self.frame_buffer.push(frame_codes);
1727
1728 let frame_idx = self.frames_generated;
1729 self.frames_generated += 1;
1730
1731 let acoustic_embed_sum = self
1733 .model
1734 .code_predictor
1735 .get_acoustic_embeddings_sum_from_tensor(&acoustic_codes_tensor)?;
1736 let summed = semantic_embed.add(&acoustic_embed_sum)?;
1737
1738 let text_addition = if frame_idx < self.trailing_text_len {
1739 self.trailing_text_hidden
1740 .i((.., frame_idx..frame_idx + 1, ..))?
1741 } else {
1742 self.tts_pad_embed.clone()
1743 };
1744 let step_input = summed.add(&text_addition)?;
1745
1746 let (h, new_logits) = self.model.talker.generate_step_with_embed(
1748 &step_input,
1749 &mut self.kv_caches,
1750 self.offset,
1751 )?;
1752 self.offset += 1;
1753 self.last_hidden = h;
1754
1755 let logits_2d = new_logits.squeeze(1)?;
1757 let logits_2d = self.model.apply_generation_penalties_gpu(
1758 &logits_2d,
1759 &self.penalty_mask,
1760 &self.config,
1761 self.token_count,
1762 Some(&self.suppression_mask),
1763 )?;
1764 let next_token_tensor =
1765 generation::sample(&logits_2d, &self.config, &mut self.sampling_ctx)?;
1766 let next_token_id: u32 = next_token_tensor.flatten_all()?.to_vec1::<u32>()?[0];
1767 Qwen3TTS::update_penalty_mask(
1768 &mut self.penalty_mask,
1769 next_token_id,
1770 codec_tokens::CODEC_VOCAB_SIZE,
1771 )?;
1772 self.token_count += 1;
1773
1774 if self.config.eos_token_id == Some(next_token_id) {
1775 self.current_token = None;
1776 self.current_token_tensor = None;
1777 self.done = true;
1778 } else {
1779 self.current_token = Some(next_token_id);
1780 self.current_token_tensor = Some(next_token_tensor);
1781 }
1782 }
1783
1784 if self.frame_buffer.is_empty() {
1786 return Ok(None);
1787 }
1788
1789 let codes = self.model.codes_to_tensor(&self.frame_buffer)?;
1790 self.frame_buffer.clear();
1791 let audio = self.model.decoder.decode(&codes)?;
1792 Ok(Some(AudioBuffer::from_tensor(audio, 24000)?))
1793 }
1794
1795 pub fn frames_generated(&self) -> usize {
1797 self.frames_generated
1798 }
1799
1800 pub fn is_done(&self) -> bool {
1802 self.done && self.frame_buffer.is_empty()
1803 }
1804}
1805
1806impl<'a> Iterator for StreamingSession<'a> {
1807 type Item = Result<AudioBuffer>;
1808
1809 fn next(&mut self) -> Option<Self::Item> {
1810 match self.next_chunk() {
1811 Ok(Some(audio)) => Some(Ok(audio)),
1812 Ok(None) => None,
1813 Err(e) => Some(Err(e)),
1814 }
1815 }
1816}
1817
1818#[derive(Debug, Clone)]
1820pub struct SynthesisOptions {
1821 pub max_length: usize,
1823 pub temperature: f64,
1825 pub top_k: usize,
1827 pub top_p: f64,
1829 pub repetition_penalty: f64,
1831 pub eos_token_id: Option<u32>,
1833 pub chunk_frames: usize,
1835 pub min_new_tokens: usize,
1837 pub seed: Option<u64>,
1839}
1840
1841impl SynthesisOptions {
1842 pub fn to_gen_config(&self) -> generation::GenerationConfig {
1844 generation::GenerationConfig {
1845 max_new_tokens: self.max_length,
1846 temperature: self.temperature,
1847 top_k: self.top_k,
1848 top_p: self.top_p,
1849 repetition_penalty: self.repetition_penalty,
1850 eos_token_id: self.eos_token_id,
1851 min_new_tokens: self.min_new_tokens,
1852 }
1853 }
1854}
1855
1856impl Default for SynthesisOptions {
1857 fn default() -> Self {
1858 Self {
1859 max_length: 2048,
1860 temperature: 0.9,
1861 top_k: 50,
1862 top_p: 0.9,
1863 repetition_penalty: 1.05,
1864 eos_token_id: Some(CODEC_EOS_TOKEN_ID),
1865 chunk_frames: 10, min_new_tokens: 2,
1867 seed: None,
1868 }
1869 }
1870}
1871
1872pub fn auto_device() -> Result<Device> {
1889 #[cfg(feature = "cuda")]
1890 {
1891 if let Ok(device) = Device::cuda_if_available(0) {
1892 if device.is_cuda() {
1893 tracing::info!("Using CUDA device");
1894 return Ok(device);
1895 }
1896 }
1897 }
1898
1899 #[cfg(feature = "metal")]
1900 {
1901 if let Ok(device) = Device::new_metal(0) {
1902 tracing::info!("Using Metal device");
1903 return Ok(device);
1904 }
1905 }
1906
1907 tracing::info!("Using CPU device");
1908 Ok(Device::Cpu)
1909}
1910
1911pub fn parse_device(device_str: &str) -> Result<Device> {
1925 match device_str.to_lowercase().as_str() {
1926 "auto" => auto_device(),
1927 "cpu" => Ok(Device::Cpu),
1928 s if s.starts_with("cuda") => {
1929 #[cfg(feature = "cuda")]
1930 {
1931 let ordinal: usize = if s == "cuda" {
1932 0
1933 } else if let Some(idx) = s.strip_prefix("cuda:") {
1934 idx.parse()
1935 .map_err(|e| anyhow::anyhow!("invalid CUDA device index: {e}"))?
1936 } else {
1937 0
1938 };
1939 Device::cuda_if_available(ordinal)
1940 .map_err(|e| anyhow::anyhow!("failed to init CUDA device {ordinal}: {e}"))
1941 }
1942 #[cfg(not(feature = "cuda"))]
1943 anyhow::bail!("CUDA support not compiled in. Rebuild with: cargo build --features cuda")
1944 }
1945 "metal" => {
1946 #[cfg(feature = "metal")]
1947 {
1948 Device::new_metal(0)
1949 .map_err(|e| anyhow::anyhow!("failed to init Metal device: {e}"))
1950 }
1951 #[cfg(not(feature = "metal"))]
1952 anyhow::bail!(
1953 "Metal support not compiled in. Rebuild with: cargo build --features metal"
1954 )
1955 }
1956 other => {
1957 anyhow::bail!("unknown device '{other}'. Supported: auto, cpu, cuda, cuda:N, metal")
1958 }
1959 }
1960}
1961
1962pub fn device_info(device: &Device) -> String {
1964 match device {
1965 Device::Cpu => "CPU".to_string(),
1966 Device::Cuda(_) => "CUDA".to_string(),
1967 Device::Metal(_) => "Metal".to_string(),
1968 }
1969}
1970
1971#[cfg(test)]
1972mod tests {
1973 use super::*;
1974
1975 #[test]
1976 fn test_synthesis_options_default() {
1977 let options = SynthesisOptions::default();
1978 assert_eq!(options.max_length, 2048);
1979 assert!((options.temperature - 0.9).abs() < 1e-6);
1980 assert_eq!(options.top_k, 50);
1981 assert!((options.top_p - 0.9).abs() < 1e-6);
1982 assert!((options.repetition_penalty - 1.05).abs() < 1e-6);
1983 assert_eq!(options.eos_token_id, Some(CODEC_EOS_TOKEN_ID));
1984 assert_eq!(options.chunk_frames, 10);
1985 assert_eq!(options.min_new_tokens, 2);
1986 }
1987
1988 #[test]
1989 fn test_synthesis_options_custom() {
1990 let options = SynthesisOptions {
1991 max_length: 512,
1992 temperature: 0.5,
1993 top_k: 10,
1994 top_p: 0.8,
1995 repetition_penalty: 1.2,
1996 eos_token_id: Some(CODEC_EOS_TOKEN_ID),
1997 chunk_frames: 5,
1998 min_new_tokens: 0,
1999 seed: Some(42),
2000 };
2001 assert_eq!(options.max_length, 512);
2002 assert!((options.temperature - 0.5).abs() < 1e-6);
2003 assert_eq!(options.eos_token_id, Some(CODEC_EOS_TOKEN_ID));
2004 assert_eq!(options.chunk_frames, 5);
2005 }
2006
2007 #[test]
2008 fn test_synthesis_options_clone() {
2009 let options = SynthesisOptions::default();
2010 let cloned = options.clone();
2011 assert_eq!(cloned.max_length, options.max_length);
2012 assert_eq!(cloned.top_k, options.top_k);
2013 }
2014
2015 #[test]
2016 fn test_synthesis_options_debug() {
2017 let options = SynthesisOptions::default();
2018 let debug_str = format!("{:?}", options);
2019 assert!(debug_str.contains("max_length"));
2020 assert!(debug_str.contains("2048"));
2021 }
2022
2023 #[test]
2024 fn test_auto_device() {
2025 let device = auto_device().unwrap();
2027 assert!(
2029 matches!(device, Device::Cpu)
2030 || matches!(device, Device::Cuda(_))
2031 || matches!(device, Device::Metal(_))
2032 );
2033 }
2034
2035 #[test]
2036 fn test_audio_buffer_reexport() {
2037 let buffer = AudioBuffer::new(vec![0.0f32; 100], 24000);
2039 assert_eq!(buffer.sample_rate, 24000);
2040 }
2041
2042 #[test]
2043 fn test_config_reexport() {
2044 let config = Qwen3TTSConfig::default();
2046 assert_eq!(config.model_type, "qwen3_tts");
2047 }
2048
2049 #[test]
2050 fn test_codes_to_tensor_empty() {
2051 let device = Device::Cpu;
2052 let codes: Vec<Vec<u32>> = vec![];
2053 let tensor = codes_to_tensor(&codes, &device).unwrap();
2054 assert_eq!(tensor.dims(), &[1, 16, 0]);
2055 }
2056
2057 #[test]
2058 fn test_codes_to_tensor_single_frame() {
2059 let device = Device::Cpu;
2060 let codes = vec![vec![0u32; 16]];
2061 let tensor = codes_to_tensor(&codes, &device).unwrap();
2062 assert_eq!(tensor.dims(), &[1, 16, 1]);
2063 }
2064
2065 #[test]
2066 fn test_codes_to_tensor_layout() {
2067 let device = Device::Cpu;
2068 let codes = vec![
2070 (0..16).map(|i| i as u32).collect::<Vec<_>>(), (100..116).map(|i| i as u32).collect::<Vec<_>>(), ];
2073 let tensor = codes_to_tensor(&codes, &device).unwrap();
2074 assert_eq!(tensor.dims(), &[1, 16, 2]);
2075
2076 let vals: Vec<i64> = tensor.flatten_all().unwrap().to_vec1().unwrap();
2078 assert_eq!(vals[0], 0);
2080 assert_eq!(vals[1], 100);
2081 assert_eq!(vals[2], 1);
2083 assert_eq!(vals[3], 101);
2084 }
2085
2086 #[test]
2087 fn test_parse_device_cpu() {
2088 let device = parse_device("cpu").unwrap();
2089 assert!(matches!(device, Device::Cpu));
2090 }
2091
2092 #[test]
2093 fn test_parse_device_auto() {
2094 let device = parse_device("auto").unwrap();
2095 assert!(
2097 matches!(device, Device::Cpu)
2098 || matches!(device, Device::Cuda(_))
2099 || matches!(device, Device::Metal(_))
2100 );
2101 }
2102
2103 #[test]
2104 fn test_parse_device_unknown() {
2105 let result = parse_device("tpu");
2106 assert!(result.is_err());
2107 }
2108
2109 #[test]
2110 fn test_parse_device_case_insensitive() {
2111 let device = parse_device("CPU").unwrap();
2112 assert!(matches!(device, Device::Cpu));
2113 }
2114
2115 #[test]
2116 fn test_device_info() {
2117 assert_eq!(device_info(&Device::Cpu), "CPU");
2118 }
2119
2120 #[test]
2121 fn test_compute_dtype_for_device() {
2122 let dtype = compute_dtype_for_device(&Device::Cpu);
2123 assert_eq!(dtype, DType::F32);
2124 }
2125
2126 #[test]
2127 fn test_update_penalty_mask() {
2128 let device = Device::Cpu;
2129 let vocab_size = 3072;
2130 let mut mask = Tensor::zeros((1, vocab_size), DType::F32, &device).unwrap();
2131
2132 Qwen3TTS::update_penalty_mask(&mut mask, 42, vocab_size).unwrap();
2133
2134 let vals: Vec<f32> = mask.flatten_all().unwrap().to_vec1().unwrap();
2135 assert_eq!(vals[42], 1.0);
2136 assert_eq!(vals[41], 0.0);
2138 assert_eq!(vals[43], 0.0);
2139 }
2140
2141 #[test]
2142 fn test_update_penalty_mask_out_of_range() {
2143 let device = Device::Cpu;
2144 let vocab_size = 3072;
2145 let mut mask = Tensor::zeros((1, vocab_size), DType::F32, &device).unwrap();
2146
2147 Qwen3TTS::update_penalty_mask(&mut mask, 9999, vocab_size).unwrap();
2149
2150 let sum: f32 = mask.sum_all().unwrap().to_scalar().unwrap();
2151 assert_eq!(sum, 0.0);
2152 }
2153
2154 #[test]
2155 fn test_suppression_mask_deterministic() {
2156 let device = Device::Cpu;
2157 let vocab = codec_tokens::CODEC_VOCAB_SIZE;
2158 let mask1 = generation::build_suppression_mask(vocab, CODEC_EOS_TOKEN_ID, &device).unwrap();
2159 let mask2 = generation::build_suppression_mask(vocab, CODEC_EOS_TOKEN_ID, &device).unwrap();
2160
2161 let logits = Tensor::ones((1, vocab), DType::F32, &device).unwrap();
2163 let out1 = generation::apply_token_suppression_with_mask(&logits, &mask1).unwrap();
2164 let out2 = generation::apply_token_suppression_with_mask(&logits, &mask2).unwrap();
2165 let v1: Vec<f32> = out1.flatten_all().unwrap().to_vec1().unwrap();
2166 let v2: Vec<f32> = out2.flatten_all().unwrap().to_vec1().unwrap();
2167 assert_eq!(v1, v2);
2168 }
2169}