sherpa_transducers/
asr.rs

1use std::borrow::Cow;
2use std::ffi::CStr;
3use std::ops::DerefMut;
4use std::ptr::null;
5use std::sync::{Arc, Mutex, mpsc};
6
7use anyhow::{Result, anyhow, ensure};
8use sherpa_rs_sys::*;
9
10use crate::{DropCString, track_cstr};
11
12/// Configuration for a [Model]. See [Model::from_pretrained] for a simple way to get started.
13#[derive(Clone)]
14pub struct Config {
15    sample_rate: i32,
16    feature_dim: i32,
17    load: Arch,
18    tokens: String,
19    num_threads: i32,
20    provider: String,
21    debug: i32,
22    decoding_method: String,
23    max_active_paths: i32,
24    detect_endpoints: i32,
25    rule1_min_trailing_silence: f32,
26    rule2_min_trailing_silence: f32,
27    rule3_min_utterance_length: f32,
28}
29
30#[derive(Clone)]
31pub enum Arch {
32    Transducer {
33        encoder: String,
34        decoder: String,
35        joiner: String,
36    },
37
38    Paraformer {
39        encoder: String,
40        decoder: String,
41    },
42
43    Zip2Ctc {
44        model: String,
45    },
46}
47
48impl Config {
49    /// Make a new [Config] for a transducer model with reasonable defaults.
50    pub fn transducer(encoder: &str, decoder: &str, joiner: &str, tokens: &str) -> Self {
51        Self::new(
52            Arch::Transducer {
53                encoder: encoder.into(),
54                decoder: decoder.into(),
55                joiner: joiner.into(),
56            },
57            tokens,
58        )
59    }
60
61    /// Make a new [Config] for a paraformer model with reasonable defaults.
62    pub fn paraformer(encoder: &str, decoder: &str, tokens: &str) -> Self {
63        Self::new(
64            Arch::Paraformer {
65                encoder: encoder.into(),
66                decoder: decoder.into(),
67            },
68            tokens,
69        )
70    }
71
72    /// Make a new [Config] for a zipformer2 ctc model with reasonable defaults.
73    pub fn zipformer2_ctc(model: &str, tokens: &str) -> Self {
74        Self::new(Arch::Zip2Ctc { model: model.into() }, tokens)
75    }
76
77    fn new(load: Arch, tokens: &str) -> Self {
78        Self {
79            sample_rate: 16000,
80            feature_dim: 80,
81            load,
82            tokens: tokens.into(),
83            num_threads: crate::default_num_threads() as i32,
84            provider: crate::default_compute_provider().into(),
85            debug: 0,
86            decoding_method: "modified_beam_search".into(),
87            max_active_paths: 16,
88            detect_endpoints: 0,
89            rule1_min_trailing_silence: 0.,
90            rule2_min_trailing_silence: 0.,
91            rule3_min_utterance_length: 0.,
92        }
93    }
94
95    /// Set the model's sample rate - usually 16000 for most transducers.
96    pub fn sample_rate(mut self, rate: usize) -> Self {
97        self.sample_rate = rate as i32;
98        self
99    }
100
101    /// Set the model's feature dimension - usually 80 for most transducers.
102    pub fn feature_dim(mut self, dim: usize) -> Self {
103        self.feature_dim = dim as i32;
104        self
105    }
106
107    /// Set the number of threads to use. Defaults to physical core count or 8, whichever is smaller.
108    pub fn num_threads(mut self, n: usize) -> Self {
109        self.num_threads = n as i32;
110        self
111    }
112
113    /// Use CPU as the compute provider.
114    pub fn cpu(mut self) -> Self {
115        self.provider = "cpu".into();
116        self
117    }
118
119    /// Use CUDA as the compute provider. This requires CUDA 11.8.
120    #[cfg(feature = "cuda")]
121    #[cfg_attr(docsrs, doc(cfg(feature = "cuda")))]
122    pub fn cuda(mut self) -> Self {
123        self.provider = "cuda".into();
124        self
125    }
126
127    /// Use DirectML as the compute provider.
128    #[cfg(feature = "directml")]
129    #[cfg_attr(docsrs, doc(cfg(feature = "directml")))]
130    pub fn directml(mut self) -> Self {
131        self.provider = "directml".into();
132        self
133    }
134
135    /// Print debug messages at model load time.
136    pub fn debug(mut self, enable: bool) -> Self {
137        self.debug = if enable { 1 } else { 0 };
138        self
139    }
140
141    /// Take the symbol with largest posterior probability of each frame as the decoding result.
142    pub fn greedy_search(mut self) -> Self {
143        self.decoding_method = "greedy_search".into();
144        self
145    }
146
147    /// Keep topk states for each frame, then expand kept states with their own contexts to next frame.
148    pub fn modified_beam_search(mut self) -> Self {
149        self.decoding_method = "modified_beam_search".into();
150        self
151    }
152
153    /// Maximum number of active paths to keep when [Config::modified_beam_search] is used.
154    ///
155    /// Defaults to 16.
156    pub fn max_active_paths(mut self, n: usize) -> Self {
157        self.max_active_paths = n as i32;
158        self
159    }
160
161    /// Enable endpoint detection. Defaults to disabled.
162    pub fn detect_endpoints(mut self, enable: bool) -> Self {
163        self.detect_endpoints = if enable { 1 } else { 0 };
164        self
165    }
166
167    /// Detect endpoint if trailing silence is larger than this value even if nothing has been decoded.
168    pub fn rule1_min_trailing_silence(mut self, seconds: f32) -> Self {
169        self.rule1_min_trailing_silence = seconds;
170        self
171    }
172
173    /// Detect endpoint if trailing silence is larger than this value and a non-blank has been decoded.
174    pub fn rule2_min_trailing_silence(mut self, seconds: f32) -> Self {
175        self.rule2_min_trailing_silence = seconds;
176        self
177    }
178
179    /// Detect an endpoint if an utterance is larger than this value.
180    pub fn rule3_min_utterance_length(mut self, seconds: f32) -> Self {
181        self.rule3_min_utterance_length = seconds;
182        self
183    }
184
185    /// Build your very own [Model].
186    pub fn build(self) -> Result<Model> {
187        let mut config = online_config();
188
189        let mut _dcs = vec![];
190        let dcs = &mut _dcs;
191
192        config.feat_config.sample_rate = self.sample_rate;
193        config.feat_config.feature_dim = self.feature_dim;
194
195        match self.load {
196            Arch::Transducer { encoder, decoder, joiner } => {
197                config.model_config.transducer.encoder = track_cstr(dcs, &encoder);
198                config.model_config.transducer.decoder = track_cstr(dcs, &decoder);
199                config.model_config.transducer.joiner = track_cstr(dcs, &joiner);
200            }
201
202            Arch::Paraformer { encoder, decoder } => {
203                config.model_config.paraformer.encoder = track_cstr(dcs, &encoder);
204                config.model_config.paraformer.decoder = track_cstr(dcs, &decoder);
205            }
206
207            Arch::Zip2Ctc { model } => {
208                config.model_config.zipformer2_ctc.model = track_cstr(dcs, &model);
209            }
210        }
211
212        config.model_config.tokens = track_cstr(dcs, &self.tokens);
213        config.model_config.num_threads = self.num_threads;
214        config.model_config.provider = track_cstr(dcs, &self.provider);
215        config.model_config.debug = self.debug;
216        config.decoding_method = track_cstr(dcs, &self.decoding_method);
217        config.max_active_paths = self.max_active_paths;
218
219        // TODO: hotwords
220
221        let ptr = unsafe { SherpaOnnxCreateOnlineRecognizer(&config) };
222        ensure!(!ptr.is_null(), "failed to load transducer model");
223
224        let (tx, rx) = mpsc::channel();
225
226        let mut tdc = Model {
227            inner: Arc::new(ModelPtr { ptr, dcs: _dcs }),
228            sample_rate: self.sample_rate as usize,
229            chunk_size: 0,
230            tx,
231            rx: Arc::new(Mutex::new(rx)),
232        };
233
234        tdc.chunk_size = tdc.get_chunk_size()?;
235
236        Ok(tdc)
237    }
238}
239
240fn online_config() -> SherpaOnnxOnlineRecognizerConfig {
241    SherpaOnnxOnlineRecognizerConfig {
242        feat_config: SherpaOnnxFeatureConfig { sample_rate: 0, feature_dim: 0 },
243        model_config: SherpaOnnxOnlineModelConfig {
244            transducer: SherpaOnnxOnlineTransducerModelConfig {
245                encoder: null(),
246                decoder: null(),
247                joiner: null(),
248            },
249            paraformer: SherpaOnnxOnlineParaformerModelConfig {
250                encoder: null(),
251                decoder: null(),
252            },
253            zipformer2_ctc: SherpaOnnxOnlineZipformer2CtcModelConfig { model: null() },
254            tokens: null(),
255            tokens_buf: null(),
256            tokens_buf_size: 0,
257            num_threads: 0,
258            provider: null(),
259            debug: 0,
260            model_type: null(),
261            modeling_unit: null(),
262            bpe_vocab: null(),
263        },
264        decoding_method: null(),
265        max_active_paths: 0,
266        enable_endpoint: 0,
267        rule1_min_trailing_silence: 0.0,
268        rule2_min_trailing_silence: 0.0,
269        rule3_min_utterance_length: 0.0,
270        hotwords_file: null(),
271        hotwords_buf: null(),
272        hotwords_buf_size: 0,
273        hotwords_score: 0.0,
274        blank_penalty: 0.0,
275        rule_fsts: null(),
276        rule_fars: null(),
277        ctc_fst_decoder_config: SherpaOnnxOnlineCtcFstDecoderConfig {
278            graph: null(),
279            max_active: 0,
280        },
281    }
282}
283
284struct ModelPtr {
285    ptr: *const SherpaOnnxOnlineRecognizer,
286    // NOTE: unsure if sherpa-onnx accesses these pointers post-init; we err on the side of caution and
287    // keep them allocated until we drop the whole transducer.
288    #[allow(dead_code)]
289    dcs: Vec<DropCString>,
290}
291
292// SAFETY: thread locals? surely not
293unsafe impl Send for ModelPtr {}
294
295// SAFETY: afaik there is no interior mutability through &refs
296unsafe impl Sync for ModelPtr {}
297
298impl Drop for ModelPtr {
299    fn drop(&mut self) {
300        unsafe { SherpaOnnxDestroyOnlineRecognizer(self.ptr) }
301    }
302}
303
304/// Streaming zipformer transducer speech recognition model.
305#[derive(Clone)]
306pub struct Model {
307    inner: Arc<ModelPtr>,
308    sample_rate: usize,
309    chunk_size: usize,
310    tx: mpsc::Sender<OnlineStreamPtr>,
311    rx: Arc<Mutex<mpsc::Receiver<OnlineStreamPtr>>>,
312}
313
314impl Model {
315    /// Create a [Config] from a pretrained model on huggingface.
316    ///
317    /// ```no_run
318    /// # tokio_test::block_on(async {
319    /// use sherpa_transducers::asr;
320    ///
321    /// let model = asr::Model::from_pretrained("nytopop/nemo-conformer-transducer-en-80ms")
322    ///     .await?
323    ///     .build()?;
324    /// # Ok::<_, anyhow::Error>(())
325    /// # });
326    /// ```
327    #[cfg(feature = "download-models")]
328    #[cfg_attr(docsrs, doc(cfg(feature = "download-models")))]
329    pub async fn from_pretrained<S: AsRef<str>>(model: S) -> Result<Config> {
330        use hf_hub::api::tokio::ApiBuilder;
331        use tokio::fs;
332
333        let api = ApiBuilder::from_env().with_progress(true).build()?;
334        let repo = api.model(model.as_ref().into());
335        let conf = repo.get("config.json").await?;
336        let config = fs::read_to_string(conf).await?;
337
338        #[derive(serde::Deserialize)]
339        struct Conf {
340            kind: String,
341            arch: String,
342            decoding_method: Option<String>,
343        }
344
345        let Conf { kind, arch, decoding_method } = serde_json::from_str(&config)?;
346        ensure!(kind == "online_asr", "unknown model kind: {kind:?}");
347
348        let mut config = match arch.as_str() {
349            "transducer" => Config::transducer(
350                repo.get("encoder.onnx").await?.to_str().unwrap(),
351                repo.get("decoder.onnx").await?.to_str().unwrap(),
352                repo.get("joiner.onnx").await?.to_str().unwrap(),
353                repo.get("tokens.txt").await?.to_str().unwrap(),
354            ),
355
356            "paraformer" => Config::paraformer(
357                repo.get("encoder.onnx").await?.to_str().unwrap(),
358                repo.get("decoder.onnx").await?.to_str().unwrap(),
359                repo.get("tokens.txt").await?.to_str().unwrap(),
360            ),
361
362            "zipformer2_ctc" => Config::zipformer2_ctc(
363                repo.get("model.onnx").await?.to_str().unwrap(),
364                repo.get("tokens.txt").await?.to_str().unwrap(),
365            ),
366
367            _ => return Err(anyhow!("unknown model arch: {arch:?}")),
368        };
369
370        if let Some("greedy_search") = decoding_method.as_deref() {
371            config = config.greedy_search();
372        }
373
374        Ok(config)
375    }
376
377    /// Create a [Config] from a pretrained ASR model on huggingface without a `config.json`.
378    #[cfg(feature = "download-models")]
379    #[cfg_attr(docsrs, doc(cfg(feature = "download-models")))]
380    pub async fn from_pretrained_arch<S>(model: S, mut arch: Arch, tokens: S) -> Result<Config>
381    where
382        S: AsRef<str>,
383    {
384        use hf_hub::api::tokio::ApiBuilder;
385
386        let api = ApiBuilder::from_env().with_progress(true).build()?;
387        let repo = api.model(model.as_ref().into());
388
389        match &mut arch {
390            Arch::Transducer { encoder, decoder, joiner } => {
391                *encoder = repo.get(encoder).await?.to_str().unwrap().into();
392                *decoder = repo.get(decoder).await?.to_str().unwrap().into();
393                *joiner = repo.get(joiner).await?.to_str().unwrap().into();
394            }
395
396            Arch::Paraformer { encoder, decoder } => {
397                *encoder = repo.get(encoder).await?.to_str().unwrap().into();
398                *decoder = repo.get(decoder).await?.to_str().unwrap().into();
399            }
400
401            Arch::Zip2Ctc { model } => {
402                *model = repo.get(model).await?.to_str().unwrap().into();
403            }
404        }
405
406        let tokens = repo.get(tokens.as_ref()).await?;
407
408        Ok(Config::new(arch, tokens.to_str().unwrap()))
409    }
410
411    /// Make an [OnlineStream] for incremental speech recognition.
412    pub fn online_stream(&self) -> Result<OnlineStream> {
413        let tdc = self.clone();
414        let ptr = unsafe { SherpaOnnxCreateOnlineStream(self.as_ptr()) };
415        ensure!(!ptr.is_null(), "failed to create recognizer");
416
417        Ok(OnlineStream { tdc, ptr })
418    }
419
420    /// Make a [PhasedStream] for incremental speech recognition.
421    ///
422    /// Trades off increased compute utilization for lower latency transcriptions (sub chunk size).
423    pub fn phased_stream(&self, n_phase: usize) -> Result<PhasedStream> {
424        PhasedStream::new(n_phase, self)
425    }
426
427    /// Returns the native sample rate.
428    pub fn sample_rate(&self) -> usize {
429        self.sample_rate
430    }
431
432    /// Returns the chunk size at the native sample rate.
433    pub fn chunk_size(&self) -> usize {
434        self.chunk_size
435    }
436
437    fn get_chunk_size(&self) -> Result<usize> {
438        let mut s = self.online_stream()?;
439        let mut n = 0;
440
441        for _ in 0.. {
442            let mut k = 0;
443
444            while !s.is_ready() {
445                s.accept_waveform(self.sample_rate, &[0.]);
446                k += 1;
447            }
448            s.decode();
449
450            if n == k {
451                break;
452            }
453
454            n = k;
455        }
456
457        Ok(n)
458    }
459
460    // WARN: DO NOT MUTATE THROUGH THIS POINTER ON PAIN OF UNSOUNDNESS
461    fn as_ptr(&self) -> *const SherpaOnnxOnlineRecognizer {
462        self.inner.ptr
463    }
464}
465
466struct OnlineStreamPtr(*const SherpaOnnxOnlineStream);
467
468unsafe impl Send for OnlineStreamPtr {}
469
470unsafe impl Sync for OnlineStreamPtr {}
471
472/// Context state for streaming speech recognition.
473///
474/// You can do VAD if you want to reduce compute utilization, but feeding constant streaming audio into
475/// this is perfectly reasonable. Decoding is incremental and constant latency.
476///
477/// Created by [Model::online_stream].
478pub struct OnlineStream {
479    tdc: Model,
480    ptr: *const SherpaOnnxOnlineStream,
481}
482
483// SAFETY: thread locals? surely not
484unsafe impl Send for OnlineStream {}
485
486// SAFETY: afaik there is no interior mutability through &refs
487unsafe impl Sync for OnlineStream {}
488
489impl Drop for OnlineStream {
490    fn drop(&mut self) {
491        unsafe { SherpaOnnxDestroyOnlineStream(self.ptr) }
492    }
493}
494
495impl OnlineStream {
496    /// Flush extant buffers (feature frames) and signal that no further inputs will be made available.
497    ///
498    /// # Safety
499    /// Do not call [OnlineStream::accept_waveform] after calling this function.
500    ///
501    /// That restriction makes it quite useless, so ymmv. I have not observed any problems doing so so
502    /// long as an intervening call to [OnlineStream::reset] exists:
503    ///
504    /// ```skip
505    /// unsafe { s.flush_buffers() };
506    /// s.decode();
507    /// s.reset();
508    /// s.accept_waveform(16000, &[ ... ]);
509    /// ```
510    ///
511    /// Regardless, upstream docs state not to call [OnlineStream::accept_waveform] after, so do so at
512    /// your own risk.
513    pub unsafe fn flush_buffers(&mut self) {
514        // TODO: find answers to the following:
515        //
516        // 1. can stream state be recovered after calling this, or is it permanently kill?
517        // 2. flush -> reset  -> is it safe to accept_waveform again?
518        // 3. flush -> decode -> is it safe to accept_waveform again?
519        //
520        // after digging through the c sources, the rabbit hole continues on to kaldi which i didn't
521        // want to pull in just yet. another day.
522        unsafe { SherpaOnnxOnlineStreamInputFinished(self.ptr) }
523    }
524
525    /// Accept ((-1, 1)) normalized) input audio samples and buffer the computed feature frames.
526    pub fn accept_waveform(&mut self, sample_rate: usize, samples: &[f32]) {
527        unsafe {
528            SherpaOnnxOnlineStreamAcceptWaveform(
529                self.ptr,
530                sample_rate as i32,
531                samples.as_ptr(),
532                samples.len() as i32,
533            )
534        }
535    }
536
537    /// Returns true if there are enough feature frames for decoding.
538    pub fn is_ready(&self) -> bool {
539        unsafe { SherpaOnnxIsOnlineStreamReady(self.tdc.as_ptr(), self.ptr) == 1 }
540    }
541
542    /// Decode all available feature frames.
543    pub fn decode(&mut self) {
544        while self.is_ready() {
545            unsafe { self.decode_unchecked() }
546        }
547    }
548
549    /// Decode an unspecified number of feature frames.
550    ///
551    /// It is a logic error to call this function when [OnlineStream::is_ready] returns false.
552    ///
553    /// # Safety
554    /// Ensure [OnlineStream::is_ready] returns true. It is probably not ever worth eliding the check,
555    /// but hey, you do you.
556    pub unsafe fn decode_unchecked(&mut self) {
557        unsafe { SherpaOnnxDecodeOnlineStream(self.tdc.as_ptr(), self.ptr) }
558    }
559
560    /// Decode all available feature frames in the provided iterator of streams concurrently.
561    ///
562    /// This batches all operations together, and thus is superior to calling [OnlineStream::decode] on
563    /// every [OnlineStream] in separate threads (though it is not *invalid* to do so, if desired).
564    pub fn decode_batch<I: IntoIterator<Item = Q>, Q: DerefMut<Target = Self>>(streams: I) {
565        let mut streams = streams.into_iter().peekable();
566
567        // WARN: this may or may not be correct; what happens when [1..] have a different tdc? well, in
568        // that case something else is very sus, so let's silently ignore it and hope nobody does that.
569        let tdc = streams.peek().unwrap().tdc.as_ptr();
570
571        let mut masked: Vec<_> = streams
572            .filter_map(|s| s.is_ready().then_some(s.ptr))
573            .collect();
574
575        while !masked.is_empty() {
576            // only the masked subset of ready streams
577            unsafe {
578                SherpaOnnxDecodeMultipleOnlineStreams(tdc, masked.as_mut_ptr(), masked.len() as i32)
579            }
580
581            // remove any streams that aren't ready
582            masked.retain(|&ptr| unsafe { SherpaOnnxIsOnlineStreamReady(tdc, ptr) } == 1);
583        }
584    }
585
586    /// Decode all available feature frames in a shared concurrency context.
587    ///
588    /// This introduces a small amount of synchronization overhead in exchange for much better compute
589    /// utilization.
590    pub fn decode_shared(&mut self) {
591        // ensure our ptr is in the shared queue (for the case where we don't acquire the lock first)
592        self.tdc.tx.send(OnlineStreamPtr(self.ptr)).unwrap();
593
594        let que = self.tdc.rx.lock().unwrap();
595        let tdc = self.tdc.as_ptr();
596
597        let mut masked: Vec<_> = que
598            .try_iter()
599            .map(|p| p.0)
600            .filter(|&ptr| unsafe { SherpaOnnxIsOnlineStreamReady(tdc, ptr) } == 1)
601            .collect();
602
603        while !masked.is_empty() {
604            // only the masked subset of ready streams
605            unsafe {
606                SherpaOnnxDecodeMultipleOnlineStreams(tdc, masked.as_mut_ptr(), masked.len() as i32)
607            }
608
609            // remove any streams that aren't ready
610            masked.retain(|&ptr| unsafe { SherpaOnnxIsOnlineStreamReady(tdc, ptr) } == 1);
611        }
612    }
613
614    /// Returns recognition state since the last call to [OnlineStream::reset].
615    pub fn result(&self) -> Result<String> {
616        self.result_with(|cow| cow.into_owned())
617    }
618
619    /// Returns recognition state since the last call to [OnlineStream::reset].
620    pub fn result_with<F: FnOnce(Cow<'_, str>) -> R, R>(&self, f: F) -> Result<R> {
621        unsafe {
622            let res = SherpaOnnxGetOnlineStreamResult(self.tdc.as_ptr(), self.ptr);
623            ensure!(!res.is_null(), "failed to get online stream result");
624
625            let txt = (*res).text;
626            ensure!(!txt.is_null(), "failed to get online stream result");
627
628            let out = f(CStr::from_ptr(txt).to_string_lossy());
629
630            SherpaOnnxDestroyOnlineRecognizerResult(res);
631
632            Ok(out)
633        }
634    }
635
636    /// Returns true if an endpoint has been detected.
637    pub fn is_endpoint(&self) -> bool {
638        unsafe { SherpaOnnxOnlineStreamIsEndpoint(self.tdc.as_ptr(), self.ptr) == 1 }
639    }
640
641    /// Clear any extant neural network and decoder states.
642    pub fn reset(&mut self) {
643        unsafe { SherpaOnnxOnlineStreamReset(self.tdc.as_ptr(), self.ptr) }
644    }
645
646    /// Returns the native sample rate.
647    pub fn sample_rate(&self) -> usize {
648        self.tdc.sample_rate()
649    }
650
651    /// Returns the chunk size at the native sample rate.
652    ///
653    /// The stream becomes ready for decoding once this many samples have been accepted.
654    pub fn chunk_size(&self) -> usize {
655        self.tdc.chunk_size()
656    }
657}
658
659/// A wrapper around multiple phase-shifted [OnlineStream] states. The use case is latency reduction at
660/// the cost of additional compute load.
661///
662/// For example, a transducer with a chunk size of 320ms has worst-case transcription latency of 320ms;
663/// it must be fed with 320ms chunks of audio before producing any results. If an utterance lies at the
664/// beginning of a chunk, you must wait until the rest arrives before it can be transcribed.
665///
666/// In a [PhasedStream] with `n_phase == 2`, the worst-case latency is reduced to 160ms, though compute
667/// utilization is approximately doubled.
668///
669/// This does not mean latency due to compute is doubled; if used correctly, that remains constant. Let
670/// Q be the amount of time it takes to transcribe a 320ms chunk: we can feed the transducer with 160ms
671/// chunks and expect processing to take Q as well. Instead of paying Q every 320ms we now pay it every
672/// 160ms.
673///
674/// Likewise, with `n_phase == 3`, we could feed 106.7ms chunks and expect to pay Q every 106.7ms. More
675/// generally, the computational cost of transcribing a chunk remains constant while the chunk count in
676/// a given time window scales linearly with the number of phases.
677///
678/// For most zipformer transducers, RTF is favorable (Q is low) and the extra load can be an acceptable
679/// trade off for the observed latency improvement.
680///
681/// Created by [Model::phased_stream].
682// TODO: look into the underlying implementation to see if we can fuse beam states: having disconnected
683// beams is not super optimal even though it does work
684// TODO: support hooking into external continuous batch rendezvous point for moar throughput
685pub struct PhasedStream {
686    phase: Vec<OnlineStream>,
687    state: Vec<String>,
688    epoch: Vec<usize>,
689    flush: f32,
690}
691
692impl PhasedStream {
693    /// Make a new [PhasedStream].
694    fn new(n_phase: usize, transducer: &Model) -> Result<Self> {
695        let mut phase = vec![];
696        let mut epoch = vec![];
697
698        for i in 0..n_phase {
699            let mut p = transducer.online_stream()?;
700            let q = vec![0.; p.chunk_size() / n_phase * i];
701
702            // push p out of phase (it will stay that way forever)
703            p.accept_waveform(p.sample_rate(), &q);
704
705            epoch.push(p.chunk_size() / n_phase * i);
706            phase.push(p);
707        }
708
709        Ok(Self {
710            phase,
711            state: vec!["".into(); n_phase],
712            epoch,
713            flush: 0.,
714        })
715    }
716
717    /// Accept ((-1, 1)) normalized) input audio samples and buffer the computed feature frames.
718    pub fn accept_waveform(&mut self, sample_rate: usize, samples: &[f32]) {
719        for p in self.phase.iter_mut() {
720            p.accept_waveform(sample_rate, samples);
721        }
722
723        // convert to the native sample rate before incrementing
724        self.flush +=
725            sample_rate as f32 / self.phase[0].sample_rate() as f32 * samples.len() as f32;
726    }
727
728    /// Decode all available feature frames.
729    pub fn decode(&mut self) {
730        if self.flush == 0. {
731            return;
732        }
733
734        // WARN: technically batched but not really because they're out of phase. increasing our overall
735        // throughput would require synchronizing with online streams external to the local context
736        OnlineStream::decode_batch(&mut self.phase);
737
738        for i in 0..self.phase.len() {
739            self.epoch[i] += self.flush.round() as usize;
740        }
741
742        self.flush = 0.;
743    }
744
745    /// Returns recognition state since the last call to [PhasedStream::reset].
746    pub fn result(&mut self) -> Result<(usize, String)> {
747        for i in 0..self.phase.len() {
748            self.state[i] = self.phase[i].result()?;
749        }
750
751        let (i, _) = (0..self.phase.len())
752            .map(|i| (i, self.epoch[i] % self.phase[i].chunk_size()))
753            .min_by_key(|&(_, m)| m)
754            .unwrap();
755
756        Ok((self.epoch[i], self.state[i].clone()))
757    }
758
759    /// Clear any extant neural network and decoder states.
760    pub fn reset(&mut self) {
761        for p in self.phase.iter_mut() {
762            unsafe { p.flush_buffers() }
763            p.reset();
764        }
765    }
766}