1use std::borrow::Cow;
2use std::ffi::CStr;
3use std::ops::DerefMut;
4use std::ptr::null;
5use std::sync::{Arc, Mutex, mpsc};
6
7use anyhow::{Result, anyhow, ensure};
8use sherpa_rs_sys::*;
9
10use crate::{DropCString, track_cstr};
11
12#[derive(Clone)]
14pub struct Config {
15 sample_rate: i32,
16 feature_dim: i32,
17 load: Arch,
18 tokens: String,
19 num_threads: i32,
20 provider: String,
21 debug: i32,
22 decoding_method: String,
23 max_active_paths: i32,
24 detect_endpoints: i32,
25 rule1_min_trailing_silence: f32,
26 rule2_min_trailing_silence: f32,
27 rule3_min_utterance_length: f32,
28}
29
30#[derive(Clone)]
31pub enum Arch {
32 Transducer {
33 encoder: String,
34 decoder: String,
35 joiner: String,
36 },
37
38 Paraformer {
39 encoder: String,
40 decoder: String,
41 },
42
43 Zip2Ctc {
44 model: String,
45 },
46}
47
48impl Config {
49 pub fn transducer(encoder: &str, decoder: &str, joiner: &str, tokens: &str) -> Self {
51 Self::new(
52 Arch::Transducer {
53 encoder: encoder.into(),
54 decoder: decoder.into(),
55 joiner: joiner.into(),
56 },
57 tokens,
58 )
59 }
60
61 pub fn paraformer(encoder: &str, decoder: &str, tokens: &str) -> Self {
63 Self::new(
64 Arch::Paraformer {
65 encoder: encoder.into(),
66 decoder: decoder.into(),
67 },
68 tokens,
69 )
70 }
71
72 pub fn zipformer2_ctc(model: &str, tokens: &str) -> Self {
74 Self::new(Arch::Zip2Ctc { model: model.into() }, tokens)
75 }
76
77 fn new(load: Arch, tokens: &str) -> Self {
78 Self {
79 sample_rate: 16000,
80 feature_dim: 80,
81 load,
82 tokens: tokens.into(),
83 num_threads: crate::default_num_threads() as i32,
84 provider: crate::default_compute_provider().into(),
85 debug: 0,
86 decoding_method: "modified_beam_search".into(),
87 max_active_paths: 16,
88 detect_endpoints: 0,
89 rule1_min_trailing_silence: 0.,
90 rule2_min_trailing_silence: 0.,
91 rule3_min_utterance_length: 0.,
92 }
93 }
94
95 pub fn sample_rate(mut self, rate: usize) -> Self {
97 self.sample_rate = rate as i32;
98 self
99 }
100
101 pub fn feature_dim(mut self, dim: usize) -> Self {
103 self.feature_dim = dim as i32;
104 self
105 }
106
107 pub fn num_threads(mut self, n: usize) -> Self {
109 self.num_threads = n as i32;
110 self
111 }
112
113 pub fn cpu(mut self) -> Self {
115 self.provider = "cpu".into();
116 self
117 }
118
119 #[cfg(feature = "cuda")]
121 #[cfg_attr(docsrs, doc(cfg(feature = "cuda")))]
122 pub fn cuda(mut self) -> Self {
123 self.provider = "cuda".into();
124 self
125 }
126
127 #[cfg(feature = "directml")]
129 #[cfg_attr(docsrs, doc(cfg(feature = "directml")))]
130 pub fn directml(mut self) -> Self {
131 self.provider = "directml".into();
132 self
133 }
134
135 pub fn debug(mut self, enable: bool) -> Self {
137 self.debug = if enable { 1 } else { 0 };
138 self
139 }
140
141 pub fn greedy_search(mut self) -> Self {
143 self.decoding_method = "greedy_search".into();
144 self
145 }
146
147 pub fn modified_beam_search(mut self) -> Self {
149 self.decoding_method = "modified_beam_search".into();
150 self
151 }
152
153 pub fn max_active_paths(mut self, n: usize) -> Self {
157 self.max_active_paths = n as i32;
158 self
159 }
160
161 pub fn detect_endpoints(mut self, enable: bool) -> Self {
163 self.detect_endpoints = if enable { 1 } else { 0 };
164 self
165 }
166
167 pub fn rule1_min_trailing_silence(mut self, seconds: f32) -> Self {
169 self.rule1_min_trailing_silence = seconds;
170 self
171 }
172
173 pub fn rule2_min_trailing_silence(mut self, seconds: f32) -> Self {
175 self.rule2_min_trailing_silence = seconds;
176 self
177 }
178
179 pub fn rule3_min_utterance_length(mut self, seconds: f32) -> Self {
181 self.rule3_min_utterance_length = seconds;
182 self
183 }
184
185 pub fn build(self) -> Result<Model> {
187 let mut config = online_config();
188
189 let mut _dcs = vec![];
190 let dcs = &mut _dcs;
191
192 config.feat_config.sample_rate = self.sample_rate;
193 config.feat_config.feature_dim = self.feature_dim;
194
195 match self.load {
196 Arch::Transducer { encoder, decoder, joiner } => {
197 config.model_config.transducer.encoder = track_cstr(dcs, &encoder);
198 config.model_config.transducer.decoder = track_cstr(dcs, &decoder);
199 config.model_config.transducer.joiner = track_cstr(dcs, &joiner);
200 }
201
202 Arch::Paraformer { encoder, decoder } => {
203 config.model_config.paraformer.encoder = track_cstr(dcs, &encoder);
204 config.model_config.paraformer.decoder = track_cstr(dcs, &decoder);
205 }
206
207 Arch::Zip2Ctc { model } => {
208 config.model_config.zipformer2_ctc.model = track_cstr(dcs, &model);
209 }
210 }
211
212 config.model_config.tokens = track_cstr(dcs, &self.tokens);
213 config.model_config.num_threads = self.num_threads;
214 config.model_config.provider = track_cstr(dcs, &self.provider);
215 config.model_config.debug = self.debug;
216 config.decoding_method = track_cstr(dcs, &self.decoding_method);
217 config.max_active_paths = self.max_active_paths;
218
219 let ptr = unsafe { SherpaOnnxCreateOnlineRecognizer(&config) };
222 ensure!(!ptr.is_null(), "failed to load transducer model");
223
224 let (tx, rx) = mpsc::channel();
225
226 let mut tdc = Model {
227 inner: Arc::new(ModelPtr { ptr, dcs: _dcs }),
228 sample_rate: self.sample_rate as usize,
229 chunk_size: 0,
230 tx,
231 rx: Arc::new(Mutex::new(rx)),
232 };
233
234 tdc.chunk_size = tdc.get_chunk_size()?;
235
236 Ok(tdc)
237 }
238}
239
240fn online_config() -> SherpaOnnxOnlineRecognizerConfig {
241 SherpaOnnxOnlineRecognizerConfig {
242 feat_config: SherpaOnnxFeatureConfig { sample_rate: 0, feature_dim: 0 },
243 model_config: SherpaOnnxOnlineModelConfig {
244 transducer: SherpaOnnxOnlineTransducerModelConfig {
245 encoder: null(),
246 decoder: null(),
247 joiner: null(),
248 },
249 paraformer: SherpaOnnxOnlineParaformerModelConfig {
250 encoder: null(),
251 decoder: null(),
252 },
253 zipformer2_ctc: SherpaOnnxOnlineZipformer2CtcModelConfig { model: null() },
254 tokens: null(),
255 tokens_buf: null(),
256 tokens_buf_size: 0,
257 num_threads: 0,
258 provider: null(),
259 debug: 0,
260 model_type: null(),
261 modeling_unit: null(),
262 bpe_vocab: null(),
263 },
264 decoding_method: null(),
265 max_active_paths: 0,
266 enable_endpoint: 0,
267 rule1_min_trailing_silence: 0.0,
268 rule2_min_trailing_silence: 0.0,
269 rule3_min_utterance_length: 0.0,
270 hotwords_file: null(),
271 hotwords_buf: null(),
272 hotwords_buf_size: 0,
273 hotwords_score: 0.0,
274 blank_penalty: 0.0,
275 rule_fsts: null(),
276 rule_fars: null(),
277 ctc_fst_decoder_config: SherpaOnnxOnlineCtcFstDecoderConfig {
278 graph: null(),
279 max_active: 0,
280 },
281 }
282}
283
284struct ModelPtr {
285 ptr: *const SherpaOnnxOnlineRecognizer,
286 #[allow(dead_code)]
289 dcs: Vec<DropCString>,
290}
291
292unsafe impl Send for ModelPtr {}
294
295unsafe impl Sync for ModelPtr {}
297
298impl Drop for ModelPtr {
299 fn drop(&mut self) {
300 unsafe { SherpaOnnxDestroyOnlineRecognizer(self.ptr) }
301 }
302}
303
304#[derive(Clone)]
306pub struct Model {
307 inner: Arc<ModelPtr>,
308 sample_rate: usize,
309 chunk_size: usize,
310 tx: mpsc::Sender<OnlineStreamPtr>,
311 rx: Arc<Mutex<mpsc::Receiver<OnlineStreamPtr>>>,
312}
313
314impl Model {
315 #[cfg(feature = "download-models")]
328 #[cfg_attr(docsrs, doc(cfg(feature = "download-models")))]
329 pub async fn from_pretrained<S: AsRef<str>>(model: S) -> Result<Config> {
330 use hf_hub::api::tokio::ApiBuilder;
331 use tokio::fs;
332
333 let api = ApiBuilder::from_env().with_progress(true).build()?;
334 let repo = api.model(model.as_ref().into());
335 let conf = repo.get("config.json").await?;
336 let config = fs::read_to_string(conf).await?;
337
338 #[derive(serde::Deserialize)]
339 struct Conf {
340 kind: String,
341 arch: String,
342 decoding_method: Option<String>,
343 }
344
345 let Conf { kind, arch, decoding_method } = serde_json::from_str(&config)?;
346 ensure!(kind == "online_asr", "unknown model kind: {kind:?}");
347
348 let mut config = match arch.as_str() {
349 "transducer" => Config::transducer(
350 repo.get("encoder.onnx").await?.to_str().unwrap(),
351 repo.get("decoder.onnx").await?.to_str().unwrap(),
352 repo.get("joiner.onnx").await?.to_str().unwrap(),
353 repo.get("tokens.txt").await?.to_str().unwrap(),
354 ),
355
356 "paraformer" => Config::paraformer(
357 repo.get("encoder.onnx").await?.to_str().unwrap(),
358 repo.get("decoder.onnx").await?.to_str().unwrap(),
359 repo.get("tokens.txt").await?.to_str().unwrap(),
360 ),
361
362 "zipformer2_ctc" => Config::zipformer2_ctc(
363 repo.get("model.onnx").await?.to_str().unwrap(),
364 repo.get("tokens.txt").await?.to_str().unwrap(),
365 ),
366
367 _ => return Err(anyhow!("unknown model arch: {arch:?}")),
368 };
369
370 if let Some("greedy_search") = decoding_method.as_deref() {
371 config = config.greedy_search();
372 }
373
374 Ok(config)
375 }
376
377 #[cfg(feature = "download-models")]
379 #[cfg_attr(docsrs, doc(cfg(feature = "download-models")))]
380 pub async fn from_pretrained_arch<S>(model: S, mut arch: Arch, tokens: S) -> Result<Config>
381 where
382 S: AsRef<str>,
383 {
384 use hf_hub::api::tokio::ApiBuilder;
385
386 let api = ApiBuilder::from_env().with_progress(true).build()?;
387 let repo = api.model(model.as_ref().into());
388
389 match &mut arch {
390 Arch::Transducer { encoder, decoder, joiner } => {
391 *encoder = repo.get(encoder).await?.to_str().unwrap().into();
392 *decoder = repo.get(decoder).await?.to_str().unwrap().into();
393 *joiner = repo.get(joiner).await?.to_str().unwrap().into();
394 }
395
396 Arch::Paraformer { encoder, decoder } => {
397 *encoder = repo.get(encoder).await?.to_str().unwrap().into();
398 *decoder = repo.get(decoder).await?.to_str().unwrap().into();
399 }
400
401 Arch::Zip2Ctc { model } => {
402 *model = repo.get(model).await?.to_str().unwrap().into();
403 }
404 }
405
406 let tokens = repo.get(tokens.as_ref()).await?;
407
408 Ok(Config::new(arch, tokens.to_str().unwrap()))
409 }
410
411 pub fn online_stream(&self) -> Result<OnlineStream> {
413 let tdc = self.clone();
414 let ptr = unsafe { SherpaOnnxCreateOnlineStream(self.as_ptr()) };
415 ensure!(!ptr.is_null(), "failed to create recognizer");
416
417 Ok(OnlineStream { tdc, ptr })
418 }
419
420 pub fn phased_stream(&self, n_phase: usize) -> Result<PhasedStream> {
424 PhasedStream::new(n_phase, self)
425 }
426
427 pub fn sample_rate(&self) -> usize {
429 self.sample_rate
430 }
431
432 pub fn chunk_size(&self) -> usize {
434 self.chunk_size
435 }
436
437 fn get_chunk_size(&self) -> Result<usize> {
438 let mut s = self.online_stream()?;
439 let mut n = 0;
440
441 for _ in 0.. {
442 let mut k = 0;
443
444 while !s.is_ready() {
445 s.accept_waveform(self.sample_rate, &[0.]);
446 k += 1;
447 }
448 s.decode();
449
450 if n == k {
451 break;
452 }
453
454 n = k;
455 }
456
457 Ok(n)
458 }
459
460 fn as_ptr(&self) -> *const SherpaOnnxOnlineRecognizer {
462 self.inner.ptr
463 }
464}
465
466struct OnlineStreamPtr(*const SherpaOnnxOnlineStream);
467
468unsafe impl Send for OnlineStreamPtr {}
469
470unsafe impl Sync for OnlineStreamPtr {}
471
472pub struct OnlineStream {
479 tdc: Model,
480 ptr: *const SherpaOnnxOnlineStream,
481}
482
483unsafe impl Send for OnlineStream {}
485
486unsafe impl Sync for OnlineStream {}
488
489impl Drop for OnlineStream {
490 fn drop(&mut self) {
491 unsafe { SherpaOnnxDestroyOnlineStream(self.ptr) }
492 }
493}
494
495impl OnlineStream {
496 pub unsafe fn flush_buffers(&mut self) {
514 unsafe { SherpaOnnxOnlineStreamInputFinished(self.ptr) }
523 }
524
525 pub fn accept_waveform(&mut self, sample_rate: usize, samples: &[f32]) {
527 unsafe {
528 SherpaOnnxOnlineStreamAcceptWaveform(
529 self.ptr,
530 sample_rate as i32,
531 samples.as_ptr(),
532 samples.len() as i32,
533 )
534 }
535 }
536
537 pub fn is_ready(&self) -> bool {
539 unsafe { SherpaOnnxIsOnlineStreamReady(self.tdc.as_ptr(), self.ptr) == 1 }
540 }
541
542 pub fn decode(&mut self) {
544 while self.is_ready() {
545 unsafe { self.decode_unchecked() }
546 }
547 }
548
549 pub unsafe fn decode_unchecked(&mut self) {
557 unsafe { SherpaOnnxDecodeOnlineStream(self.tdc.as_ptr(), self.ptr) }
558 }
559
560 pub fn decode_batch<I: IntoIterator<Item = Q>, Q: DerefMut<Target = Self>>(streams: I) {
565 let mut streams = streams.into_iter().peekable();
566
567 let tdc = streams.peek().unwrap().tdc.as_ptr();
570
571 let mut masked: Vec<_> = streams
572 .filter_map(|s| s.is_ready().then_some(s.ptr))
573 .collect();
574
575 while !masked.is_empty() {
576 unsafe {
578 SherpaOnnxDecodeMultipleOnlineStreams(tdc, masked.as_mut_ptr(), masked.len() as i32)
579 }
580
581 masked.retain(|&ptr| unsafe { SherpaOnnxIsOnlineStreamReady(tdc, ptr) } == 1);
583 }
584 }
585
586 pub fn decode_shared(&mut self) {
591 self.tdc.tx.send(OnlineStreamPtr(self.ptr)).unwrap();
593
594 let que = self.tdc.rx.lock().unwrap();
595 let tdc = self.tdc.as_ptr();
596
597 let mut masked: Vec<_> = que
598 .try_iter()
599 .map(|p| p.0)
600 .filter(|&ptr| unsafe { SherpaOnnxIsOnlineStreamReady(tdc, ptr) } == 1)
601 .collect();
602
603 while !masked.is_empty() {
604 unsafe {
606 SherpaOnnxDecodeMultipleOnlineStreams(tdc, masked.as_mut_ptr(), masked.len() as i32)
607 }
608
609 masked.retain(|&ptr| unsafe { SherpaOnnxIsOnlineStreamReady(tdc, ptr) } == 1);
611 }
612 }
613
614 pub fn result(&self) -> Result<String> {
616 self.result_with(|cow| cow.into_owned())
617 }
618
619 pub fn result_with<F: FnOnce(Cow<'_, str>) -> R, R>(&self, f: F) -> Result<R> {
621 unsafe {
622 let res = SherpaOnnxGetOnlineStreamResult(self.tdc.as_ptr(), self.ptr);
623 ensure!(!res.is_null(), "failed to get online stream result");
624
625 let txt = (*res).text;
626 ensure!(!txt.is_null(), "failed to get online stream result");
627
628 let out = f(CStr::from_ptr(txt).to_string_lossy());
629
630 SherpaOnnxDestroyOnlineRecognizerResult(res);
631
632 Ok(out)
633 }
634 }
635
636 pub fn is_endpoint(&self) -> bool {
638 unsafe { SherpaOnnxOnlineStreamIsEndpoint(self.tdc.as_ptr(), self.ptr) == 1 }
639 }
640
641 pub fn reset(&mut self) {
643 unsafe { SherpaOnnxOnlineStreamReset(self.tdc.as_ptr(), self.ptr) }
644 }
645
646 pub fn sample_rate(&self) -> usize {
648 self.tdc.sample_rate()
649 }
650
651 pub fn chunk_size(&self) -> usize {
655 self.tdc.chunk_size()
656 }
657}
658
659pub struct PhasedStream {
686 phase: Vec<OnlineStream>,
687 state: Vec<String>,
688 epoch: Vec<usize>,
689 flush: f32,
690}
691
692impl PhasedStream {
693 fn new(n_phase: usize, transducer: &Model) -> Result<Self> {
695 let mut phase = vec![];
696 let mut epoch = vec![];
697
698 for i in 0..n_phase {
699 let mut p = transducer.online_stream()?;
700 let q = vec![0.; p.chunk_size() / n_phase * i];
701
702 p.accept_waveform(p.sample_rate(), &q);
704
705 epoch.push(p.chunk_size() / n_phase * i);
706 phase.push(p);
707 }
708
709 Ok(Self {
710 phase,
711 state: vec!["".into(); n_phase],
712 epoch,
713 flush: 0.,
714 })
715 }
716
717 pub fn accept_waveform(&mut self, sample_rate: usize, samples: &[f32]) {
719 for p in self.phase.iter_mut() {
720 p.accept_waveform(sample_rate, samples);
721 }
722
723 self.flush +=
725 sample_rate as f32 / self.phase[0].sample_rate() as f32 * samples.len() as f32;
726 }
727
728 pub fn decode(&mut self) {
730 if self.flush == 0. {
731 return;
732 }
733
734 OnlineStream::decode_batch(&mut self.phase);
737
738 for i in 0..self.phase.len() {
739 self.epoch[i] += self.flush.round() as usize;
740 }
741
742 self.flush = 0.;
743 }
744
745 pub fn result(&mut self) -> Result<(usize, String)> {
747 for i in 0..self.phase.len() {
748 self.state[i] = self.phase[i].result()?;
749 }
750
751 let (i, _) = (0..self.phase.len())
752 .map(|i| (i, self.epoch[i] % self.phase[i].chunk_size()))
753 .min_by_key(|&(_, m)| m)
754 .unwrap();
755
756 Ok((self.epoch[i], self.state[i].clone()))
757 }
758
759 pub fn reset(&mut self) {
761 for p in self.phase.iter_mut() {
762 unsafe { p.flush_buffers() }
763 p.reset();
764 }
765 }
766}