1use bon::bon;
17use snafu::{ResultExt, Snafu};
18use svod_arch::ctc::CtcDecoder;
19use svod_arch::rnnt::{RnntDecoder, RnntOpts};
20use svod_tensor::PrepareConfig;
21
22pub use svod_arch::rnnt::Word;
23
24use crate::audio::{AudioChunk, EncoderBounds, MelConfig, MelSpectrogram, Splitter};
25use crate::gigaam::SubsamplingMode;
26use crate::gigaam::ctc::CtcHeadJit;
27use crate::gigaam::jit::GigaAmEncoderJit;
28use crate::gigaam::model::{GigaAm, Head};
29use crate::gigaam::rnnt::RnntStepBackend;
30use crate::jit::InputSpec;
31
32#[derive(Clone, Debug)]
51pub struct TranscribeOpts {
52 pub word_timestamps: bool,
55 pub beam_decode: bool,
58 pub max_scores_mib: usize,
62}
63
64impl Default for TranscribeOpts {
65 fn default() -> Self {
66 Self::builder().build()
67 }
68}
69
70#[bon]
71impl TranscribeOpts {
72 #[builder]
78 pub fn builder(
79 #[builder(default = std::env::var("SVOD_TIMESTAMPS").as_deref() == Ok("1"))] word_timestamps: bool,
80 #[builder(default = std::env::var("SVOD_BEAM_DECODE").as_deref() == Ok("1"))] beam_decode: bool,
81 #[builder(default = std::env::var("SVOD_MAX_SCORES_MIB").ok().and_then(|s| s.parse().ok()).unwrap_or(256))]
82 max_scores_mib: usize,
83 ) -> Self {
84 Self { word_timestamps, beam_decode, max_scores_mib }
85 }
86
87 pub fn from_env() -> Self {
90 Self::builder().build()
91 }
92}
93
94#[derive(Clone, Debug)]
98pub struct TranscribeResult {
99 pub text: String,
100 pub chunks: Vec<ChunkResult>,
101}
102
103impl TranscribeResult {
104 pub fn words(&self) -> impl Iterator<Item = Word> + '_ {
107 self.chunks.iter().flat_map(|c| {
108 let offset = c.start_sec;
109 c.words.iter().flatten().map(move |w| Word {
110 text: w.text.clone(),
111 start: w.start + offset,
112 end: w.end + offset,
113 })
114 })
115 }
116}
117
118#[derive(Clone, Debug)]
123pub struct ChunkResult {
124 pub start_sec: f32,
125 pub end_sec: f32,
126 pub text: String,
127 pub words: Option<Vec<Word>>,
128}
129
130#[allow(clippy::large_enum_variant)]
135pub(crate) enum HeadDecoder {
136 Ctc { jit: CtcHeadJit, decoder: CtcDecoder },
137 Rnnt { backend: RnntStepBackend, decoder: RnntDecoder, sentencepiece: bool },
138}
139
140pub(crate) fn ctc_frames_to_words(text: &str, frames: &[usize], frame_shift: f32) -> Vec<Word> {
148 let mut words: Vec<Word> = Vec::new();
149 let mut current = String::new();
150 let mut first_frame = 0usize;
151 let mut last_frame = 0usize;
152
153 let commit = |words: &mut Vec<Word>, current: &mut String, first: usize, last: usize| {
154 if !current.is_empty() {
155 words.push(Word {
156 text: std::mem::take(current),
157 start: first as f32 * frame_shift,
158 end: (last + 1) as f32 * frame_shift,
159 });
160 }
161 };
162
163 for (ch, &frame) in text.chars().zip(frames.iter()) {
164 if ch == ' ' {
165 commit(&mut words, &mut current, first_frame, last_frame);
166 continue;
167 }
168 if current.is_empty() {
169 first_frame = frame;
170 }
171 current.push(ch);
172 last_frame = frame;
173 }
174 commit(&mut words, &mut current, first_frame, last_frame);
175 words
176}
177
178fn transpose_dt_to_td(src: &[f32], d_model: usize, t_exec_sub: usize, actual_sub: usize) -> Vec<f32> {
182 let mut out = vec![0.0_f32; actual_sub * d_model];
183 for t in 0..actual_sub {
184 for d in 0..d_model {
185 out[t * d_model + d] = src[d * t_exec_sub + t];
186 }
187 }
188 out
189}
190
191fn rnnt_decode_err<E: std::error::Error + 'static>(
192 e: svod_arch::rnnt::RnntDecodeError<crate::jit::JitError>,
193) -> TranscribeError<E> {
194 TranscribeError::RnntDecode { source: Box::new(e) }
195}
196
197#[derive(Debug, Snafu)]
203#[snafu(visibility(pub(crate)))]
204pub enum TranscribeError<E: std::error::Error + 'static> {
205 #[snafu(display("splitter: {source}"))]
206 Splitter { source: E },
207 #[snafu(display("{source}"))]
208 Jit {
209 #[snafu(source(from(crate::jit::JitError, Box::new)))]
210 source: Box<crate::jit::JitError>,
211 },
212 #[snafu(display("{source}"))]
213 CtcDecode { source: svod_arch::ctc::DecodeError },
214 #[snafu(display("{source}"))]
215 RnntDecode { source: Box<svod_arch::rnnt::RnntDecodeError<crate::jit::JitError>> },
216 #[snafu(display("{source}"))]
217 Model {
218 #[snafu(source(from(crate::gigaam::error::Error, Box::new)))]
219 source: Box<crate::gigaam::error::Error>,
220 },
221 #[snafu(display("{source}"))]
222 Tensor {
223 #[snafu(source(from(svod_tensor::error::Error, Box::new)))]
224 source: Box<svod_tensor::error::Error>,
225 },
226 #[snafu(display("{source}"))]
227 Device {
228 #[snafu(source(from(svod_device::error::Error, Box::new)))]
229 source: Box<svod_device::error::Error>,
230 },
231 #[snafu(display("WAV is {wav_sr} Hz, model expects {model_sr} Hz (resample first)"))]
232 SampleRateMismatch { wav_sr: u32, model_sr: u32 },
233 #[snafu(display("chunk {idx} length {samples} samples exceeds encoder capacity {max_samples} samples"))]
234 ChunkExceedsCapacity { idx: usize, samples: usize, max_samples: usize },
235 #[snafu(display("chunk {idx} end {end_sample} exceeds waveform length {waveform_len}"))]
236 ChunkOutOfRange { idx: usize, end_sample: usize, waveform_len: usize },
237}
238
239pub struct Transcriber<S: Splitter> {
247 model: GigaAm,
248 opts: TranscribeOpts,
249 splitter: S,
250 mel: MelSpectrogram,
251 head_decoder: HeadDecoder,
252 encoder_jit: GigaAmEncoderJit,
253 max_batch: usize,
254 max_t_mel: usize,
255}
256
257impl<S: Splitter> Transcriber<S> {
258 pub fn new(model: GigaAm, splitter: S, opts: TranscribeOpts) -> Result<Self, TranscribeError<S::Error>> {
262 let mel = MelSpectrogram::new(&MelConfig {
263 sample_rate: model.config.sample_rate,
264 n_fft: model.config.n_fft,
265 hop_length: model.config.hop_length,
266 win_length: model.config.win_length,
267 n_mels: model.config.n_mels,
268 center: model.config.mel_center,
269 });
270
271 let subsampling_factor = model.config.subsampling_factor;
272 let hop_length = model.config.hop_length;
273 let model_bounds = EncoderBounds {
274 sample_rate: model.config.sample_rate as u32,
275 hop_length,
276 subsampling_factor,
277 max_mel_frames: model.config.max_mel_frames,
278 };
279 let chunk_samples_cap = splitter.max_chunk_samples(&model_bounds).min(model_bounds.max_samples());
283 let chunk_mel = (chunk_samples_cap / hop_length).saturating_add(2 * subsampling_factor);
284 let max_t_mel = chunk_mel.max(1).next_power_of_two().min(model.config.max_mel_frames).max(subsampling_factor);
285
286 let t_sub_max = (max_t_mel / subsampling_factor).max(1);
289 let scores_dtype_bytes = model.encoder.input_dtype().bytes();
290 let bytes_per_batch = model.config.n_heads * t_sub_max * t_sub_max * scores_dtype_bytes;
291 let target_scores_bytes = opts.max_scores_mib * 1024 * 1024;
292 let max_batch_by_memory = (target_scores_bytes / bytes_per_batch.max(1)).max(1);
293 let max_batch = max_batch_by_memory.min(model.config.max_batch_size);
294
295 let prepare_config = PrepareConfig::from_env();
296 let mut encoder_jit = GigaAmEncoderJit::new(model.clone()).with_b_bound(max_batch).with_t_bound(max_t_mel);
297 encoder_jit
298 .prepare_with_config(
299 InputSpec::f32(&[max_batch, model.config.n_mels, max_t_mel]),
300 InputSpec::i32(&[max_batch]),
301 &prepare_config,
302 )
303 .context(JitSnafu)?;
304
305 let head_decoder = match &model.head {
306 Head::Ctc(_) => {
307 let decoder = if opts.beam_decode {
308 match &model.config.decoder {
309 CtcDecoder::Greedy(g) => CtcDecoder::Beam(Box::new(svod_arch::ctc::BeamDecoder::new(
310 g.vocabulary().to_vec(),
311 svod_arch::ctc::BeamOpts::default(),
312 ))),
313 other => other.clone(),
314 }
315 } else {
316 model.config.decoder.clone()
317 };
318 let subs_kernel_size = match model.config.subsampling_mode {
319 SubsamplingMode::Conv1d => model.config.subs_kernel_size,
320 SubsamplingMode::Conv2d => 3,
321 };
322 let max_t_sub = subs_output_length(subs_kernel_size, max_t_mel);
323 let mut jit = CtcHeadJit::new(model.clone()).with_b_bound(max_batch).with_t_sub_bound(max_t_sub);
324 jit.prepare_with_config(InputSpec::f32(&[max_batch, model.config.d_model, max_t_sub]), &prepare_config)
325 .context(JitSnafu)?;
326 HeadDecoder::Ctc { jit, decoder }
327 }
328 Head::Rnnt { runtime, .. } => {
329 let backend = RnntStepBackend::from_model(model.clone()).context(JitSnafu)?;
330 let decoder = RnntDecoder::new(
331 runtime.vocabulary.clone(),
332 RnntOpts { max_symbols_per_step: runtime.max_symbols_per_step },
333 );
334 HeadDecoder::Rnnt { backend, decoder, sentencepiece: runtime.sentencepiece }
335 }
336 };
337
338 Ok(Self { model, opts, splitter, mel, head_decoder, encoder_jit, max_batch, max_t_mel })
339 }
340
341 pub fn encoder_bounds(&self, sample_rate: u32) -> Result<EncoderBounds, TranscribeError<S::Error>> {
344 self.bounds_with(sample_rate, self.model.config.max_mel_frames)
345 }
346
347 fn prepared_bounds(&self, sample_rate: u32) -> Result<EncoderBounds, TranscribeError<S::Error>> {
349 self.bounds_with(sample_rate, self.max_t_mel)
350 }
351
352 fn bounds_with(&self, sample_rate: u32, max_mel_frames: usize) -> Result<EncoderBounds, TranscribeError<S::Error>> {
353 if sample_rate as usize != self.model.config.sample_rate {
354 return Err(TranscribeError::SampleRateMismatch {
355 wav_sr: sample_rate,
356 model_sr: self.model.config.sample_rate as u32,
357 });
358 }
359 Ok(EncoderBounds {
360 sample_rate,
361 hop_length: self.model.config.hop_length,
362 subsampling_factor: self.model.config.subsampling_factor,
363 max_mel_frames,
364 })
365 }
366
367 pub fn transcribe(
372 &mut self,
373 waveform: &[f32],
374 sample_rate: u32,
375 ) -> Result<TranscribeResult, TranscribeError<S::Error>> {
376 let bounds = self.encoder_bounds(sample_rate)?;
377 let chunks = self.splitter.split(waveform, &bounds).context(SplitterSnafu)?;
378 self.transcribe_chunks(waveform, sample_rate, &chunks)
379 }
380
381 pub fn transcribe_chunks(
387 &mut self,
388 waveform: &[f32],
389 sample_rate: u32,
390 chunks: &[AudioChunk],
391 ) -> Result<TranscribeResult, TranscribeError<S::Error>> {
392 let max_samples = self.prepared_bounds(sample_rate)?.max_samples();
395 for (idx, chunk) in chunks.iter().enumerate() {
396 if chunk.end_sample > waveform.len() {
397 return Err(TranscribeError::ChunkOutOfRange {
398 idx,
399 end_sample: chunk.end_sample,
400 waveform_len: waveform.len(),
401 });
402 }
403 let samples = chunk.end_sample.saturating_sub(chunk.start_sample);
404 if samples > max_samples {
405 return Err(TranscribeError::ChunkExceedsCapacity { idx, samples, max_samples });
406 }
407 }
408
409 let n_mels = self.mel.n_mels();
410 if chunks.is_empty() {
411 return Ok(TranscribeResult { text: String::new(), chunks: Vec::new() });
412 }
413
414 let sample_rate_hz = self.model.config.sample_rate;
415 let d_model = self.model.config.d_model;
416 let subs_kernel_size = match self.model.config.subsampling_mode {
417 SubsamplingMode::Conv1d => self.model.config.subs_kernel_size,
418 SubsamplingMode::Conv2d => 3,
419 };
420 let max_t_mel = self.max_t_mel;
421 let max_t_sub = subs_output_length(subs_kernel_size, max_t_mel);
422 let max_batch = self.max_batch;
423 let want_words = self.opts.word_timestamps;
424
425 let chunks_meta: Vec<(usize, usize, usize, f32, f32)> = chunks
427 .iter()
428 .filter_map(|c| {
429 let mel_len = self.mel.num_frames(c.end_sample.saturating_sub(c.start_sample));
430 if mel_len == 0 {
431 return None;
432 }
433 let start_sec = c.start_sample as f32 / sample_rate_hz as f32;
434 let end_sec = c.end_sample as f32 / sample_rate_hz as f32;
435 Some((c.start_sample, c.end_sample, mel_len, start_sec, end_sec))
436 })
437 .collect();
438 if chunks_meta.is_empty() {
439 return Ok(TranscribeResult { text: String::new(), chunks: Vec::new() });
440 }
441
442 let num_chunks = chunks_meta.len();
443 let mut chunk_results: Vec<ChunkResult> = Vec::with_capacity(num_chunks);
444 for chunk_batch_start in (0..num_chunks).step_by(max_batch) {
445 let b = (num_chunks - chunk_batch_start).min(max_batch);
446 let mut chunk_lengths = vec![0usize; b];
447
448 let batch_mels: Vec<Vec<f32>> = (0..b)
449 .map(|bi| {
450 let &(start_sample, end_sample, valid, _, _) = &chunks_meta[chunk_batch_start + bi];
451 let mut chunk_mel = ndarray::Array3::<f32>::zeros((1, n_mels, valid));
452 {
453 let mut view = chunk_mel.view_mut().into_dyn();
454 self.mel.forward_into(&waveform[start_sample..end_sample], &mut view);
455 }
456 chunk_mel.as_slice().expect("contiguous chunk mel").to_vec()
457 })
458 .collect();
459
460 {
462 let buf = self.encoder_jit.mel_mut().context(JitSnafu)?;
463 let mut view = buf.as_array_mut::<f32>().context(DeviceSnafu)?;
464 let slice = view.as_slice_mut().expect("contiguous mel buffer");
465 slice.fill(0.0);
466 for (bi, chunk_len) in chunk_lengths.iter_mut().enumerate() {
467 let &(_, _, valid, _, _) = &chunks_meta[chunk_batch_start + bi];
468 *chunk_len = valid;
469 let chunk_mel = &batch_mels[bi];
470 for mel_bin in 0..n_mels {
471 let src = mel_bin * valid;
472 let dst = ((bi * n_mels) + mel_bin) * max_t_mel;
473 slice[dst..dst + valid].copy_from_slice(&chunk_mel[src..src + valid]);
474 }
475 }
476 }
477 {
479 let buf = self.encoder_jit.lengths_mut().context(JitSnafu)?;
480 let mut view = buf.as_array_mut::<i32>().context(DeviceSnafu)?;
481 let slice = view.as_slice_mut().expect("contiguous lengths buffer");
482 slice.fill(0);
483 for (i, len) in chunk_lengths.iter().enumerate() {
484 slice[i] = *len as i32;
485 }
486 }
487
488 let t_exec = chunk_lengths.iter().copied().max().unwrap_or(1).max(1);
489 let t_exec_sub = subs_output_length(subs_kernel_size, t_exec);
490 self.encoder_jit.execute_with_vars(&[("b", b as i64), ("t", t_exec as i64)]).context(JitSnafu)?;
491
492 match &mut self.head_decoder {
496 HeadDecoder::Ctc { jit, decoder } => {
497 {
500 let n = b * d_model * t_exec_sub;
501 let src_flat =
502 self.encoder_jit.output().context(JitSnafu)?.as_array::<f32>().context(DeviceSnafu)?;
503 let src_3d = src_flat
504 .slice(ndarray::s![0..n])
505 .into_shape_with_order((b, d_model, t_exec_sub))
506 .expect("encoder output reshape");
507 let dst_flat =
508 jit.encoded_mut().context(JitSnafu)?.as_array_mut::<f32>().context(DeviceSnafu)?;
509 let mut dst_3d = dst_flat
510 .into_shape_with_order((max_batch, d_model, max_t_sub))
511 .expect("head input reshape");
512 dst_3d.slice_mut(ndarray::s![0..b, 0..d_model, 0..t_exec_sub]).assign(&src_3d);
513 }
514 jit.execute_with_vars(&[("b", b as i64), ("t_sub", t_exec_sub as i64)]).context(JitSnafu)?;
515
516 let total_vocab = decoder.total_vocab();
517 let item_stride = t_exec_sub * total_vocab;
518 let logits_buf = jit.output().context(JitSnafu)?;
519 let logits = logits_buf.as_array::<f32>().context(DeviceSnafu)?;
520 let flat = logits.as_slice().expect("contiguous head logits");
521 for (bi, mel_len) in chunk_lengths.iter().enumerate() {
522 let actual_sub = subs_output_length(subs_kernel_size, *mel_len);
523 let &(start_sample, end_sample, _, start_sec, end_sec) = &chunks_meta[chunk_batch_start + bi];
524 let chunk_duration_sec = (end_sample - start_sample) as f32 / sample_rate_hz as f32;
525 let frame_shift = chunk_duration_sec / (actual_sub.max(1) as f32);
526
527 let item_slice = &flat[bi * item_stride..bi * item_stride + item_stride];
528
529 let (text, frames) = if want_words {
530 let (text, frames) = decoder
531 .decode_with_timestamps(item_slice, t_exec_sub, actual_sub)
532 .context(CtcDecodeSnafu)?;
533 (text, Some(frames))
534 } else {
535 let text = decoder.decode(item_slice, t_exec_sub, actual_sub).context(CtcDecodeSnafu)?;
536 (text, None)
537 };
538 let words = want_words.then(|| {
539 let frames = frames.as_deref().unwrap_or(&[]);
540 ctc_frames_to_words(&text, frames, frame_shift)
541 });
542 chunk_results.push(ChunkResult { start_sec, end_sec, text, words });
543 }
544 }
545 HeadDecoder::Rnnt { backend, decoder, sentencepiece } => {
546 let item_stride = d_model * t_exec_sub;
547 let enc_buf = self.encoder_jit.output().context(JitSnafu)?;
548 let enc = enc_buf.as_array::<f32>().context(DeviceSnafu)?;
549 let flat = enc.as_slice().expect("contiguous encoder output");
550 for (bi, mel_len) in chunk_lengths.iter().enumerate() {
551 let actual_sub = subs_output_length(subs_kernel_size, *mel_len);
552 let &(start_sample, end_sample, _, start_sec, end_sec) = &chunks_meta[chunk_batch_start + bi];
553 let chunk_duration_sec = (end_sample - start_sample) as f32 / sample_rate_hz as f32;
554 let frame_shift = chunk_duration_sec / (actual_sub.max(1) as f32);
555
556 let item_slice = &flat[bi * item_stride..bi * item_stride + item_stride];
557 let frames = transpose_dt_to_td(item_slice, d_model, t_exec_sub, actual_sub);
560
561 let backend: &mut RnntStepBackend = backend;
562 let (raw, emissions) = if want_words {
563 let (s, e) = decoder
564 .decode_with_timestamps(&frames, actual_sub, actual_sub, d_model, backend)
565 .map_err(rnnt_decode_err)?;
566 (s, e)
567 } else {
568 let s = decoder
569 .decode(&frames, actual_sub, actual_sub, d_model, backend)
570 .map_err(rnnt_decode_err)?;
571 (s, Vec::new())
572 };
573 let words = want_words.then(|| decoder.frames_to_words(&emissions, frame_shift));
574 let text = if *sentencepiece { raw.replace('\u{2581}', " ").trim().to_string() } else { raw };
577 chunk_results.push(ChunkResult { start_sec, end_sec, text, words });
578 }
579 }
580 }
581 }
582
583 let text =
584 chunk_results.iter().map(|c| c.text.as_str()).filter(|s| !s.is_empty()).collect::<Vec<_>>().join(" ");
585 Ok(TranscribeResult { text, chunks: chunk_results })
586 }
587}
588
589fn subs_output_length(kernel_size: usize, mel_frames: usize) -> usize {
593 let pad = (kernel_size - 1) / 2;
594 let mut len = mel_frames;
595 for _ in 0..2 {
596 len = (len + 2 * pad - kernel_size) / 2 + 1;
597 }
598 len
599}