Skip to main content

sensevoice_rs/
lib.rs

1pub mod asr_candle_pt;
2pub mod config;
3pub mod silero_vad;
4pub mod wavfrontend;
5
6use core::fmt;
7#[cfg(feature = "rknpu")]
8use std::{fs::File, io::BufReader};
9
10use hf_hub::api::sync::Api;
11use hound::WavReader;
12#[cfg(feature = "rknpu")]
13use ndarray::Axis;
14use ndarray::{s, ArrayView3};
15#[cfg(feature = "rknpu")]
16use ndarray::{Array2, Array3};
17#[cfg(feature = "rknpu")]
18use ndarray_npy::ReadNpyExt;
19use regex::Regex;
20#[cfg(feature = "rknpu")]
21use rknn_rs::prelude::{Rknn, RknnTensorFormat, RknnTensorType};
22use sentencepiece::SentencePieceProcessor;
23
24use asr_candle_pt::CandlePtAsrSession;
25use config::SenseVoiceConfig;
26use silero_vad::{VadConfig, VadOutput, VadProcessor, CHUNK_SIZE};
27use wavfrontend::{WavFrontend, WavFrontendConfig};
28
29#[cfg(feature = "stream")]
30use async_stream::stream;
31#[cfg(feature = "stream")]
32use futures::stream::Stream;
33#[cfg(feature = "stream")]
34use futures::StreamExt;
35
36/// Represents supported languages for speech recognition.
37///
38/// This enum defines the languages supported by the `SenseVoiceSmall` model.
39#[derive(Debug, Copy, Clone)]
40pub enum SenseVoiceLanguage {
41    /// English
42    En,
43    /// Chinese (Mandarin)
44    Zh,
45    /// Cantonese
46    Yue,
47    /// Japanese
48    Ja,
49    /// Korean
50    Ko,
51    /// No Speech / Silence
52    NoSpeech,
53}
54
55/// Implementation of methods for `SenseVoiceLanguage`.
56impl SenseVoiceLanguage {
57    /// Converts a string to a `SenseVoiceLanguage` variant.
58    ///
59    /// This method parses a string (case-insensitive) and returns the corresponding language variant.
60    ///
61    /// # Arguments
62    ///
63    /// * `s` - The string to parse (e.g., "en", "ZH").
64    ///
65    /// # Returns
66    ///
67    /// An `Option<SenseVoiceLanguage>` where `None` indicates an unrecognized language string.
68    fn from_str(s: &str) -> Option<Self> {
69        match s.to_lowercase().as_str() {
70            "en" => Some(SenseVoiceLanguage::En),
71            "zh" => Some(SenseVoiceLanguage::Zh),
72            "yue" => Some(SenseVoiceLanguage::Yue),
73            "ja" => Some(SenseVoiceLanguage::Ja),
74            "ko" => Some(SenseVoiceLanguage::Ko),
75            "nospeech" => Some(SenseVoiceLanguage::NoSpeech),
76            _ => None,
77        }
78    }
79}
80
81/// Represents possible emotions detected in speech.
82///
83/// This enum defines the emotional states that can be identified in audio segments.
84#[derive(Debug, Copy, Clone)]
85pub enum SenseVoiceEmo {
86    /// Happy emotion
87    Happy,
88    /// Sad emotion
89    Sad,
90    /// Angry emotion
91    Angry,
92    /// Neutral emotion
93    Neutral,
94    /// Fearful emotion
95    Fearful,
96    /// Disgusted emotion
97    Disgusted,
98    /// Surprised emotion
99    Surprised,
100    /// Unknown emotion
101    Unknown,
102}
103
104/// Implementation of methods for `SenseVoiceEmo`.
105impl SenseVoiceEmo {
106    /// Converts a string to a `SenseVoiceEmo` variant.
107    ///
108    /// This method parses a string (case-insensitive) and returns the corresponding emotion variant.
109    ///
110    /// # Arguments
111    ///
112    /// * `s` - The string to parse (e.g., "HAPPY", "sad").
113    ///
114    /// # Returns
115    ///
116    /// An `Option<SenseVoiceEmo>` where `None` indicates an unrecognized emotion string.
117    fn from_str(s: &str) -> Option<Self> {
118        match s.to_uppercase().as_str() {
119            "HAPPY" => Some(SenseVoiceEmo::Happy),
120            "SAD" => Some(SenseVoiceEmo::Sad),
121            "ANGRY" => Some(SenseVoiceEmo::Angry),
122            "NEUTRAL" => Some(SenseVoiceEmo::Neutral),
123            "FEARFUL" => Some(SenseVoiceEmo::Fearful),
124            "DISGUSTED" => Some(SenseVoiceEmo::Disgusted),
125            "SURPRISED" => Some(SenseVoiceEmo::Surprised),
126            "EMO_UNKNOWN" => Some(SenseVoiceEmo::Unknown),
127            _ => None,
128        }
129    }
130}
131
132/// Represents types of audio events detected in speech.
133///
134/// This enum defines the categories of events that can occur within audio segments.
135#[derive(Debug, Copy, Clone)]
136pub enum SenseVoiceEvent {
137    /// Background music
138    Bgm,
139    /// Speech content
140    Speech,
141    /// Applause sound
142    Applause,
143    /// Laughter sound
144    Laughter,
145    /// Crying sound
146    Cry,
147    /// Sneezing sound
148    Sneeze,
149    /// Breathing sound
150    Breath,
151    /// Coughing sound
152    Cough,
153    /// Unknown event
154    Unknown,
155}
156
157/// Implementation of methods for `SenseVoiceEvent`.
158impl SenseVoiceEvent {
159    /// Converts a string to a `SenseVoiceEvent` variant.
160    ///
161    /// This method parses a string (case-insensitive) and returns the corresponding event variant.
162    ///
163    /// # Arguments
164    ///
165    /// * `s` - The string to parse (e.g., "BGM", "laughter").
166    ///
167    /// # Returns
168    ///
169    /// An `Option<SenseVoiceEvent>` where `None` indicates an unrecognized event string.
170    fn from_str(s: &str) -> Option<Self> {
171        match s.to_uppercase().as_str() {
172            "BGM" => Some(SenseVoiceEvent::Bgm),
173            "SPEECH" => Some(SenseVoiceEvent::Speech),
174            "APPLAUSE" => Some(SenseVoiceEvent::Applause),
175            "LAUGHTER" => Some(SenseVoiceEvent::Laughter),
176            "CRY" => Some(SenseVoiceEvent::Cry),
177            "SNEEZE" => Some(SenseVoiceEvent::Sneeze),
178            "BREATH" => Some(SenseVoiceEvent::Breath),
179            "COUGH" => Some(SenseVoiceEvent::Cough),
180            "EVENT_UNK" => Some(SenseVoiceEvent::Unknown),
181            _ => None,
182        }
183    }
184}
185
186/// Represents options for punctuation normalization in transcribed text.
187///
188/// This enum defines whether punctuation is included or excluded in the output text.
189#[derive(Debug, Copy, Clone)]
190pub enum SenseVoicePunctuationNormalization {
191    /// Include punctuation in the text
192    With,
193    /// Exclude punctuation from the text
194    Woitn,
195}
196
197/// Implementation of methods for `SenseVoicePunctuationNormalization`.
198impl SenseVoicePunctuationNormalization {
199    /// Converts a string to a `SenseVoicePunctuationNormalization` variant.
200    ///
201    /// This method parses a string (case-insensitive) and returns the corresponding normalization variant.
202    ///
203    /// # Arguments
204    ///
205    /// * `s` - The string to parse (e.g., "with", "WOITN").
206    ///
207    /// # Returns
208    ///
209    /// An `Option<SenseVoicePunctuationNormalization>` where `None` indicates an unrecognized normalization string.
210    fn from_str(s: &str) -> Option<Self> {
211        match s.to_lowercase().as_str() {
212            "with" => Some(SenseVoicePunctuationNormalization::With),
213            "woitn" => Some(SenseVoicePunctuationNormalization::Woitn),
214            _ => None,
215        }
216    }
217}
218
219/// Represents a segment of audio with its transcribed text and associated metadata.
220///
221/// This structure holds the transcription result of an audio segment, including timing, language, emotion, event, and normalization details.
222#[derive(Debug)]
223pub struct VoiceText {
224    /// The language of the transcribed text.
225    pub language: SenseVoiceLanguage,
226    /// The detected emotion in the audio segment.
227    pub emotion: SenseVoiceEmo,
228    /// The type of audio event in the segment.
229    pub event: SenseVoiceEvent,
230    /// Indicates whether punctuation is included in the transcribed text.
231    pub punctuation_normalization: SenseVoicePunctuationNormalization,
232    /// The transcribed text of the audio segment.
233    pub content: String,
234}
235
236/// Parses a string line into a `VoiceText` instance based on a specific format.
237///
238/// The expected format is: `<|language|><|emotion|><|event|><|punctuation|><content>`
239///
240/// # Arguments
241///
242/// * `line` - The string to parse (e.g., "<|zh|><|HAPPY|><|BGM|><|woitn|>Hello").
243/// * `start_ms` - Start time of the segment in milliseconds.
244/// * `end_ms` - End time of the segment in milliseconds.
245///
246/// # Returns
247///
248/// An `Option<VoiceText>` where `None` indicates parsing failure due to invalid format or unrecognized tags.
249fn parse_line(line: &str) -> Option<VoiceText> {
250    let re = Regex::new(r"^<\|(.*?)\|><\|(.*?)\|><\|(.*?)\|><\|(.*?)\|>(.*)$").unwrap();
251    if let Some(caps) = re.captures(line) {
252        let lang_str = &caps[1];
253        let emo_str = &caps[2];
254        let event_str = &caps[3];
255        let punct_str = &caps[4];
256        let content = &caps[5];
257
258        let language = SenseVoiceLanguage::from_str(lang_str)?;
259        let emotion = SenseVoiceEmo::from_str(emo_str)?;
260        let event = SenseVoiceEvent::from_str(event_str)?;
261        let punctuation_normalization = SenseVoicePunctuationNormalization::from_str(punct_str)?;
262
263        Some(VoiceText {
264            language,
265            emotion,
266            event,
267            punctuation_normalization,
268            content: content.to_string(),
269        })
270    } else {
271        None
272    }
273}
274
275/// Represents an error specific to the `SenseVoiceSmall` system.
276///
277/// This structure encapsulates error messages related to initialization, inference, or resource management.
278#[derive(Debug)]
279struct SenseVoiceSmallError {
280    /// The error message describing the issue.
281    message: String,
282}
283
284/// Implements `Display` trait for `SenseVoiceSmallError` to format error messages.
285impl fmt::Display for SenseVoiceSmallError {
286    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
287        write!(f, "SenseVoiceSmallError: {}", self.message)
288    }
289}
290
291/// Implements `Error` trait for `SenseVoiceSmallError` to integrate with Rust's error handling system.
292impl std::error::Error for SenseVoiceSmallError {}
293
294/// Implementation of methods for `SenseVoiceSmallError`.
295impl SenseVoiceSmallError {
296    /// Creates a new `SenseVoiceSmallError` instance with the given message.
297    ///
298    /// # Arguments
299    ///
300    /// * `message` - The error message to encapsulate.
301    ///
302    /// # Returns
303    ///
304    /// A new `SenseVoiceSmallError` instance.
305    pub fn new(message: &str) -> Self {
306        SenseVoiceSmallError {
307            message: message.to_owned(),
308        }
309    }
310}
311
312/// Represents the core structure for the SenseVoiceSmall speech recognition system.
313///
314/// This structure manages components such as voice activity detection (VAD), automatic speech recognition (ASR),
315/// and inference backends (RKNN or Candle) for processing audio data.
316#[derive(Debug)]
317pub struct SenseVoiceSmall {
318    asr_frontend: WavFrontend,
319    #[cfg(feature = "rknpu")]
320    n_seq: usize,
321    spp: SentencePieceProcessor,
322
323    // RKNN specific fields
324    #[cfg(feature = "rknpu")]
325    rknn: Option<Rknn>,
326    #[cfg(feature = "rknpu")]
327    embedding: Option<ndarray::Array2<f32>>,
328
329    candle_pt_asr: Option<CandlePtAsrSession>,
330
331    // VAD
332    vad_config: VadConfig,
333    // silero_vad used in stream, but for batch we instantiate a new one to avoid state issues or use this one if mutexed?
334    // Let's keep a template or config. For stream we need stateful processor.
335    // For batch `infer_vec`, we should create a fresh processor to avoid carrying state.
336    // But `infer_stream` uses `silero_vad` field.
337    #[cfg(feature = "stream")]
338    silero_vad: VadProcessor,
339
340    use_rknn: bool,
341}
342
343/// Implementation of methods for `SenseVoiceSmall`.
344impl SenseVoiceSmall {
345    /// Initializes a new `SenseVoiceSmall` instance.
346    ///
347    /// If the `rknpu` feature is enabled, it initializes the RKNN backend using the default RKNN model.
348    /// Otherwise, it initializes the Candle backend using the official `model.pt`.
349    ///
350    /// # Arguments
351    ///
352    /// * `vadconfig` - Configuration for VAD (Silero VAD).
353    ///
354    /// # Errors
355    ///
356    /// Returns an error if model files cannot be loaded.
357    pub fn init(vadconfig: VadConfig) -> Result<Self, Box<dyn std::error::Error>> {
358        #[cfg(feature = "rknpu")]
359        {
360            // RKNN Path
361            let model_path = "happyme531/SenseVoiceSmall-RKNN2";
362
363            let api = Api::new().unwrap();
364            let repo = api.model(model_path.to_string());
365
366            // FSMN VAD files removed
367            let embedding_path = repo.get("embedding.npy")?;
368            let rknn_path = repo.get("sense-voice-encoder.rknn")?;
369            let sentence_path = repo.get("chn_jpn_yue_eng_ko_spectok.bpe.model")?;
370            let am_path = repo.get("am.mvn")?;
371
372            let config = SenseVoiceConfig {
373                model_path: rknn_path,
374                tokenizer_path: sentence_path,
375                cmvn_path: Some(am_path),
376            };
377
378            let embedding_file = File::open(embedding_path)?;
379            let embedding_reader = BufReader::new(embedding_file);
380            let embedding: Array2<f32> = Array2::read_npy(embedding_reader)?;
381            assert_eq!(embedding.shape()[1], 560, "Embedding dimension must be 560");
382
383            let rknn = Rknn::rknn_init(config.model_path)?;
384            let spp = SentencePieceProcessor::open(config.tokenizer_path)?;
385
386            let n_seq = 171;
387
388            // vad_frontend removed
389
390            let asr_frontend = WavFrontend::new(WavFrontendConfig {
391                lfr_m: 7,
392                cmvn_file: Some(
393                    config
394                        .cmvn_path
395                        .as_ref()
396                        .unwrap()
397                        .to_str()
398                        .unwrap()
399                        .to_owned(),
400                ),
401                ..Default::default()
402            })?;
403
404            #[cfg(feature = "stream")]
405            let silero_vad = VadProcessor::new(vadconfig)?;
406
407            Ok(SenseVoiceSmall {
408                asr_frontend,
409                n_seq,
410                spp,
411                rknn: Some(rknn),
412                embedding: Some(embedding),
413                candle_pt_asr: None,
414                vad_config: vadconfig,
415                #[cfg(feature = "stream")]
416                silero_vad,
417                use_rknn: true,
418            })
419        }
420        #[cfg(not(feature = "rknpu"))]
421        {
422            Self::init_official_model_pt(vadconfig)
423        }
424    }
425
426    /// Initializes using official `FunAudioLLM/SenseVoiceSmall` assets (`model.pt` path).
427    /// This uses the native Candle PT backend.
428    pub fn init_official_model_pt(
429        vadconfig: VadConfig,
430    ) -> Result<Self, Box<dyn std::error::Error>> {
431        let api = Api::new().unwrap();
432        let repo = api.model("FunAudioLLM/SenseVoiceSmall".to_owned());
433
434        let config = SenseVoiceConfig {
435            model_path: repo.get("model.pt")?,
436            tokenizer_path: repo.get("chn_jpn_yue_eng_ko_spectok.bpe.model")?,
437            cmvn_path: Some(repo.get("am.mvn")?),
438        };
439
440        Self::init_with_config(config, vadconfig)
441    }
442
443    /// Initializes a new `SenseVoiceSmall` instance with custom configuration.
444    ///
445    /// # Arguments
446    ///
447    /// * `config` - The configuration containing file paths.
448    /// * `vadconfig` - Configuration for VAD.
449    ///
450    /// # Errors
451    ///
452    /// Returns an error if model files cannot be loaded.
453    pub fn init_with_config(
454        config: SenseVoiceConfig,
455        vadconfig: VadConfig,
456    ) -> Result<Self, Box<dyn std::error::Error>> {
457        #[cfg(feature = "rknpu")]
458        {
459            // If rknpu feature is enabled, we check if it is an RKNN model
460            let is_rknn_model = config
461                .model_path
462                .extension()
463                .map_or(false, |ext| ext == "rknn");
464            if is_rknn_model {
465                return Err("Manual loading of RKNN models via init_with_config is not fully supported yet (missing embedding path). Use init() for default RKNN model.".into());
466            }
467        }
468
469        let is_pt_model = config
470            .model_path
471            .extension()
472            .and_then(|ext| ext.to_str())
473            .map(|ext| ext.eq_ignore_ascii_case("pt"))
474            .unwrap_or(false);
475        if !is_pt_model {
476            return Err(std::io::Error::other(
477                "Candle ASR now only supports official .pt model paths.",
478            )
479            .into());
480        }
481        let candle_pt_asr = Some(CandlePtAsrSession::new(&config.model_path)?);
482
483        let spp = SentencePieceProcessor::open(&config.tokenizer_path)?;
484
485        let asr_frontend = WavFrontend::new(WavFrontendConfig {
486            lfr_m: 7,
487            cmvn_file: config.cmvn_path.map(|p| p.to_string_lossy().to_string()),
488            ..Default::default()
489        })?;
490
491        #[cfg(feature = "stream")]
492        let silero_vad = VadProcessor::new(vadconfig)?;
493
494        #[cfg(feature = "rknpu")]
495        let n_seq = 0;
496
497        Ok(SenseVoiceSmall {
498            asr_frontend,
499            #[cfg(feature = "rknpu")]
500            n_seq,
501            spp,
502            #[cfg(feature = "rknpu")]
503            rknn: None,
504            #[cfg(feature = "rknpu")]
505            embedding: None,
506            candle_pt_asr,
507            vad_config: vadconfig,
508            #[cfg(feature = "stream")]
509            silero_vad,
510            use_rknn: false,
511        })
512    }
513
514    /// Updates the silence notification threshold for VAD.
515    /// If `ms` is Some, a NoSpeech event will be emitted once after `ms` milliseconds of continuous dropped audio (waiting state).
516    #[cfg(feature = "stream")]
517    pub fn set_vad_silence_notification(&mut self, ms: Option<u32>) {
518        self.silero_vad.set_notify_silence_after_ms(ms);
519    }
520
521    /// Performs speech recognition on a vector of audio samples.
522    pub fn infer_vec(
523        &self,
524        content: Vec<i16>,
525        _sample_rate: u32, // Silero VAD config has sample_rate
526    ) -> Result<Vec<VoiceText>, Box<dyn std::error::Error>> {
527        // Use Silero VAD to segment audio
528        let mut vad = VadProcessor::new(self.vad_config)?;
529        let mut ret = Vec::new();
530
531        let chunk_size = CHUNK_SIZE;
532        // Pad content to multiple of chunk_size
533        let mut padded_content = content.clone();
534        let remainder = padded_content.len() % chunk_size;
535        if remainder != 0 {
536            padded_content.extend(std::iter::repeat(0).take(chunk_size - remainder));
537        }
538
539        for chunk in padded_content.chunks_exact(chunk_size) {
540            let chunk_arr: &[i16; CHUNK_SIZE] = chunk.try_into()?;
541            if let Some(output) = vad.process_chunk(chunk_arr) {
542                match output {
543                    VadOutput::Segment(segment) => {
544                        let vt = self.recognition(&segment)?;
545                        ret.push(vt);
546                    }
547                    VadOutput::SilenceNotification => {
548                        // For batch infer, usually we don't need intermediate notifications,
549                        // but if configured in vad_config, we respect it.
550                        ret.push(VoiceText {
551                            language: SenseVoiceLanguage::NoSpeech,
552                            emotion: SenseVoiceEmo::Unknown,
553                            event: SenseVoiceEvent::Unknown,
554                            punctuation_normalization: SenseVoicePunctuationNormalization::Woitn,
555                            content: String::new(),
556                        });
557                    }
558                }
559            }
560        }
561
562        if let Some(output) = vad.finish() {
563            match output {
564                VadOutput::Segment(segment) => {
565                    let vt = self.recognition(&segment)?;
566                    ret.push(vt);
567                }
568                VadOutput::SilenceNotification => {
569                    // Should not happen in finish usually, but handle it
570                    ret.push(VoiceText {
571                        language: SenseVoiceLanguage::NoSpeech,
572                        emotion: SenseVoiceEmo::Unknown,
573                        event: SenseVoiceEvent::Unknown,
574                        punctuation_normalization: SenseVoicePunctuationNormalization::Woitn,
575                        content: String::new(),
576                    });
577                }
578            }
579        }
580
581        Ok(ret)
582    }
583
584    pub fn recognition(&self, segment: &[i16]) -> Result<VoiceText, Box<dyn std::error::Error>> {
585        // 提取特徵
586        let audio_feats = self.asr_frontend.extract_features(segment)?;
587
588        if self.use_rknn {
589            #[cfg(feature = "rknpu")]
590            {
591                if let Some(rknn) = &self.rknn {
592                    // 準備 RKNN 輸入
593                    self.prepare_rknn_input_advanced(&audio_feats, 0, false)?;
594                    rknn.run()?;
595                    let asr_output = rknn.outputs_get::<f32>()?;
596                    let asr_text = self.decode_asr_output(&asr_output)?;
597                    return match parse_line(&asr_text) {
598                        Some(vt) => Ok(vt),
599                        None => Err(format!("Parse line failed, text is:{}, If u still get empty text, please check your vad config. This model only can infer 9 secs voice.", asr_text).into()),
600                    };
601                }
602            }
603            return Err("RKNN is enabled but model is not initialized".into());
604        } else {
605            let seq_len = audio_feats.shape()[0] as i64;
606            let candle_pt_asr = self
607                .candle_pt_asr
608                .as_ref()
609                .ok_or_else(|| std::io::Error::other("Candle ASR session is not initialized"))?;
610            let (output_data, output_shape) = candle_pt_asr.run(&audio_feats, seq_len, 0, 15)?;
611            let asr_text = self.decode_onnx_output(&output_data, &output_shape)?;
612            return match parse_line(&asr_text) {
613                Some(vt) => Ok(vt),
614                None => Err(format!("Parse line failed, text is:{}", asr_text).into()),
615            };
616        }
617    }
618
619    #[cfg(feature = "stream")]
620    pub fn infer_stream<'a, S>(
621        &'a mut self,
622        input_stream: S,
623    ) -> impl Stream<Item = Result<VoiceText, Box<dyn std::error::Error>>> + 'a
624    where
625        S: Stream<Item = Vec<i16>> + Unpin + 'a,
626    {
627        stream! {
628        let mut stream = input_stream;
629        while let Some(chunk) = stream.next().await {
630            // Ensure chunk is 512 samples. If stream provides different sizes, we might need buffering.
631            // For now, assuming the stream provides correct chunks or we try to convert.
632            // process_chunk expects &[i16; 512].
633            if let Ok(chunk_arr) = chunk.as_slice().try_into() {
634                if let Some(output) = self.silero_vad.process_chunk(chunk_arr) {
635                    match output {
636                        VadOutput::Segment(segment) => {
637                            yield self.recognition(&segment);
638                        },
639                        VadOutput::SilenceNotification => {
640                            yield Ok(VoiceText {
641                                language: SenseVoiceLanguage::NoSpeech,
642                                emotion: SenseVoiceEmo::Unknown,
643                                event: SenseVoiceEvent::Unknown,
644                                punctuation_normalization: SenseVoicePunctuationNormalization::Woitn,
645                                content: String::new(),
646                            });
647                        }
648                    }
649                }
650            } else {
651                 // Handle mismatch size? For now ignore or log?
652            }
653        }
654        if let Some(output) = self.silero_vad.finish() {
655            match output {
656                VadOutput::Segment(segment) => {
657                    yield self.recognition(&segment);
658                },
659                VadOutput::SilenceNotification => {
660                     yield Ok(VoiceText {
661                        language: SenseVoiceLanguage::NoSpeech,
662                        emotion: SenseVoiceEmo::Unknown,
663                        event: SenseVoiceEvent::Unknown,
664                        punctuation_normalization: SenseVoicePunctuationNormalization::Woitn,
665                        content: String::new(),
666                    });
667                }
668            }
669        }
670        }
671    }
672
673    /// Performs speech recognition on an audio file.
674    pub fn infer_file<P: AsRef<std::path::Path>>(
675        &self,
676        wav_path: P,
677    ) -> Result<Vec<VoiceText>, Box<dyn std::error::Error>> {
678        let mut wav_reader = WavReader::open(wav_path)?;
679        match wav_reader.spec().sample_rate {
680            8000 => (),
681            16000 => (),
682            _ => {
683                return Err(Box::new(SenseVoiceSmallError::new(
684                    "Unsupported sample rate. Expect 8 kHz or 16 kHz.",
685                )))
686            }
687        };
688        if wav_reader.spec().sample_format != hound::SampleFormat::Int {
689            return Err(Box::new(SenseVoiceSmallError::new(
690                "Unsupported sample format. Expect Int.",
691            )));
692        }
693
694        let content = wav_reader
695            .samples()
696            .filter_map(|x| x.ok())
697            .collect::<Vec<i16>>();
698        if content.is_empty() {
699            return Err(Box::new(SenseVoiceSmallError::new(
700                "content is empty, check your audio file",
701            )));
702        }
703
704        self.infer_vec(content, wav_reader.spec().sample_rate)
705    }
706
707    /// Decodes RKNN output into a transcribed text string.
708    #[cfg(feature = "rknpu")]
709    fn decode_asr_output(&self, output: &[f32]) -> Result<String, Box<dyn std::error::Error>> {
710        // 解析為 [1, n_vocab, n_seq]
711        let n_vocab = self.spp.len();
712        // RKNN n_seq is fixed 171
713        let output_array = ArrayView3::from_shape((1, n_vocab, self.n_seq), output)?;
714
715        // 在 n_vocab 維度(Axis(1))上取 argmax
716        let token_ids: Vec<i32> = output_array
717            .axis_iter(Axis(2)) // 沿著 n_seq=171 維度迭代
718            .into_iter()
719            .map(|slice| {
720                slice
721                    .iter()
722                    .enumerate()
723                    .fold((0, f32::NEG_INFINITY), |(idx, max_val), (i, &val)| {
724                        if val > max_val {
725                            (i, val)
726                        } else {
727                            (idx, max_val)
728                        }
729                    })
730                    .0 as i32 // 提取最大值的索引
731            })
732            .collect();
733
734        self.ids_to_text(token_ids)
735    }
736
737    /// Helper to convert token IDs to text
738    fn ids_to_text(&self, token_ids: Vec<i32>) -> Result<String, Box<dyn std::error::Error>> {
739        // 移除連續重複的 token 和 blank_id=0
740        let mut unique_ids = Vec::new();
741        let mut prev_id = None;
742        for &id in token_ids.iter() {
743            if Some(id) != prev_id && id != 0 {
744                unique_ids.push(id as u32);
745                prev_id = Some(id);
746            } else if Some(id) != prev_id {
747                prev_id = Some(id);
748            }
749        }
750
751        // 解碼為文本
752        let decoded_text = self.spp.decode_piece_ids(&unique_ids)?;
753        Ok(decoded_text)
754    }
755
756    /// Decodes ONNX output.
757    fn decode_onnx_output(
758        &self,
759        output: &[f32],
760        shape: &[i64],
761    ) -> Result<String, Box<dyn std::error::Error>> {
762        // Shape is likely [1, T, Vocab] or [1, Vocab, T].
763        // If T is dynamic, we use it.
764        // Assuming [1, T, Vocab] which is common for CTC/Frame-level outputs from generic inference.
765        // But RKNN was [1, Vocab, T].
766        // Let's assume standard SenseVoice ONNX matches pytorch output: [Batch, Time, Vocab].
767
768        let batch_size = shape[0] as usize;
769        if batch_size != 1 {
770            return Err("Batch size must be 1".into());
771        }
772
773        // Guessing layout based on dimensions. Vocab size is ~25055.
774        // If dim 1 is ~25000, then it is [B, V, T].
775        // If dim 2 is ~25000, then it is [B, T, V].
776
777        let n_vocab = self.spp.len(); // ~25055
778        let dim1 = shape[1] as usize;
779        let dim2 = shape[2] as usize;
780
781        let output_array = ArrayView3::from_shape(
782            (shape[0] as usize, shape[1] as usize, shape[2] as usize),
783            output,
784        )?;
785        let mut token_ids = Vec::new();
786
787        if dim1 == n_vocab {
788            // [B, V, T] - iterate over T (dim 2)
789            for t in 0..dim2 {
790                // slice at time t: [1, V]
791                let col = output_array.slice(s![0, .., t]);
792                // argmax over V
793                let (best_idx, _) = col.iter().enumerate().fold(
794                    (0, f32::NEG_INFINITY),
795                    |(acc_idx, acc_val), (i, &val)| {
796                        if val > acc_val {
797                            (i, val)
798                        } else {
799                            (acc_idx, acc_val)
800                        }
801                    },
802                );
803                token_ids.push(best_idx as i32);
804            }
805        } else if dim2 == n_vocab {
806            // [B, T, V] - iterate over T (dim 1)
807            for t in 0..dim1 {
808                let row = output_array.slice(s![0, t, ..]);
809                let (best_idx, _) = row.iter().enumerate().fold(
810                    (0, f32::NEG_INFINITY),
811                    |(acc_idx, acc_val), (i, &val)| {
812                        if val > acc_val {
813                            (i, val)
814                        } else {
815                            (acc_idx, acc_val)
816                        }
817                    },
818                );
819                token_ids.push(best_idx as i32);
820            }
821        } else {
822            return Err(format!(
823                "Unexpected output shape: {:?}, expected one dimension to be vocab size {}",
824                shape, n_vocab
825            )
826            .into());
827        }
828
829        self.ids_to_text(token_ids)
830    }
831
832    /// Destroys the `SenseVoiceSmall` instance, releasing associated resources.
833    ///
834    /// This method ensures that the RKNN model resources are properly cleaned up.
835    ///
836    /// # Errors
837    ///
838    /// Returns an error if the RKNN model destruction fails.
839    ///
840    /// # Example
841    ///
842    /// ```
843    /// use sensevoice_rs::SenseVoiceSmall;
844    ///
845    /// let svs = SenseVoiceSmall::init().expect("Failed to initialize");
846    /// svs.destroy().expect("Failed to destroy SenseVoiceSmall");
847    /// ```
848    pub fn destroy(&self) -> Result<(), Box<dyn std::error::Error>> {
849        Ok(())
850    }
851
852    /// Prepares input data for RKNN inference with advanced configuration.
853    ///
854    /// This method constructs the input tensor by combining language embeddings, event/emotion embeddings,
855    /// text normalization embeddings, and scaled audio features, then pads or truncates it to match the expected shape.
856    ///
857    /// # Arguments
858    ///
859    /// * `feats` - A 2D array of audio features.
860    /// * `language` - Index of the language embedding to use (0 for auto).
861    /// * `use_itn` - Whether to use inverse text normalization (true) or not (false).
862    ///
863    /// # Errors
864    ///
865    /// Returns an error if tensor concatenation, padding, or RKNN input setting fails.
866    #[cfg(feature = "rknpu")]
867    fn prepare_rknn_input_advanced(
868        &self,
869        feats: &Array2<f32>,
870        language: usize,
871        use_itn: bool,
872    ) -> Result<(), Box<dyn std::error::Error>> {
873        // 提取嵌入向量
874        let embedding = self.embedding.as_ref().ok_or("Embedding not loaded")?;
875
876        let language_query = embedding.slice(s![language, ..]).insert_axis(Axis(0));
877        let text_norm_idx = if use_itn { 14 } else { 15 };
878        let text_norm_query = embedding.slice(s![text_norm_idx, ..]).insert_axis(Axis(0));
879        let event_emo_query = embedding.slice(s![1..=2, ..]).to_owned();
880
881        // 縮放語音特徵
882        let speech = feats.mapv(|x| x * 0.5);
883
884        // 沿著幀軸串接
885        let input_content = ndarray::concatenate(
886            Axis(0),
887            &[
888                language_query.view(),
889                event_emo_query.view(),
890                text_norm_query.view(),
891                speech.view(),
892            ],
893        )?;
894
895        // 填充或截斷至 [n_seq , 560]
896        let total_frames = input_content.shape()[0];
897        let padded_input = if total_frames < self.n_seq {
898            let mut padded = Array2::zeros((self.n_seq, 560));
899            padded
900                .slice_mut(s![..total_frames, ..])
901                .assign(&input_content);
902            padded
903        } else {
904            input_content.slice(s![..self.n_seq, ..]).to_owned()
905        };
906        // Add batch dimension
907        let input_3d: Array3<f32> = padded_input.insert_axis(Axis(0)); // [1, n_seq , 560]
908
909        // Ensure contiguous memory and flatten to 1D
910        let contiguous_input = input_3d.as_standard_layout(); // Row-major contiguous
911        let flattened_input: Vec<f32> = contiguous_input
912            .into_shape_with_order(1 * self.n_seq * 560)? // Flatten to [95760]
913            .to_vec(); // Owned Vec<f32>
914
915        if let Some(rknn) = &self.rknn {
916            rknn.input_set_slice(
917                0, // 根據您的輸入索引設定
918                &flattened_input,
919                false, // 通常設為 false,除非模型需要特殊處理
920                RknnTensorType::Float32,
921                RknnTensorFormat::NCHW,
922            )?;
923        }
924        Ok(())
925    }
926}