1pub mod asr_candle_pt;
2pub mod config;
3pub mod silero_vad;
4pub mod wavfrontend;
5
6use core::fmt;
7#[cfg(feature = "rknpu")]
8use std::{fs::File, io::BufReader};
9
10use hf_hub::api::sync::Api;
11use hound::WavReader;
12#[cfg(feature = "rknpu")]
13use ndarray::Axis;
14use ndarray::{s, ArrayView3};
15#[cfg(feature = "rknpu")]
16use ndarray::{Array2, Array3};
17#[cfg(feature = "rknpu")]
18use ndarray_npy::ReadNpyExt;
19use regex::Regex;
20#[cfg(feature = "rknpu")]
21use rknn_rs::prelude::{Rknn, RknnTensorFormat, RknnTensorType};
22use sentencepiece::SentencePieceProcessor;
23
24use asr_candle_pt::CandlePtAsrSession;
25use config::SenseVoiceConfig;
26use silero_vad::{VadConfig, VadOutput, VadProcessor, CHUNK_SIZE};
27use wavfrontend::{WavFrontend, WavFrontendConfig};
28
29#[cfg(feature = "stream")]
30use async_stream::stream;
31#[cfg(feature = "stream")]
32use futures::stream::Stream;
33#[cfg(feature = "stream")]
34use futures::StreamExt;
35
36#[derive(Debug, Copy, Clone)]
40pub enum SenseVoiceLanguage {
41 En,
43 Zh,
45 Yue,
47 Ja,
49 Ko,
51 NoSpeech,
53}
54
55impl SenseVoiceLanguage {
57 fn from_str(s: &str) -> Option<Self> {
69 match s.to_lowercase().as_str() {
70 "en" => Some(SenseVoiceLanguage::En),
71 "zh" => Some(SenseVoiceLanguage::Zh),
72 "yue" => Some(SenseVoiceLanguage::Yue),
73 "ja" => Some(SenseVoiceLanguage::Ja),
74 "ko" => Some(SenseVoiceLanguage::Ko),
75 "nospeech" => Some(SenseVoiceLanguage::NoSpeech),
76 _ => None,
77 }
78 }
79}
80
81#[derive(Debug, Copy, Clone)]
85pub enum SenseVoiceEmo {
86 Happy,
88 Sad,
90 Angry,
92 Neutral,
94 Fearful,
96 Disgusted,
98 Surprised,
100 Unknown,
102}
103
104impl SenseVoiceEmo {
106 fn from_str(s: &str) -> Option<Self> {
118 match s.to_uppercase().as_str() {
119 "HAPPY" => Some(SenseVoiceEmo::Happy),
120 "SAD" => Some(SenseVoiceEmo::Sad),
121 "ANGRY" => Some(SenseVoiceEmo::Angry),
122 "NEUTRAL" => Some(SenseVoiceEmo::Neutral),
123 "FEARFUL" => Some(SenseVoiceEmo::Fearful),
124 "DISGUSTED" => Some(SenseVoiceEmo::Disgusted),
125 "SURPRISED" => Some(SenseVoiceEmo::Surprised),
126 "EMO_UNKNOWN" => Some(SenseVoiceEmo::Unknown),
127 _ => None,
128 }
129 }
130}
131
132#[derive(Debug, Copy, Clone)]
136pub enum SenseVoiceEvent {
137 Bgm,
139 Speech,
141 Applause,
143 Laughter,
145 Cry,
147 Sneeze,
149 Breath,
151 Cough,
153 Unknown,
155}
156
157impl SenseVoiceEvent {
159 fn from_str(s: &str) -> Option<Self> {
171 match s.to_uppercase().as_str() {
172 "BGM" => Some(SenseVoiceEvent::Bgm),
173 "SPEECH" => Some(SenseVoiceEvent::Speech),
174 "APPLAUSE" => Some(SenseVoiceEvent::Applause),
175 "LAUGHTER" => Some(SenseVoiceEvent::Laughter),
176 "CRY" => Some(SenseVoiceEvent::Cry),
177 "SNEEZE" => Some(SenseVoiceEvent::Sneeze),
178 "BREATH" => Some(SenseVoiceEvent::Breath),
179 "COUGH" => Some(SenseVoiceEvent::Cough),
180 "EVENT_UNK" => Some(SenseVoiceEvent::Unknown),
181 _ => None,
182 }
183 }
184}
185
186#[derive(Debug, Copy, Clone)]
190pub enum SenseVoicePunctuationNormalization {
191 With,
193 Woitn,
195}
196
197impl SenseVoicePunctuationNormalization {
199 fn from_str(s: &str) -> Option<Self> {
211 match s.to_lowercase().as_str() {
212 "with" => Some(SenseVoicePunctuationNormalization::With),
213 "woitn" => Some(SenseVoicePunctuationNormalization::Woitn),
214 _ => None,
215 }
216 }
217}
218
219#[derive(Debug)]
223pub struct VoiceText {
224 pub language: SenseVoiceLanguage,
226 pub emotion: SenseVoiceEmo,
228 pub event: SenseVoiceEvent,
230 pub punctuation_normalization: SenseVoicePunctuationNormalization,
232 pub content: String,
234}
235
236fn parse_line(line: &str) -> Option<VoiceText> {
250 let re = Regex::new(r"^<\|(.*?)\|><\|(.*?)\|><\|(.*?)\|><\|(.*?)\|>(.*)$").unwrap();
251 if let Some(caps) = re.captures(line) {
252 let lang_str = &caps[1];
253 let emo_str = &caps[2];
254 let event_str = &caps[3];
255 let punct_str = &caps[4];
256 let content = &caps[5];
257
258 let language = SenseVoiceLanguage::from_str(lang_str)?;
259 let emotion = SenseVoiceEmo::from_str(emo_str)?;
260 let event = SenseVoiceEvent::from_str(event_str)?;
261 let punctuation_normalization = SenseVoicePunctuationNormalization::from_str(punct_str)?;
262
263 Some(VoiceText {
264 language,
265 emotion,
266 event,
267 punctuation_normalization,
268 content: content.to_string(),
269 })
270 } else {
271 None
272 }
273}
274
275#[derive(Debug)]
279struct SenseVoiceSmallError {
280 message: String,
282}
283
284impl fmt::Display for SenseVoiceSmallError {
286 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
287 write!(f, "SenseVoiceSmallError: {}", self.message)
288 }
289}
290
291impl std::error::Error for SenseVoiceSmallError {}
293
294impl SenseVoiceSmallError {
296 pub fn new(message: &str) -> Self {
306 SenseVoiceSmallError {
307 message: message.to_owned(),
308 }
309 }
310}
311
312#[derive(Debug)]
317pub struct SenseVoiceSmall {
318 asr_frontend: WavFrontend,
319 #[cfg(feature = "rknpu")]
320 n_seq: usize,
321 spp: SentencePieceProcessor,
322
323 #[cfg(feature = "rknpu")]
325 rknn: Option<Rknn>,
326 #[cfg(feature = "rknpu")]
327 embedding: Option<ndarray::Array2<f32>>,
328
329 candle_pt_asr: Option<CandlePtAsrSession>,
330
331 vad_config: VadConfig,
333 #[cfg(feature = "stream")]
338 silero_vad: VadProcessor,
339
340 use_rknn: bool,
341}
342
343impl SenseVoiceSmall {
345 pub fn init(vadconfig: VadConfig) -> Result<Self, Box<dyn std::error::Error>> {
358 #[cfg(feature = "rknpu")]
359 {
360 let model_path = "happyme531/SenseVoiceSmall-RKNN2";
362
363 let api = Api::new().unwrap();
364 let repo = api.model(model_path.to_string());
365
366 let embedding_path = repo.get("embedding.npy")?;
368 let rknn_path = repo.get("sense-voice-encoder.rknn")?;
369 let sentence_path = repo.get("chn_jpn_yue_eng_ko_spectok.bpe.model")?;
370 let am_path = repo.get("am.mvn")?;
371
372 let config = SenseVoiceConfig {
373 model_path: rknn_path,
374 tokenizer_path: sentence_path,
375 cmvn_path: Some(am_path),
376 };
377
378 let embedding_file = File::open(embedding_path)?;
379 let embedding_reader = BufReader::new(embedding_file);
380 let embedding: Array2<f32> = Array2::read_npy(embedding_reader)?;
381 assert_eq!(embedding.shape()[1], 560, "Embedding dimension must be 560");
382
383 let rknn = Rknn::rknn_init(config.model_path)?;
384 let spp = SentencePieceProcessor::open(config.tokenizer_path)?;
385
386 let n_seq = 171;
387
388 let asr_frontend = WavFrontend::new(WavFrontendConfig {
391 lfr_m: 7,
392 cmvn_file: Some(
393 config
394 .cmvn_path
395 .as_ref()
396 .unwrap()
397 .to_str()
398 .unwrap()
399 .to_owned(),
400 ),
401 ..Default::default()
402 })?;
403
404 #[cfg(feature = "stream")]
405 let silero_vad = VadProcessor::new(vadconfig)?;
406
407 Ok(SenseVoiceSmall {
408 asr_frontend,
409 n_seq,
410 spp,
411 rknn: Some(rknn),
412 embedding: Some(embedding),
413 candle_pt_asr: None,
414 vad_config: vadconfig,
415 #[cfg(feature = "stream")]
416 silero_vad,
417 use_rknn: true,
418 })
419 }
420 #[cfg(not(feature = "rknpu"))]
421 {
422 Self::init_official_model_pt(vadconfig)
423 }
424 }
425
426 pub fn init_official_model_pt(
429 vadconfig: VadConfig,
430 ) -> Result<Self, Box<dyn std::error::Error>> {
431 let api = Api::new().unwrap();
432 let repo = api.model("FunAudioLLM/SenseVoiceSmall".to_owned());
433
434 let config = SenseVoiceConfig {
435 model_path: repo.get("model.pt")?,
436 tokenizer_path: repo.get("chn_jpn_yue_eng_ko_spectok.bpe.model")?,
437 cmvn_path: Some(repo.get("am.mvn")?),
438 };
439
440 Self::init_with_config(config, vadconfig)
441 }
442
443 pub fn init_with_config(
454 config: SenseVoiceConfig,
455 vadconfig: VadConfig,
456 ) -> Result<Self, Box<dyn std::error::Error>> {
457 #[cfg(feature = "rknpu")]
458 {
459 let is_rknn_model = config
461 .model_path
462 .extension()
463 .map_or(false, |ext| ext == "rknn");
464 if is_rknn_model {
465 return Err("Manual loading of RKNN models via init_with_config is not fully supported yet (missing embedding path). Use init() for default RKNN model.".into());
466 }
467 }
468
469 let is_pt_model = config
470 .model_path
471 .extension()
472 .and_then(|ext| ext.to_str())
473 .map(|ext| ext.eq_ignore_ascii_case("pt"))
474 .unwrap_or(false);
475 if !is_pt_model {
476 return Err(std::io::Error::other(
477 "Candle ASR now only supports official .pt model paths.",
478 )
479 .into());
480 }
481 let candle_pt_asr = Some(CandlePtAsrSession::new(&config.model_path)?);
482
483 let spp = SentencePieceProcessor::open(&config.tokenizer_path)?;
484
485 let asr_frontend = WavFrontend::new(WavFrontendConfig {
486 lfr_m: 7,
487 cmvn_file: config.cmvn_path.map(|p| p.to_string_lossy().to_string()),
488 ..Default::default()
489 })?;
490
491 #[cfg(feature = "stream")]
492 let silero_vad = VadProcessor::new(vadconfig)?;
493
494 #[cfg(feature = "rknpu")]
495 let n_seq = 0;
496
497 Ok(SenseVoiceSmall {
498 asr_frontend,
499 #[cfg(feature = "rknpu")]
500 n_seq,
501 spp,
502 #[cfg(feature = "rknpu")]
503 rknn: None,
504 #[cfg(feature = "rknpu")]
505 embedding: None,
506 candle_pt_asr,
507 vad_config: vadconfig,
508 #[cfg(feature = "stream")]
509 silero_vad,
510 use_rknn: false,
511 })
512 }
513
514 #[cfg(feature = "stream")]
517 pub fn set_vad_silence_notification(&mut self, ms: Option<u32>) {
518 self.silero_vad.set_notify_silence_after_ms(ms);
519 }
520
521 pub fn infer_vec(
523 &self,
524 content: Vec<i16>,
525 _sample_rate: u32, ) -> Result<Vec<VoiceText>, Box<dyn std::error::Error>> {
527 let mut vad = VadProcessor::new(self.vad_config)?;
529 let mut ret = Vec::new();
530
531 let chunk_size = CHUNK_SIZE;
532 let mut padded_content = content.clone();
534 let remainder = padded_content.len() % chunk_size;
535 if remainder != 0 {
536 padded_content.extend(std::iter::repeat(0).take(chunk_size - remainder));
537 }
538
539 for chunk in padded_content.chunks_exact(chunk_size) {
540 let chunk_arr: &[i16; CHUNK_SIZE] = chunk.try_into()?;
541 if let Some(output) = vad.process_chunk(chunk_arr) {
542 match output {
543 VadOutput::Segment(segment) => {
544 let vt = self.recognition(&segment)?;
545 ret.push(vt);
546 }
547 VadOutput::SilenceNotification => {
548 ret.push(VoiceText {
551 language: SenseVoiceLanguage::NoSpeech,
552 emotion: SenseVoiceEmo::Unknown,
553 event: SenseVoiceEvent::Unknown,
554 punctuation_normalization: SenseVoicePunctuationNormalization::Woitn,
555 content: String::new(),
556 });
557 }
558 }
559 }
560 }
561
562 if let Some(output) = vad.finish() {
563 match output {
564 VadOutput::Segment(segment) => {
565 let vt = self.recognition(&segment)?;
566 ret.push(vt);
567 }
568 VadOutput::SilenceNotification => {
569 ret.push(VoiceText {
571 language: SenseVoiceLanguage::NoSpeech,
572 emotion: SenseVoiceEmo::Unknown,
573 event: SenseVoiceEvent::Unknown,
574 punctuation_normalization: SenseVoicePunctuationNormalization::Woitn,
575 content: String::new(),
576 });
577 }
578 }
579 }
580
581 Ok(ret)
582 }
583
584 pub fn recognition(&self, segment: &[i16]) -> Result<VoiceText, Box<dyn std::error::Error>> {
585 let audio_feats = self.asr_frontend.extract_features(segment)?;
587
588 if self.use_rknn {
589 #[cfg(feature = "rknpu")]
590 {
591 if let Some(rknn) = &self.rknn {
592 self.prepare_rknn_input_advanced(&audio_feats, 0, false)?;
594 rknn.run()?;
595 let asr_output = rknn.outputs_get::<f32>()?;
596 let asr_text = self.decode_asr_output(&asr_output)?;
597 return match parse_line(&asr_text) {
598 Some(vt) => Ok(vt),
599 None => Err(format!("Parse line failed, text is:{}, If u still get empty text, please check your vad config. This model only can infer 9 secs voice.", asr_text).into()),
600 };
601 }
602 }
603 return Err("RKNN is enabled but model is not initialized".into());
604 } else {
605 let seq_len = audio_feats.shape()[0] as i64;
606 let candle_pt_asr = self
607 .candle_pt_asr
608 .as_ref()
609 .ok_or_else(|| std::io::Error::other("Candle ASR session is not initialized"))?;
610 let (output_data, output_shape) = candle_pt_asr.run(&audio_feats, seq_len, 0, 15)?;
611 let asr_text = self.decode_onnx_output(&output_data, &output_shape)?;
612 return match parse_line(&asr_text) {
613 Some(vt) => Ok(vt),
614 None => Err(format!("Parse line failed, text is:{}", asr_text).into()),
615 };
616 }
617 }
618
619 #[cfg(feature = "stream")]
620 pub fn infer_stream<'a, S>(
621 &'a mut self,
622 input_stream: S,
623 ) -> impl Stream<Item = Result<VoiceText, Box<dyn std::error::Error>>> + 'a
624 where
625 S: Stream<Item = Vec<i16>> + Unpin + 'a,
626 {
627 stream! {
628 let mut stream = input_stream;
629 while let Some(chunk) = stream.next().await {
630 if let Ok(chunk_arr) = chunk.as_slice().try_into() {
634 if let Some(output) = self.silero_vad.process_chunk(chunk_arr) {
635 match output {
636 VadOutput::Segment(segment) => {
637 yield self.recognition(&segment);
638 },
639 VadOutput::SilenceNotification => {
640 yield Ok(VoiceText {
641 language: SenseVoiceLanguage::NoSpeech,
642 emotion: SenseVoiceEmo::Unknown,
643 event: SenseVoiceEvent::Unknown,
644 punctuation_normalization: SenseVoicePunctuationNormalization::Woitn,
645 content: String::new(),
646 });
647 }
648 }
649 }
650 } else {
651 }
653 }
654 if let Some(output) = self.silero_vad.finish() {
655 match output {
656 VadOutput::Segment(segment) => {
657 yield self.recognition(&segment);
658 },
659 VadOutput::SilenceNotification => {
660 yield Ok(VoiceText {
661 language: SenseVoiceLanguage::NoSpeech,
662 emotion: SenseVoiceEmo::Unknown,
663 event: SenseVoiceEvent::Unknown,
664 punctuation_normalization: SenseVoicePunctuationNormalization::Woitn,
665 content: String::new(),
666 });
667 }
668 }
669 }
670 }
671 }
672
673 pub fn infer_file<P: AsRef<std::path::Path>>(
675 &self,
676 wav_path: P,
677 ) -> Result<Vec<VoiceText>, Box<dyn std::error::Error>> {
678 let mut wav_reader = WavReader::open(wav_path)?;
679 match wav_reader.spec().sample_rate {
680 8000 => (),
681 16000 => (),
682 _ => {
683 return Err(Box::new(SenseVoiceSmallError::new(
684 "Unsupported sample rate. Expect 8 kHz or 16 kHz.",
685 )))
686 }
687 };
688 if wav_reader.spec().sample_format != hound::SampleFormat::Int {
689 return Err(Box::new(SenseVoiceSmallError::new(
690 "Unsupported sample format. Expect Int.",
691 )));
692 }
693
694 let content = wav_reader
695 .samples()
696 .filter_map(|x| x.ok())
697 .collect::<Vec<i16>>();
698 if content.is_empty() {
699 return Err(Box::new(SenseVoiceSmallError::new(
700 "content is empty, check your audio file",
701 )));
702 }
703
704 self.infer_vec(content, wav_reader.spec().sample_rate)
705 }
706
707 #[cfg(feature = "rknpu")]
709 fn decode_asr_output(&self, output: &[f32]) -> Result<String, Box<dyn std::error::Error>> {
710 let n_vocab = self.spp.len();
712 let output_array = ArrayView3::from_shape((1, n_vocab, self.n_seq), output)?;
714
715 let token_ids: Vec<i32> = output_array
717 .axis_iter(Axis(2)) .into_iter()
719 .map(|slice| {
720 slice
721 .iter()
722 .enumerate()
723 .fold((0, f32::NEG_INFINITY), |(idx, max_val), (i, &val)| {
724 if val > max_val {
725 (i, val)
726 } else {
727 (idx, max_val)
728 }
729 })
730 .0 as i32 })
732 .collect();
733
734 self.ids_to_text(token_ids)
735 }
736
737 fn ids_to_text(&self, token_ids: Vec<i32>) -> Result<String, Box<dyn std::error::Error>> {
739 let mut unique_ids = Vec::new();
741 let mut prev_id = None;
742 for &id in token_ids.iter() {
743 if Some(id) != prev_id && id != 0 {
744 unique_ids.push(id as u32);
745 prev_id = Some(id);
746 } else if Some(id) != prev_id {
747 prev_id = Some(id);
748 }
749 }
750
751 let decoded_text = self.spp.decode_piece_ids(&unique_ids)?;
753 Ok(decoded_text)
754 }
755
756 fn decode_onnx_output(
758 &self,
759 output: &[f32],
760 shape: &[i64],
761 ) -> Result<String, Box<dyn std::error::Error>> {
762 let batch_size = shape[0] as usize;
769 if batch_size != 1 {
770 return Err("Batch size must be 1".into());
771 }
772
773 let n_vocab = self.spp.len(); let dim1 = shape[1] as usize;
779 let dim2 = shape[2] as usize;
780
781 let output_array = ArrayView3::from_shape(
782 (shape[0] as usize, shape[1] as usize, shape[2] as usize),
783 output,
784 )?;
785 let mut token_ids = Vec::new();
786
787 if dim1 == n_vocab {
788 for t in 0..dim2 {
790 let col = output_array.slice(s![0, .., t]);
792 let (best_idx, _) = col.iter().enumerate().fold(
794 (0, f32::NEG_INFINITY),
795 |(acc_idx, acc_val), (i, &val)| {
796 if val > acc_val {
797 (i, val)
798 } else {
799 (acc_idx, acc_val)
800 }
801 },
802 );
803 token_ids.push(best_idx as i32);
804 }
805 } else if dim2 == n_vocab {
806 for t in 0..dim1 {
808 let row = output_array.slice(s![0, t, ..]);
809 let (best_idx, _) = row.iter().enumerate().fold(
810 (0, f32::NEG_INFINITY),
811 |(acc_idx, acc_val), (i, &val)| {
812 if val > acc_val {
813 (i, val)
814 } else {
815 (acc_idx, acc_val)
816 }
817 },
818 );
819 token_ids.push(best_idx as i32);
820 }
821 } else {
822 return Err(format!(
823 "Unexpected output shape: {:?}, expected one dimension to be vocab size {}",
824 shape, n_vocab
825 )
826 .into());
827 }
828
829 self.ids_to_text(token_ids)
830 }
831
832 pub fn destroy(&self) -> Result<(), Box<dyn std::error::Error>> {
849 Ok(())
850 }
851
852 #[cfg(feature = "rknpu")]
867 fn prepare_rknn_input_advanced(
868 &self,
869 feats: &Array2<f32>,
870 language: usize,
871 use_itn: bool,
872 ) -> Result<(), Box<dyn std::error::Error>> {
873 let embedding = self.embedding.as_ref().ok_or("Embedding not loaded")?;
875
876 let language_query = embedding.slice(s![language, ..]).insert_axis(Axis(0));
877 let text_norm_idx = if use_itn { 14 } else { 15 };
878 let text_norm_query = embedding.slice(s![text_norm_idx, ..]).insert_axis(Axis(0));
879 let event_emo_query = embedding.slice(s![1..=2, ..]).to_owned();
880
881 let speech = feats.mapv(|x| x * 0.5);
883
884 let input_content = ndarray::concatenate(
886 Axis(0),
887 &[
888 language_query.view(),
889 event_emo_query.view(),
890 text_norm_query.view(),
891 speech.view(),
892 ],
893 )?;
894
895 let total_frames = input_content.shape()[0];
897 let padded_input = if total_frames < self.n_seq {
898 let mut padded = Array2::zeros((self.n_seq, 560));
899 padded
900 .slice_mut(s![..total_frames, ..])
901 .assign(&input_content);
902 padded
903 } else {
904 input_content.slice(s![..self.n_seq, ..]).to_owned()
905 };
906 let input_3d: Array3<f32> = padded_input.insert_axis(Axis(0)); let contiguous_input = input_3d.as_standard_layout(); let flattened_input: Vec<f32> = contiguous_input
912 .into_shape_with_order(1 * self.n_seq * 560)? .to_vec(); if let Some(rknn) = &self.rknn {
916 rknn.input_set_slice(
917 0, &flattened_input,
919 false, RknnTensorType::Float32,
921 RknnTensorFormat::NCHW,
922 )?;
923 }
924 Ok(())
925 }
926}