1#![allow(clippy::all, clippy::pedantic, clippy::restriction, clippy::nursery)]
3
4#[cfg(feature = "realizar-gpu")]
6use crate::cuda;
7use crate::{
8 audio, detection, error, format, inference, model, progress, timestamps, tokenizer, vad,
9};
10pub use error::{WhisperError, WhisperResult};
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum ModelType {
15 Tiny,
17 TinyEn,
19 Base,
21 BaseEn,
23 Small,
25 SmallEn,
27 Medium,
29 MediumEn,
31 Large,
33 LargeV1,
35 LargeV2,
37 LargeV3,
39 LargeV3Turbo,
41}
42
43#[derive(Debug, Clone, Copy, Default)]
45pub enum DecodingStrategy {
46 #[default]
48 Greedy,
49 BeamSearch {
51 beam_size: usize,
53 temperature: f32,
55 patience: f32,
57 },
58 Sampling {
60 temperature: f32,
62 top_k: Option<usize>,
64 top_p: Option<f32>,
66 },
67}
68
69#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
71pub enum Task {
72 #[default]
74 Transcribe,
75 Translate,
77}
78
79#[derive(Debug, Clone, Default)]
81pub struct TranscribeOptions {
82 pub language: Option<String>,
84 pub task: Task,
86 pub strategy: DecodingStrategy,
88 pub word_timestamps: bool,
90 pub profile: bool,
92 pub prompt: Option<String>,
95 pub hotwords: Vec<String>,
98}
99
100#[derive(Debug, Clone, Default)]
102#[cfg_attr(feature = "cli", derive(serde::Serialize, serde::Deserialize))]
103pub struct Segment {
104 pub start: f32,
106 pub end: f32,
108 pub text: String,
110 pub tokens: Vec<u32>,
112}
113
114#[derive(Debug, Clone, Default)]
116#[cfg_attr(feature = "cli", derive(serde::Serialize, serde::Deserialize))]
117pub struct ProfilingStats {
118 pub total_ms: f64,
120 pub breakdown: std::collections::HashMap<String, f64>,
122 #[cfg_attr(feature = "cli", serde(skip_serializing_if = "Option::is_none"))]
124 pub trace_json: Option<String>,
125}
126
127#[derive(Debug, Clone, Default)]
129#[cfg_attr(feature = "cli", derive(serde::Serialize, serde::Deserialize))]
130pub struct TranscriptionResult {
131 pub text: String,
133 pub language: String,
135 pub segments: Vec<Segment>,
137 #[cfg_attr(feature = "cli", serde(skip_serializing_if = "Option::is_none"))]
139 pub profiling: Option<ProfilingStats>,
140}
141
142#[derive(Debug, Clone)]
144pub struct BatchTranscriptionResult {
145 pub results: Vec<TranscriptionResult>,
147 pub total_duration_secs: f32,
149}
150
151impl BatchTranscriptionResult {
152 #[must_use]
154 pub fn len(&self) -> usize {
155 self.results.len()
156 }
157
158 #[must_use]
160 pub fn is_empty(&self) -> bool {
161 self.results.is_empty()
162 }
163
164 #[must_use]
166 pub fn get(&self, index: usize) -> Option<&TranscriptionResult> {
167 self.results.get(index)
168 }
169
170 pub fn iter(&self) -> impl Iterator<Item = &TranscriptionResult> {
172 self.results.iter()
173 }
174
175 #[must_use]
177 pub fn texts(&self) -> Vec<&str> {
178 self.results.iter().map(|r| r.text.as_str()).collect()
179 }
180}
181
182pub struct SummarizeOptions<'a> {
186 pub model: &'a model::lfm2::Lfm2,
188 pub tokenizer: &'a model::lfm2::Lfm2Tokenizer,
190 pub max_tokens: usize,
192 pub temperature: f32,
194}
195
196impl<'a> SummarizeOptions<'a> {
197 #[must_use]
203 pub fn new(model: &'a model::lfm2::Lfm2, tokenizer: &'a model::lfm2::Lfm2Tokenizer) -> Self {
204 Self {
205 model,
206 tokenizer,
207 max_tokens: 256,
208 temperature: 0.3,
209 }
210 }
211
212 #[must_use]
214 pub const fn with_max_tokens(mut self, max_tokens: usize) -> Self {
215 self.max_tokens = max_tokens;
216 self
217 }
218
219 #[must_use]
221 pub const fn with_temperature(mut self, temperature: f32) -> Self {
222 self.temperature = temperature;
223 self
224 }
225}
226
227#[derive(Debug, Clone)]
231pub struct TranscribeSummaryResult {
232 pub transcription: TranscriptionResult,
234 pub summary: String,
236 pub generation_stats: Option<model::lfm2::GenerationStats>,
238}
239
240impl TranscribeSummaryResult {
241 #[must_use]
243 pub fn transcript(&self) -> &str {
244 &self.transcription.text
245 }
246
247 #[must_use]
249 pub fn summary(&self) -> &str {
250 &self.summary
251 }
252
253 #[must_use]
255 pub fn has_summary(&self) -> bool {
256 !self.summary.is_empty()
257 }
258}
259
260#[derive(Debug, Clone)]
281pub struct WhisperApr {
282 config: model::ModelConfig,
284 encoder: model::Encoder,
286 decoder: model::Decoder,
288 tokenizer: tokenizer::Tokenizer,
290 mel_filters: Option<audio::MelFilterbank>,
292 conv_stem: Option<audio::ConvStem>,
294 resampler: Option<audio::SincResampler>,
296 weights_loaded: bool,
298}
299
300impl WhisperApr {
301 #[must_use]
306 pub fn from_config(config: model::ModelConfig) -> Self {
307 let encoder = model::Encoder::new(&config);
308 let decoder = model::Decoder::new(&config);
309 let tokenizer = if config.model_family == format::ModelFamily::Moonshine {
310 tokenizer::Tokenizer::SentencePiece(
311 tokenizer::SentencePieceTokenizer::moonshine_default(),
312 )
313 } else {
314 tokenizer::Tokenizer::Bpe(tokenizer::BpeTokenizer::with_base_tokens())
315 };
316
317 let (mel_filters, conv_stem) = match config.audio_frontend {
319 model::AudioFrontend::MelFilterbank => (
320 Some(audio::MelFilterbank::new(&audio::MelConfig {
321 n_mels: config.n_mels as usize,
322 ..audio::MelConfig::whisper()
323 })),
324 None,
325 ),
326 model::AudioFrontend::LearnedConv => (
327 None,
328 Some(audio::ConvStem::new(config.n_audio_state as usize)),
329 ),
330 };
331
332 Self {
333 config,
334 encoder,
335 decoder,
336 tokenizer,
337 mel_filters,
338 conv_stem,
339 resampler: None,
340 weights_loaded: false,
341 }
342 }
343
344 #[must_use]
346 pub fn tiny() -> Self {
347 Self::from_config(model::ModelConfig::tiny())
348 }
349
350 #[must_use]
352 pub fn base() -> Self {
353 Self::from_config(model::ModelConfig::base())
354 }
355
356 #[must_use]
358 pub fn small() -> Self {
359 Self::from_config(model::ModelConfig::small())
360 }
361
362 #[must_use]
364 pub fn medium() -> Self {
365 Self::from_config(model::ModelConfig::medium())
366 }
367
368 #[must_use]
370 pub fn large() -> Self {
371 Self::from_config(model::ModelConfig::large())
372 }
373
374 #[must_use]
376 pub fn moonshine_tiny() -> Self {
377 Self::from_config(model::ModelConfig::moonshine_tiny())
378 }
379
380 #[must_use]
382 pub fn moonshine_base() -> Self {
383 Self::from_config(model::ModelConfig::moonshine_base())
384 }
385
386 #[must_use]
388 pub const fn config(&self) -> &model::ModelConfig {
389 &self.config
390 }
391
392 #[must_use]
394 pub const fn model_type(&self) -> ModelType {
395 self.config.model_type
396 }
397
398 const CHUNK_SAMPLES: usize = 30 * audio::SAMPLE_RATE as usize; const OVERLAP_SAMPLES: usize = 5 * audio::SAMPLE_RATE as usize; const MAX_SUBTITLE_SECS: f32 = 10.0;
406
407 pub fn transcribe(
423 &self,
424 audio: &[f32],
425 options: TranscribeOptions,
426 ) -> WhisperResult<TranscriptionResult> {
427 if audio.len() > Self::CHUNK_SAMPLES {
429 return self.transcribe_chunked(audio, options);
430 }
431
432 self.transcribe_single_chunk(audio, options)
434 }
435
436 fn transcribe_single_chunk(
440 &self,
441 audio: &[f32],
442 options: TranscribeOptions,
443 ) -> WhisperResult<TranscriptionResult> {
444 #[cfg(feature = "std")]
445 let start_total = if options.profile {
446 Some(std::time::Instant::now())
447 } else {
448 None
449 };
450
451 #[cfg(feature = "std")]
453 let start_audio = start_total.map(|_| std::time::Instant::now());
454
455 #[cfg(feature = "std")]
456 let mut mel_ms: Option<f64> = None;
457
458 #[cfg(feature = "std")]
460 let mut brick_profiler: Option<trueno::BrickProfiler> = if options.profile {
461 Some(trueno::BrickProfiler::enabled())
462 } else {
463 None
464 };
465
466 #[cfg(feature = "std")]
468 let mut page_faults: Option<(u64, u64)> = None;
469
470 #[cfg(feature = "std")]
472 let mut blis_profiler_stats: Option<trueno::blis::BlisProfiler> = None;
473
474 #[cfg(all(feature = "std", feature = "realizar-inference"))]
476 let mut inference_tracer: Option<realizar::InferenceTracer> = if options.profile {
477 let config = realizar::TraceConfig::enabled();
478 let mut tracer = realizar::InferenceTracer::new(config);
479 tracer.set_model_info(realizar::ModelInfo {
480 name: format!("{:?}", self.config.model_type),
481 num_layers: self.config.n_audio_layer as usize,
482 hidden_dim: self.config.n_audio_state as usize,
483 vocab_size: self.config.n_vocab as usize,
484 num_heads: self.config.n_audio_head as usize,
485 quant_type: Some("f32".into()),
486 });
487 Some(tracer)
488 } else {
489 None
490 };
491
492 let audio_features = match self.config.audio_frontend {
493 model::AudioFrontend::MelFilterbank => {
494 #[cfg(feature = "std")]
495 let mel_start = std::time::Instant::now();
496
497 let mel = self.compute_mel(audio)?;
498
499 #[cfg(feature = "std")]
500 {
501 mel_ms = Some(mel_start.elapsed().as_secs_f64() * 1000.0);
502 }
503
504 #[cfg(feature = "std")]
506 if let Some(ref mut profiler) = brick_profiler {
507 crate::simd::enable_blis_profiling();
509 let (pf_minor_before, pf_major_before) = trueno::brick::get_page_faults();
510 #[cfg(feature = "realizar-inference")]
511 let features =
512 self.encode_profiled(&mel, profiler, inference_tracer.as_mut())?;
513 #[cfg(not(feature = "realizar-inference"))]
514 let features = self.encode_profiled(&mel, profiler)?;
515 blis_profiler_stats = crate::simd::take_blis_profiler();
517 let (pf_minor_after, pf_major_after) = trueno::brick::get_page_faults();
518 page_faults = Some((
519 pf_minor_after.saturating_sub(pf_minor_before),
520 pf_major_after.saturating_sub(pf_major_before),
521 ));
522 features
523 } else {
524 self.encode(&mel)?
525 }
526
527 #[cfg(not(feature = "std"))]
528 self.encode(&mel)?
529 }
530 model::AudioFrontend::LearnedConv => {
531 let stem = self
532 .conv_stem
533 .as_ref()
534 .ok_or_else(|| WhisperError::Model("Moonshine requires ConvStem".into()))?;
535 let stem_out = stem.forward(audio)?;
536 self.encoder.forward(&stem_out)?
537 }
538 };
539
540 #[cfg(feature = "std")]
541 let enc_ms = start_audio.map(|s| s.elapsed().as_secs_f64() * 1000.0);
542
543 let language = options.language.clone().unwrap_or_else(|| "en".to_string());
545
546 let initial_tokens =
548 self.build_initial_tokens(&language, options.task, options.prompt.as_deref());
549
550 #[cfg(feature = "std")]
552 let start_dec = start_total.map(|_| std::time::Instant::now());
553 let tokens = self.decode(&audio_features, &initial_tokens, &options)?;
554 #[cfg(feature = "std")]
555 let dec_ms = start_dec.map(|s| s.elapsed().as_secs_f64() * 1000.0);
556
557 let text = self.tokenizer.decode(&tokens)?;
559
560 let segments = if timestamps::has_timestamps(&tokens) {
562 timestamps::extract_segments(&tokens, |ts| self.tokenizer.decode(ts).ok())
563 } else if !text.trim().is_empty() {
564 let duration = audio.len() as f32 / audio::SAMPLE_RATE as f32;
565 let single = vec![Segment {
566 start: 0.0,
567 end: duration,
568 text: text.clone(),
569 tokens: tokens.clone(),
570 }];
571 timestamps::split_long_segments(&single, Self::MAX_SUBTITLE_SECS)
572 } else {
573 Vec::new()
574 };
575
576 #[cfg(feature = "std")]
578 let profiling = start_total.map(|st| {
579 let mut breakdown = std::collections::HashMap::new();
580 let pairs: &[(&str, Option<f64>)] = &[
581 ("mel_ms", mel_ms),
582 ("audio_ms", enc_ms),
583 ("decoder_ms", dec_ms),
584 ];
585 for &(key, val) in pairs {
586 if let Some(ms) = val {
587 breakdown.insert(key.to_string(), ms);
588 }
589 }
590 if let Some(audio) = enc_ms {
592 let enc_total = audio - mel_ms.unwrap_or(0.0);
593 breakdown.insert("encoder_ms".to_string(), enc_total);
594 }
595 if let Some(ref profiler) = brick_profiler {
597 let cats = profiler.category_stats();
598 let total_prof_ns = profiler.total_ns();
599 let norm_ns = cats[trueno::BrickCategory::Norm as usize].total_ns;
601 let attn_ns = cats[trueno::BrickCategory::Attention as usize].total_ns;
602 let ffn_ns = cats[trueno::BrickCategory::Ffn as usize].total_ns;
603 let other_ns = cats[trueno::BrickCategory::Other as usize].total_ns;
604 breakdown.insert("brick_norm_ms".to_string(), norm_ns as f64 / 1_000_000.0);
605 breakdown.insert("brick_attn_ms".to_string(), attn_ns as f64 / 1_000_000.0);
606 breakdown.insert("brick_ffn_ms".to_string(), ffn_ns as f64 / 1_000_000.0);
607 breakdown.insert("brick_other_ms".to_string(), other_ns as f64 / 1_000_000.0);
608 breakdown.insert("brick_total_ns".to_string(), total_prof_ns as f64);
609 for &brick_id in &[
611 trueno::BrickId::LayerNorm,
612 trueno::BrickId::AttentionScore,
613 trueno::BrickId::GateProjection,
614 trueno::BrickId::Embedding,
615 ] {
616 let stats = profiler.brick_stats(brick_id);
617 if stats.count > 0 {
618 let name = brick_id.name();
619 breakdown.insert(
620 format!("brick_{name}_ms"),
621 stats.total_ns as f64 / 1_000_000.0,
622 );
623 breakdown.insert(format!("brick_{name}_count"), stats.count as f64);
624 breakdown.insert(
625 format!("brick_{name}_cycles_per_elem"),
626 stats.cycles_per_element(),
627 );
628 let diagnosis = stats.diagnose_from_cycles();
629 let diag_code = match diagnosis {
631 "memory-bound (low IPC, likely cache misses)" => 1.0,
632 "compute-bound (efficient)" => 2.0,
633 "throttled or context-switched" => 3.0,
634 "balanced" => 4.0,
635 _ => 0.0,
636 };
637 breakdown.insert(format!("brick_{name}_bottleneck"), diag_code);
638 }
639 }
640 if let Some((minor, major)) = page_faults {
642 breakdown.insert("page_faults_minor".to_string(), minor as f64);
643 breakdown.insert("page_faults_major".to_string(), major as f64);
644 }
645 }
646 if let Some(ref blis) = blis_profiler_stats {
648 breakdown.insert("blis_macro_gflops".to_string(), blis.macro_stats.gflops());
649 breakdown.insert(
650 "blis_macro_calls".to_string(),
651 blis.macro_stats.count as f64,
652 );
653 breakdown.insert(
654 "blis_macro_ns".to_string(),
655 blis.macro_stats.total_ns as f64,
656 );
657 breakdown.insert("blis_midi_gflops".to_string(), blis.midi_stats.gflops());
658 breakdown.insert("blis_midi_calls".to_string(), blis.midi_stats.count as f64);
659 breakdown.insert("blis_micro_gflops".to_string(), blis.micro_stats.gflops());
660 breakdown.insert(
661 "blis_micro_calls".to_string(),
662 blis.micro_stats.count as f64,
663 );
664 breakdown.insert("blis_pack_ns".to_string(), blis.pack_stats.total_ns as f64);
665 breakdown.insert("blis_pack_calls".to_string(), blis.pack_stats.count as f64);
666 breakdown.insert("blis_total_gflops".to_string(), blis.total_gflops());
667 if blis.macro_stats.total_ns > 0 {
669 let pack_pct =
670 blis.pack_stats.total_ns as f64 / blis.macro_stats.total_ns as f64 * 100.0;
671 breakdown.insert("blis_pack_pct".to_string(), pack_pct);
672 }
673 }
674 #[cfg(feature = "realizar-inference")]
676 let trace_json = inference_tracer.as_ref().map(|t| t.to_json());
677 #[cfg(not(feature = "realizar-inference"))]
678 let trace_json: Option<String> = None;
679
680 ProfilingStats {
681 total_ms: st.elapsed().as_secs_f64() * 1000.0,
682 breakdown,
683 trace_json,
684 }
685 });
686
687 #[cfg(not(feature = "std"))]
688 let profiling = None;
689
690 Ok(TranscriptionResult {
691 text,
692 language,
693 segments,
694 profiling,
695 })
696 }
697
698 fn transcribe_chunked(
718 &self,
719 audio: &[f32],
720 options: TranscribeOptions,
721 ) -> WhisperResult<TranscriptionResult> {
722 let chunk_size = Self::CHUNK_SAMPLES;
723 let overlap = Self::OVERLAP_SAMPLES;
724 let step = chunk_size; let language = options.language.clone().unwrap_or_else(|| "en".to_string());
728
729 let mut all_segments: Vec<Segment> = Vec::new();
730 let mut all_text = String::new();
731 let mut chunk_idx = 0;
732
733 let mut offset = 0;
735 while offset < audio.len() {
736 let chunk_end = (offset + chunk_size + overlap).min(audio.len());
738 let chunk = &audio[offset..chunk_end];
739
740 if chunk.len() < audio::SAMPLE_RATE as usize / 2 {
742 break;
743 }
744
745 let chunk_options = TranscribeOptions {
747 language: Some(language.clone()),
748 task: options.task,
749 strategy: options.strategy,
750 word_timestamps: options.word_timestamps,
751 profile: options.profile,
752 prompt: options.prompt.clone(),
753 hotwords: options.hotwords.clone(),
754 };
755
756 let chunk_result = self.transcribe_single_chunk(chunk, chunk_options)?;
757
758 let time_offset = offset as f32 / audio::SAMPLE_RATE as f32;
760
761 let chunk_text = if chunk_idx == 0 {
763 chunk_result.text.clone()
765 } else {
766 let overlap_ratio = overlap as f32 / chunk.len() as f32;
770 let words: Vec<&str> = chunk_result.text.split_whitespace().collect();
771 let skip_words = ((words.len() as f32) * overlap_ratio * 0.8) as usize;
772 words
773 .into_iter()
774 .skip(skip_words)
775 .collect::<Vec<_>>()
776 .join(" ")
777 };
778
779 if !chunk_text.is_empty() {
781 if !all_text.is_empty() {
782 all_text.push(' ');
783 }
784 all_text.push_str(&chunk_text);
785 }
786
787 for mut seg in chunk_result.segments {
789 seg.start += time_offset;
790 seg.end += time_offset;
791 all_segments.push(seg);
792 }
793
794 offset += step;
796 chunk_idx += 1;
797 }
798
799 let merged_segments = self.merge_overlapping_segments(all_segments);
801
802 let final_segments =
804 timestamps::split_long_segments(&merged_segments, Self::MAX_SUBTITLE_SECS);
805
806 Ok(TranscriptionResult {
807 text: all_text,
808 language,
809 segments: final_segments,
810 profiling: None,
811 })
812 }
813
814 fn merge_overlapping_segments(&self, segments: Vec<Segment>) -> Vec<Segment> {
819 if segments.is_empty() {
820 return segments;
821 }
822
823 let mut merged: Vec<Segment> = Vec::with_capacity(segments.len());
824 let mut current = segments[0].clone();
825
826 for seg in segments.into_iter().skip(1) {
827 if seg.start < current.end + 0.1 {
829 current.end = current.end.max(seg.end);
831 if !seg.text.is_empty() {
832 if !current.text.is_empty() {
833 current.text.push(' ');
834 }
835 current.text.push_str(&seg.text);
836 }
837 current.tokens.extend(seg.tokens);
838 } else {
839 merged.push(current);
841 current = seg;
842 }
843 }
844
845 merged.push(current);
847
848 merged
849 }
850
851 pub fn forward_probed(
869 &self,
870 audio: &[f32],
871 tokens: &[u32],
872 probe: &mut crate::probe::ActivationProbe,
873 ) -> WhisperResult<Vec<f32>> {
874 let audio_features = match self.config.audio_frontend {
876 model::AudioFrontend::MelFilterbank => {
877 let mel = self.compute_mel(audio)?;
878 self.encoder.forward_probed(&mel, probe)?
879 }
880 model::AudioFrontend::LearnedConv => {
881 let stem = self
882 .conv_stem
883 .as_ref()
884 .ok_or_else(|| WhisperError::Model("Moonshine requires ConvStem".into()))?;
885 let stem_out = stem.forward_probed(audio, probe)?;
886 self.encoder.forward_probed(&stem_out, probe)?
887 }
888 };
889
890 self.decoder.forward_probed(tokens, &audio_features, probe)
892 }
893
894 fn eot_token(&self) -> u32 {
896 if self.config.model_family == format::ModelFamily::Moonshine {
897 2 } else {
899 tokenizer::special_tokens::EOT
900 }
901 }
902
903 pub fn compute_mel(&self, audio: &[f32]) -> WhisperResult<Vec<f32>> {
909 const N_SAMPLES_30S: usize = 480_000; const N_FRAMES: usize = 3000; const N_MELS: usize = 80;
912
913 let padded_audio = match audio.len().cmp(&N_SAMPLES_30S) {
915 std::cmp::Ordering::Equal => audio.to_vec(),
916 std::cmp::Ordering::Less => {
917 let mut padded = vec![0.0_f32; N_SAMPLES_30S];
919 padded[..audio.len()].copy_from_slice(audio);
920 padded
921 }
922 std::cmp::Ordering::Greater => {
923 audio[..N_SAMPLES_30S].to_vec()
925 }
926 };
927
928 let mel_fb = self.mel_filters.as_ref().ok_or_else(|| {
929 WhisperError::Audio("mel filterbank not available (Moonshine model?)".into())
930 })?;
931 let mut mel = mel_fb
932 .compute(&padded_audio)
933 .map_err(|e| WhisperError::Audio(e.to_string()))?;
934 let actual_frames = mel.len() / N_MELS;
935
936 if actual_frames < N_FRAMES {
938 let pad_value = -1.0_f32;
941 let mut padded_mel = vec![pad_value; N_FRAMES * N_MELS];
942 padded_mel[..mel.len()].copy_from_slice(&mel);
943 mel = padded_mel;
944 } else if actual_frames > N_FRAMES {
945 mel.truncate(N_FRAMES * N_MELS);
946 }
947
948 Ok(mel)
952 }
953
954 pub fn encode(&self, mel: &[f32]) -> WhisperResult<Vec<f32>> {
956 self.encoder.forward_mel(mel)
958 }
959
960 #[cfg(feature = "realizar-inference")]
965 pub fn encode_profiled(
966 &self,
967 mel: &[f32],
968 profiler: &mut trueno::BrickProfiler,
969 tracer: Option<&mut realizar::InferenceTracer>,
970 ) -> WhisperResult<Vec<f32>> {
971 self.encoder.forward_mel_profiled(mel, profiler, tracer)
972 }
973
974 #[cfg(not(feature = "realizar-inference"))]
978 pub fn encode_profiled(
979 &self,
980 mel: &[f32],
981 profiler: &mut trueno::BrickProfiler,
982 ) -> WhisperResult<Vec<f32>> {
983 self.encoder.forward_mel_profiled(mel, profiler)
984 }
985
986 fn get_initial_tokens(&self, language: &str, task: Task) -> Vec<u32> {
991 if self.config.model_family == format::ModelFamily::Moonshine {
993 return vec![1]; }
995
996 use tokenizer::special_tokens::{self, SpecialTokens};
998
999 let specials = SpecialTokens::for_vocab_size(self.config.n_vocab as usize);
1001
1002 let mut tokens = vec![specials.sot];
1003
1004 if specials.is_multilingual {
1008 let lang_offset = special_tokens::language_offset(language).unwrap_or(0);
1009 tokens.push(specials.lang_base + lang_offset);
1010 }
1011
1012 match task {
1014 Task::Transcribe => tokens.push(specials.transcribe),
1015 Task::Translate => tokens.push(special_tokens::TRANSLATE),
1016 }
1017
1018 tokens
1023 }
1024
1025 fn build_initial_tokens(&self, language: &str, task: Task, prompt: Option<&str>) -> Vec<u32> {
1030 let mut tokens = self.get_initial_tokens(language, task);
1031
1032 let prompt_text = match prompt {
1033 Some(p) if !p.is_empty() => p,
1034 _ => return tokens,
1035 };
1036
1037 let prompt_tokens = match self.tokenizer.encode(prompt_text) {
1038 Ok(t) if !t.is_empty() => t,
1039 _ => return tokens,
1040 };
1041
1042 let max_prompt = (self.config.n_text_ctx as usize / 2).min(224);
1044 let truncated = if prompt_tokens.len() > max_prompt {
1045 &prompt_tokens[prompt_tokens.len() - max_prompt..]
1046 } else {
1047 &prompt_tokens
1048 };
1049
1050 let mut prefix = Vec::with_capacity(1 + truncated.len() + tokens.len());
1052 prefix.push(tokenizer::special_tokens::PREV);
1053 prefix.extend_from_slice(truncated);
1054 prefix.append(&mut tokens);
1055 prefix
1056 }
1057
1058 pub fn detect_language(&self, audio: &[f32]) -> WhisperResult<detection::LanguageProbs> {
1071 let mel = self.compute_mel(audio)?;
1073
1074 let audio_features = self.encode(&mel)?;
1076
1077 let n_vocab = self.config.n_vocab as usize;
1079 let logits_fn = |tokens: &[u32]| -> WhisperResult<Vec<f32>> {
1080 let all_logits = self.decoder.forward(tokens, &audio_features)?;
1081 let seq_len = tokens.len();
1082 let last_start = (seq_len - 1) * n_vocab;
1083
1084 if all_logits.len() >= last_start + n_vocab {
1085 Ok(all_logits[last_start..last_start + n_vocab].to_vec())
1086 } else {
1087 let mut padded = vec![f32::NEG_INFINITY; n_vocab];
1088 let available = all_logits.len().saturating_sub(last_start);
1089 if available > 0 {
1090 padded[..available].copy_from_slice(&all_logits[last_start..]);
1091 }
1092 Ok(padded)
1093 }
1094 };
1095
1096 let detector = detection::LanguageDetector::new();
1098 detector.detect(logits_fn)
1099 }
1100
1101 fn decode(
1105 &self,
1106 audio_features: &[f32],
1107 initial_tokens: &[u32],
1108 options: &TranscribeOptions,
1109 ) -> WhisperResult<Vec<u32>> {
1110 use std::cell::RefCell;
1111
1112 let n_vocab = self.config.n_vocab as usize;
1113 let max_tokens = self.config.n_text_ctx as usize;
1114
1115 let cache = RefCell::new(self.decoder.create_kv_cache());
1118 let processed_count = RefCell::new(0usize);
1119
1120 let suppressor = inference::WhisperTokenSuppressor::new()
1123 .with_timestamp_suppression(false)
1124 .with_vocab_size(n_vocab);
1125
1126 let hotword_booster = if !options.hotwords.is_empty() {
1128 let mut booster = crate::vocabulary::HotwordBooster::new();
1129 for word in &options.hotwords {
1130 if let Ok(tokens) = self.tokenizer.encode(word) {
1131 if !tokens.is_empty() {
1132 booster.add_hotword_with_tokens_default(word, tokens);
1133 }
1134 }
1135 }
1136 if booster.is_empty() {
1137 None
1138 } else {
1139 Some(booster)
1140 }
1141 } else {
1142 None
1143 };
1144
1145 let scratch = std::cell::RefCell::new(self.decoder.create_decoder_scratch());
1147
1148 let logits_fn = |tokens: &[u32]| -> WhisperResult<Vec<f32>> {
1149 let seq_len = tokens.len();
1150 let already_processed = *processed_count.borrow();
1151
1152 let mut logits = vec![f32::NEG_INFINITY; n_vocab];
1154
1155 for &token in tokens.iter().take(seq_len).skip(already_processed) {
1156 logits = self.decoder.forward_one_with_scratch(
1157 token,
1158 audio_features,
1159 &mut cache.borrow_mut(),
1160 &mut scratch.borrow_mut(),
1161 )?;
1162 }
1163
1164 *processed_count.borrow_mut() = seq_len;
1165
1166 suppressor.apply(&mut logits);
1169
1170 if let Some(ref booster) = hotword_booster {
1172 booster.apply_bias(&mut logits, tokens);
1173 }
1174
1175 let eot_id = self.eot_token();
1177 Self::suppress_repetitions(&mut logits, tokens, eot_id, n_vocab);
1178
1179 Ok(logits)
1180 };
1181
1182 let eot = self.eot_token();
1184 match options.strategy {
1185 DecodingStrategy::Greedy => {
1186 let decoder = inference::GreedyDecoder::new(max_tokens);
1187 decoder.decode(logits_fn, initial_tokens, eot)
1188 }
1189 DecodingStrategy::BeamSearch {
1190 beam_size,
1191 temperature,
1192 patience,
1193 } => {
1194 let decoder = inference::BeamSearchDecoder::new(beam_size, max_tokens)
1195 .with_temperature(temperature)
1196 .with_patience(patience);
1197 decoder.decode(logits_fn, initial_tokens, eot)
1198 }
1199 DecodingStrategy::Sampling { temperature, .. } => {
1200 let decoder =
1202 inference::GreedyDecoder::new(max_tokens).with_temperature(temperature);
1203 decoder.decode(logits_fn, initial_tokens, eot)
1204 }
1205 }
1206 }
1207
1208 fn suppress_repetitions(logits: &mut [f32], tokens: &[u32], eot_id: u32, n_vocab: usize) {
1214 let text_tokens: Vec<u32> = if eot_id <= 3 {
1218 tokens.iter().copied().filter(|&t| t > eot_id).collect()
1219 } else {
1220 tokens.iter().copied().filter(|&t| t < eot_id).collect()
1221 };
1222
1223 let window_start = text_tokens.len().saturating_sub(50);
1225 for &prev_tok in &text_tokens[window_start..] {
1226 if (prev_tok as usize) < n_vocab {
1227 logits[prev_tok as usize] -= 2.0;
1228 }
1229 }
1230
1231 if text_tokens.len() >= 3 {
1233 let prev2 = text_tokens[text_tokens.len() - 2];
1234 let prev1 = text_tokens[text_tokens.len() - 1];
1235 for w in text_tokens[..text_tokens.len() - 2].windows(3) {
1236 if w[0] == prev2 && w[1] == prev1 && (w[2] as usize) < n_vocab {
1237 logits[w[2] as usize] -= 10.0;
1238 }
1239 }
1240 }
1241
1242 Self::suppress_degenerate_loops(logits, &text_tokens, eot_id, n_vocab);
1243 }
1244
1245 fn suppress_degenerate_loops(
1247 logits: &mut [f32],
1248 text_tokens: &[u32],
1249 eot_id: u32,
1250 n_vocab: usize,
1251 ) {
1252 if text_tokens.len() >= 12 {
1254 let mut fourgram_counts = std::collections::HashMap::<[u32; 4], u32>::new();
1255 for w in text_tokens.windows(4) {
1256 *fourgram_counts.entry([w[0], w[1], w[2], w[3]]).or_insert(0) += 1;
1257 }
1258 if fourgram_counts.values().any(|&c| c >= 3) && (eot_id as usize) < n_vocab {
1259 logits[eot_id as usize] = 100.0;
1260 return;
1261 }
1262 }
1263
1264 if text_tokens.len() >= 80 {
1266 let window = &text_tokens[text_tokens.len() - 80..];
1267 let mut freq = std::collections::HashMap::<u32, u32>::new();
1268 for &t in window {
1269 *freq.entry(t).or_insert(0) += 1;
1270 }
1271 for (&tok, &count) in &freq {
1272 if count >= 8 && (tok as usize) < n_vocab {
1273 logits[tok as usize] -= (count as f32 - 7.0) * 3.0;
1274 }
1275 }
1276 if freq.len() as f32 / 80.0 < 0.35 && (eot_id as usize) < n_vocab {
1277 logits[eot_id as usize] = 100.0;
1278 }
1279 }
1280 }
1281
1282 pub fn set_resampler(&mut self, input_rate: u32) -> WhisperResult<()> {
1287 if input_rate == audio::SAMPLE_RATE {
1288 self.resampler = None;
1289 } else {
1290 self.resampler = Some(audio::SincResampler::new(input_rate, audio::SAMPLE_RATE)?);
1291 }
1292 Ok(())
1293 }
1294
1295 pub fn resample(&self, audio: &[f32]) -> WhisperResult<Vec<f32>> {
1300 self.resampler
1301 .as_ref()
1302 .map_or_else(|| Ok(audio.to_vec()), |resampler| resampler.resample(audio))
1303 }
1304
1305 #[must_use]
1307 pub const fn tokenizer(&self) -> &tokenizer::Tokenizer {
1308 &self.tokenizer
1309 }
1310
1311 #[must_use]
1313 pub fn memory_size(&self) -> usize {
1314 let params = match self.config.model_type {
1316 ModelType::Tiny | ModelType::TinyEn => 39_000_000,
1317 ModelType::Base | ModelType::BaseEn => 74_000_000,
1318 ModelType::Small | ModelType::SmallEn => 244_000_000,
1319 ModelType::Medium | ModelType::MediumEn => 769_000_000,
1320 ModelType::Large | ModelType::LargeV1 | ModelType::LargeV2 | ModelType::LargeV3 => {
1321 1_550_000_000
1322 }
1323 ModelType::LargeV3Turbo => 809_000_000,
1324 };
1325 params * 4 }
1327
1328 #[must_use]
1333 pub fn has_weights(&self) -> bool {
1334 self.weights_loaded
1335 }
1336
1337 pub fn load_from_apr(data: &[u8]) -> WhisperResult<Self> {
1357 Self::load_from_apr_with_progress(data, &mut progress::null_callback)
1358 }
1359
1360 pub fn load_from_apr_with_progress(
1384 data: &[u8],
1385 callback: progress::ProgressCallback<'_>,
1386 ) -> WhisperResult<Self> {
1387 let mut tracker = progress::ProgressTracker::model_loading();
1389
1390 callback(&tracker.to_progress());
1392 let reader = format::AprV2ReaderRef::from_bytes(data)
1393 .map_err(|e| error::WhisperError::Format(e.to_string()))?;
1394 let config = format::metadata_to_model_config(reader.metadata());
1395 tracker.next_phase();
1396
1397 let is_f16 = reader
1398 .tensor_names()
1399 .iter()
1400 .find(|n| n.ends_with(".weight") && !n.starts_with("__"))
1401 .and_then(|n| reader.get_tensor(n))
1402 .map_or(false, |t| t.dtype == format::TensorDType::F16);
1403
1404 callback(&tracker.to_progress());
1406 let mut encoder = model::Encoder::new(&config);
1407 Self::load_encoder_weights(&reader, &mut encoder, &mut tracker, callback);
1408 encoder.finalize_weights();
1411 tracker.next_phase();
1412
1413 callback(&tracker.to_progress());
1415 let mut decoder = model::Decoder::new(&config);
1416 if is_f16 {
1417 Self::load_decoder_weights_f16(&reader, &mut decoder, &mut tracker, callback);
1419 } else {
1420 Self::load_decoder_weights(&reader, &mut decoder, &mut tracker, callback);
1421 decoder.convert_to_f16();
1424 }
1425 decoder.finalize_weights();
1427 tracker.next_phase();
1428
1429 callback(&tracker.to_progress());
1431 let tokenizer = Self::build_tokenizer(&config, &reader);
1433 tracker.next_phase();
1434
1435 callback(&tracker.to_progress());
1437 let (mel_filters, conv_stem) = match config.audio_frontend {
1438 model::AudioFrontend::MelFilterbank => {
1439 let mel_config = audio::MelConfig {
1440 n_mels: config.n_mels as usize,
1441 ..audio::MelConfig::whisper()
1442 };
1443 let mf = Self::read_mel_filterbank(&reader).map_or_else(
1444 || audio::MelFilterbank::new(&mel_config),
1445 |fb| audio::MelFilterbank::from_filters(fb.data, &mel_config),
1446 );
1447 (Some(mf), None)
1448 }
1449 model::AudioFrontend::LearnedConv => {
1450 let d_model = config.n_audio_state as usize;
1451 let mut stem = audio::ConvStem::new(d_model);
1452 Self::load_conv_stem_weights(&reader, &mut stem);
1453 (None, Some(stem))
1454 }
1455 };
1456 tracker.complete();
1457 callback(&tracker.to_progress());
1458
1459 Ok(Self {
1460 config,
1461 encoder,
1462 decoder,
1463 tokenizer,
1464 mel_filters,
1465 conv_stem,
1466 resampler: None,
1467 weights_loaded: true,
1468 })
1469 }
1470
1471 fn build_tokenizer(
1473 config: &model::ModelConfig,
1474 reader: &format::AprV2ReaderRef<'_>,
1475 ) -> tokenizer::Tokenizer {
1476 if config.model_family == format::ModelFamily::Moonshine {
1477 let mut sp = tokenizer::SentencePieceTokenizer::moonshine_default();
1478 if let Some(vocab) = Self::read_vocabulary(reader) {
1479 Self::populate_sentencepiece(&mut sp, &vocab);
1480 }
1481 tokenizer::Tokenizer::SentencePiece(sp)
1482 } else {
1483 tokenizer::Tokenizer::Bpe(Self::read_vocabulary(reader).map_or_else(
1484 tokenizer::BpeTokenizer::with_base_tokens,
1485 tokenizer::BpeTokenizer::from_vocabulary,
1486 ))
1487 }
1488 }
1489
1490 fn read_vocabulary(reader: &format::AprV2ReaderRef<'_>) -> Option<tokenizer::Vocabulary> {
1492 let raw = reader.get_tensor_data("__vocab__")?;
1493 tokenizer::Vocabulary::from_bytes(raw)
1494 }
1495
1496 fn read_mel_filterbank(
1498 reader: &format::AprV2ReaderRef<'_>,
1499 ) -> Option<format::MelFilterbankData> {
1500 let data = reader.get_tensor_as_f32("__mel_filters__")?;
1501 let entry = reader.get_tensor("__mel_filters__")?;
1502 let shape = &entry.shape;
1503 if shape.len() == 2 {
1504 Some(format::MelFilterbankData::new(
1505 shape[0] as u32,
1506 shape[1] as u32,
1507 data,
1508 ))
1509 } else {
1510 None
1511 }
1512 }
1513
1514 fn load_f16_raw(reader: &format::AprV2ReaderRef<'_>, name: &str) -> Option<Vec<u16>> {
1516 let raw = reader.get_tensor_data(name)?;
1517 Some(
1518 raw.chunks_exact(2)
1519 .map(|b| u16::from_le_bytes([b[0], b[1]]))
1520 .collect(),
1521 )
1522 }
1523
1524 fn populate_sentencepiece(
1526 sp: &mut tokenizer::SentencePieceTokenizer,
1527 vocab: &tokenizer::Vocabulary,
1528 ) {
1529 for id in 0..vocab.len() as u32 {
1530 if let Some(bytes) = vocab.get_bytes(id) {
1531 if let Ok(piece) = core::str::from_utf8(bytes) {
1532 if !piece.is_empty() {
1533 sp.add_piece(id, piece);
1534 }
1535 }
1536 }
1537 }
1538 }
1539
1540 fn load_encoder_weights(
1542 reader: &format::AprV2ReaderRef<'_>,
1543 encoder: &mut model::Encoder,
1544 tracker: &mut progress::ProgressTracker,
1545 callback: progress::ProgressCallback<'_>,
1546 ) {
1547 let n_layers = encoder.n_layers();
1548
1549 if let Some(conv_frontend) = encoder.conv_frontend_mut() {
1551 if let Some(weight) = reader.get_tensor_as_f32("encoder.conv1.weight") {
1553 let target = conv_frontend.conv1.weight_mut();
1554 let len = weight.len().min(target.len());
1555 target[..len].copy_from_slice(&weight[..len]);
1556 }
1557 if let Some(bias) = reader.get_tensor_as_f32("encoder.conv1.bias") {
1558 let target = conv_frontend.conv1.bias_mut();
1559 let len = bias.len().min(target.len());
1560 target[..len].copy_from_slice(&bias[..len]);
1561 }
1562
1563 if let Some(weight) = reader.get_tensor_as_f32("encoder.conv2.weight") {
1565 let target = conv_frontend.conv2.weight_mut();
1566 let len = weight.len().min(target.len());
1567 target[..len].copy_from_slice(&weight[..len]);
1568 }
1569 if let Some(bias) = reader.get_tensor_as_f32("encoder.conv2.bias") {
1570 let target = conv_frontend.conv2.bias_mut();
1571 let len = bias.len().min(target.len());
1572 target[..len].copy_from_slice(&bias[..len]);
1573 }
1574 }
1575
1576 let pe_result = reader
1578 .get_tensor_as_f32("encoder.embed_positions.weight")
1579 .or_else(|| reader.get_tensor_as_f32("encoder.positional_embedding"));
1580 if let Some(pe) = pe_result {
1581 let target = encoder.positional_embedding_mut();
1582 let len = pe.len().min(target.len());
1583 target[..len].copy_from_slice(&pe[..len]);
1584 }
1585
1586 if !encoder.moonshine_blocks().is_empty() {
1588 for layer_idx in 0..n_layers {
1590 let progress = layer_idx as f32 / n_layers as f32;
1591 tracker.update_phase_progress(progress);
1592 callback(&tracker.to_progress());
1593
1594 let block = &mut encoder.moonshine_blocks_mut()[layer_idx];
1595
1596 Self::load_layernorm_nobias_weights(
1598 reader,
1599 &format!("encoder.blocks.{layer_idx}.ln1"),
1600 &mut block.ln1,
1601 );
1602
1603 Self::load_gqa_weights(
1605 reader,
1606 &format!("encoder.blocks.{layer_idx}.attn"),
1607 &mut block.self_attn,
1608 );
1609
1610 Self::load_layernorm_nobias_weights(
1612 reader,
1613 &format!("encoder.blocks.{layer_idx}.ln2"),
1614 &mut block.ln2,
1615 );
1616
1617 Self::load_mlp_weights(
1619 reader,
1620 &format!("encoder.blocks.{layer_idx}.ffn"),
1621 &mut block.ffn,
1622 );
1623 }
1624 } else {
1625 for layer_idx in 0..n_layers {
1627 let progress = layer_idx as f32 / n_layers as f32;
1628 tracker.update_phase_progress(progress);
1629 callback(&tracker.to_progress());
1630
1631 let block = &mut encoder.blocks_mut()[layer_idx];
1632
1633 Self::load_layer_norm_weights(
1635 reader,
1636 &format!("encoder.layers.{layer_idx}.self_attn_layer_norm"),
1637 &mut block.ln1,
1638 );
1639
1640 Self::load_attention_weights(
1642 reader,
1643 &format!("encoder.layers.{layer_idx}.self_attn"),
1644 &mut block.self_attn,
1645 );
1646
1647 Self::load_layer_norm_weights(
1649 reader,
1650 &format!("encoder.layers.{layer_idx}.final_layer_norm"),
1651 &mut block.ln2,
1652 );
1653
1654 Self::load_ffn_weights(
1656 reader,
1657 &format!("encoder.layers.{layer_idx}"),
1658 &mut block.ffn,
1659 );
1660 }
1661 }
1662
1663 if let Some(ln) = encoder.ln_post_rms_mut() {
1665 Self::load_layernorm_nobias_weights(reader, "encoder.layer_norm", ln);
1667 } else {
1668 Self::load_layer_norm_weights(reader, "encoder.layer_norm", encoder.ln_post_mut());
1670 }
1671 }
1672
1673 fn load_decoder_weights(
1675 reader: &format::AprV2ReaderRef<'_>,
1676 decoder: &mut model::Decoder,
1677 tracker: &mut progress::ProgressTracker,
1678 callback: progress::ProgressCallback<'_>,
1679 ) {
1680 let n_layers = decoder.n_layers();
1681
1682 let te_result = reader
1687 .get_tensor_as_f32("decoder.embed_tokens.weight")
1688 .or_else(|| reader.get_tensor_as_f32("decoder.token_embedding.weight"))
1689 .or_else(|| reader.get_tensor_as_f32("decoder.token_embedding"));
1690 if let Some(te) = te_result {
1691 let target = decoder.token_embedding_mut();
1692 let len = te.len().min(target.len());
1693 target[..len].copy_from_slice(&te[..len]);
1694 }
1695
1696 let pe_result = reader
1698 .get_tensor_as_f32("decoder.embed_positions.weight")
1699 .or_else(|| reader.get_tensor_as_f32("decoder.positional_embedding"));
1700 if let Some(pe) = pe_result {
1701 let target = decoder.positional_embedding_mut();
1702 let len = pe.len().min(target.len());
1703 target[..len].copy_from_slice(&pe[..len]);
1704 }
1705
1706 if !decoder.moonshine_blocks().is_empty() {
1708 for layer_idx in 0..n_layers {
1710 let progress = layer_idx as f32 / n_layers as f32;
1711 tracker.update_phase_progress(progress);
1712 callback(&tracker.to_progress());
1713
1714 let block = &mut decoder.moonshine_blocks_mut()[layer_idx];
1715
1716 Self::load_layernorm_nobias_weights(
1718 reader,
1719 &format!("decoder.blocks.{layer_idx}.ln1"),
1720 &mut block.ln1,
1721 );
1722
1723 Self::load_gqa_weights(
1725 reader,
1726 &format!("decoder.blocks.{layer_idx}.attn"),
1727 &mut block.self_attn,
1728 );
1729
1730 Self::load_layernorm_nobias_weights(
1732 reader,
1733 &format!("decoder.blocks.{layer_idx}.ln_cross"),
1734 &mut block.ln_cross,
1735 );
1736
1737 Self::load_gqa_weights(
1739 reader,
1740 &format!("decoder.blocks.{layer_idx}.cross_attn"),
1741 &mut block.cross_attn,
1742 );
1743
1744 Self::load_layernorm_nobias_weights(
1746 reader,
1747 &format!("decoder.blocks.{layer_idx}.ln2"),
1748 &mut block.ln2,
1749 );
1750
1751 Self::load_gated_mlp_weights(
1753 reader,
1754 &format!("decoder.blocks.{layer_idx}.ffn"),
1755 &mut block.ffn,
1756 );
1757 }
1758
1759 if let Some(ln) = decoder.ln_post_rms_mut() {
1761 Self::load_layernorm_nobias_weights(reader, "decoder.ln_post", ln);
1762 }
1763 } else {
1764 for layer_idx in 0..n_layers {
1766 let progress = layer_idx as f32 / n_layers as f32;
1767 tracker.update_phase_progress(progress);
1768 callback(&tracker.to_progress());
1769
1770 let block = &mut decoder.blocks_mut()[layer_idx];
1771
1772 Self::load_layer_norm_weights(
1774 reader,
1775 &format!("decoder.layers.{layer_idx}.self_attn_layer_norm"),
1776 &mut block.ln1,
1777 );
1778
1779 Self::load_attention_weights(
1781 reader,
1782 &format!("decoder.layers.{layer_idx}.self_attn"),
1783 &mut block.self_attn,
1784 );
1785
1786 Self::load_layer_norm_weights(
1788 reader,
1789 &format!("decoder.layers.{layer_idx}.encoder_attn_layer_norm"),
1790 &mut block.ln2,
1791 );
1792
1793 Self::load_attention_weights(
1795 reader,
1796 &format!("decoder.layers.{layer_idx}.encoder_attn"),
1797 &mut block.cross_attn,
1798 );
1799
1800 Self::load_layer_norm_weights(
1802 reader,
1803 &format!("decoder.layers.{layer_idx}.final_layer_norm"),
1804 &mut block.ln3,
1805 );
1806
1807 Self::load_ffn_weights(
1809 reader,
1810 &format!("decoder.layers.{layer_idx}"),
1811 &mut block.ffn,
1812 );
1813 }
1814
1815 Self::load_layer_norm_weights(reader, "decoder.layer_norm", decoder.ln_post_mut());
1817 }
1818
1819 decoder.finalize_weights();
1821 }
1822
1823 fn load_decoder_weights_f16(
1831 reader: &format::AprV2ReaderRef<'_>,
1832 decoder: &mut model::Decoder,
1833 tracker: &mut progress::ProgressTracker,
1834 callback: progress::ProgressCallback<'_>,
1835 ) {
1836 let n_layers = decoder.n_layers();
1837
1838 let te_result = reader
1840 .get_tensor_as_f32("decoder.embed_tokens.weight")
1841 .or_else(|| reader.get_tensor_as_f32("decoder.token_embedding.weight"))
1842 .or_else(|| reader.get_tensor_as_f32("decoder.token_embedding"));
1843 if let Some(te) = te_result {
1844 let target = decoder.token_embedding_mut();
1845 let len = te.len().min(target.len());
1846 target[..len].copy_from_slice(&te[..len]);
1847 }
1848
1849 let pe_result = reader
1851 .get_tensor_as_f32("decoder.embed_positions.weight")
1852 .or_else(|| reader.get_tensor_as_f32("decoder.positional_embedding"));
1853 if let Some(pe) = pe_result {
1854 let target = decoder.positional_embedding_mut();
1855 let len = pe.len().min(target.len());
1856 target[..len].copy_from_slice(&pe[..len]);
1857 }
1858
1859 for layer_idx in 0..n_layers {
1861 let progress = layer_idx as f32 / n_layers as f32;
1862 tracker.update_phase_progress(progress);
1863 callback(&tracker.to_progress());
1864
1865 let block = &mut decoder.blocks_mut()[layer_idx];
1866
1867 Self::load_layer_norm_weights(
1869 reader,
1870 &format!("decoder.layers.{layer_idx}.self_attn_layer_norm"),
1871 &mut block.ln1,
1872 );
1873
1874 Self::load_attention_weights_f16(
1876 reader,
1877 &format!("decoder.layers.{layer_idx}.self_attn"),
1878 &mut block.self_attn,
1879 );
1880
1881 Self::load_layer_norm_weights(
1882 reader,
1883 &format!("decoder.layers.{layer_idx}.encoder_attn_layer_norm"),
1884 &mut block.ln2,
1885 );
1886
1887 Self::load_attention_weights_f16(
1889 reader,
1890 &format!("decoder.layers.{layer_idx}.encoder_attn"),
1891 &mut block.cross_attn,
1892 );
1893
1894 Self::load_layer_norm_weights(
1895 reader,
1896 &format!("decoder.layers.{layer_idx}.final_layer_norm"),
1897 &mut block.ln3,
1898 );
1899
1900 Self::load_ffn_weights_f16(
1902 reader,
1903 &format!("decoder.layers.{layer_idx}"),
1904 &mut block.ffn,
1905 );
1906 }
1907
1908 Self::load_layer_norm_weights(reader, "decoder.layer_norm", decoder.ln_post_mut());
1910
1911 decoder.finalize_weights();
1913 decoder.convert_embeddings_to_f16();
1914 }
1915
1916 fn load_layer_norm_weights(
1918 reader: &format::AprV2ReaderRef<'_>,
1919 prefix: &str,
1920 ln: &mut model::LayerNorm,
1921 ) {
1922 if let Some(weight) = reader.get_tensor_as_f32(&format!("{prefix}.weight")) {
1923 let len = weight.len().min(ln.weight.len());
1924 ln.weight[..len].copy_from_slice(&weight[..len]);
1925 }
1926 if let Some(bias) = reader.get_tensor_as_f32(&format!("{prefix}.bias")) {
1927 let len = bias.len().min(ln.bias.len());
1928 ln.bias[..len].copy_from_slice(&bias[..len]);
1929 }
1930 }
1931
1932 fn load_attention_weights(
1934 reader: &format::AprV2ReaderRef<'_>,
1935 prefix: &str,
1936 attn: &mut model::MultiHeadAttention,
1937 ) {
1938 if let Some(q_weight) = reader.get_tensor_as_f32(&format!("{prefix}.q_proj.weight")) {
1940 attn.set_query_weight(&q_weight);
1941 }
1942 if let Some(q_bias) = reader.get_tensor_as_f32(&format!("{prefix}.q_proj.bias")) {
1943 attn.set_query_bias(&q_bias);
1944 }
1945 if let Some(k_weight) = reader.get_tensor_as_f32(&format!("{prefix}.k_proj.weight")) {
1946 attn.set_key_weight(&k_weight);
1947 }
1948 if let Some(k_bias) = reader.get_tensor_as_f32(&format!("{prefix}.k_proj.bias")) {
1949 attn.set_key_bias(&k_bias);
1950 }
1951 if let Some(v_weight) = reader.get_tensor_as_f32(&format!("{prefix}.v_proj.weight")) {
1952 attn.set_value_weight(&v_weight);
1953 }
1954 if let Some(v_bias) = reader.get_tensor_as_f32(&format!("{prefix}.v_proj.bias")) {
1955 attn.set_value_bias(&v_bias);
1956 }
1957 if let Some(out_weight) = reader.get_tensor_as_f32(&format!("{prefix}.out_proj.weight")) {
1958 attn.set_out_weight(&out_weight);
1959 }
1960 if let Some(out_bias) = reader.get_tensor_as_f32(&format!("{prefix}.out_proj.bias")) {
1961 attn.set_out_bias(&out_bias);
1962 }
1963 }
1964
1965 fn load_ffn_weights(
1967 reader: &format::AprV2ReaderRef<'_>,
1968 prefix: &str,
1969 ffn: &mut model::FeedForward,
1970 ) {
1971 if let Some(fc1_weight) = reader.get_tensor_as_f32(&format!("{prefix}.fc1.weight")) {
1972 ffn.fc1.set_weight(&fc1_weight);
1973 }
1974 if let Some(fc1_bias) = reader.get_tensor_as_f32(&format!("{prefix}.fc1.bias")) {
1975 ffn.fc1.set_bias(&fc1_bias);
1976 }
1977 if let Some(fc2_weight) = reader.get_tensor_as_f32(&format!("{prefix}.fc2.weight")) {
1978 ffn.fc2.set_weight(&fc2_weight);
1979 }
1980 if let Some(fc2_bias) = reader.get_tensor_as_f32(&format!("{prefix}.fc2.bias")) {
1981 ffn.fc2.set_bias(&fc2_bias);
1982 }
1983 }
1984
1985 fn load_attention_weights_f16(
1990 reader: &format::AprV2ReaderRef<'_>,
1991 prefix: &str,
1992 attn: &mut model::MultiHeadAttention,
1993 ) {
1994 if let Some(q_weight) = Self::load_f16_raw(reader, &format!("{prefix}.q_proj.weight")) {
1996 attn.set_query_weight_f16(&q_weight);
1997 }
1998 if let Some(k_weight) = Self::load_f16_raw(reader, &format!("{prefix}.k_proj.weight")) {
1999 attn.set_key_weight_f16(&k_weight);
2000 }
2001 if let Some(v_weight) = Self::load_f16_raw(reader, &format!("{prefix}.v_proj.weight")) {
2002 attn.set_value_weight_f16(&v_weight);
2003 }
2004 if let Some(out_weight) = Self::load_f16_raw(reader, &format!("{prefix}.out_proj.weight")) {
2005 attn.set_out_weight_f16(&out_weight);
2006 }
2007 if let Some(q_bias) = reader.get_tensor_as_f32(&format!("{prefix}.q_proj.bias")) {
2009 attn.set_query_bias(&q_bias);
2010 }
2011 if let Some(k_bias) = reader.get_tensor_as_f32(&format!("{prefix}.k_proj.bias")) {
2012 attn.set_key_bias(&k_bias);
2013 }
2014 if let Some(v_bias) = reader.get_tensor_as_f32(&format!("{prefix}.v_proj.bias")) {
2015 attn.set_value_bias(&v_bias);
2016 }
2017 if let Some(out_bias) = reader.get_tensor_as_f32(&format!("{prefix}.out_proj.bias")) {
2018 attn.set_out_bias(&out_bias);
2019 }
2020 }
2021
2022 fn load_ffn_weights_f16(
2026 reader: &format::AprV2ReaderRef<'_>,
2027 prefix: &str,
2028 ffn: &mut model::FeedForward,
2029 ) {
2030 if let Some(fc1_weight) = Self::load_f16_raw(reader, &format!("{prefix}.fc1.weight")) {
2031 ffn.fc1.set_weight_f16(&fc1_weight);
2032 }
2033 if let Some(fc1_bias) = reader.get_tensor_as_f32(&format!("{prefix}.fc1.bias")) {
2034 ffn.fc1.set_bias(&fc1_bias);
2035 }
2036 if let Some(fc2_weight) = Self::load_f16_raw(reader, &format!("{prefix}.fc2.weight")) {
2037 ffn.fc2.set_weight_f16(&fc2_weight);
2038 }
2039 if let Some(fc2_bias) = reader.get_tensor_as_f32(&format!("{prefix}.fc2.bias")) {
2040 ffn.fc2.set_bias(&fc2_bias);
2041 }
2042 }
2043
2044 #[allow(dead_code)]
2046 fn load_rms_norm_weights(
2047 reader: &format::AprV2ReaderRef<'_>,
2048 prefix: &str,
2049 rms: &mut model::lfm2::layer::RmsNorm,
2050 ) {
2051 if let Some(weight) = reader.get_tensor_as_f32(&format!("{prefix}.weight")) {
2052 let len = weight.len().min(rms.weight.len());
2053 rms.weight[..len].copy_from_slice(&weight[..len]);
2054 }
2055 }
2056
2057 fn load_layernorm_nobias_weights(
2059 reader: &format::AprV2ReaderRef<'_>,
2060 prefix: &str,
2061 ln: &mut model::lfm2::layer::LayerNormNoBias,
2062 ) {
2063 if let Some(weight) = reader.get_tensor_as_f32(&format!("{prefix}.weight")) {
2064 let len = weight.len().min(ln.weight.len());
2065 ln.weight[..len].copy_from_slice(&weight[..len]);
2066 }
2067 }
2068
2069 fn load_gqa_weights(
2071 reader: &format::AprV2ReaderRef<'_>,
2072 prefix: &str,
2073 gqa: &mut model::lfm2::gqa::GroupedQueryAttention,
2074 ) {
2075 if let Some(w) = reader.get_tensor_as_f32(&format!("{prefix}.q.weight")) {
2076 let len = w.len().min(gqa.w_q.len());
2077 gqa.w_q[..len].copy_from_slice(&w[..len]);
2078 }
2079 if let Some(w) = reader.get_tensor_as_f32(&format!("{prefix}.k.weight")) {
2080 let len = w.len().min(gqa.w_k.len());
2081 gqa.w_k[..len].copy_from_slice(&w[..len]);
2082 }
2083 if let Some(w) = reader.get_tensor_as_f32(&format!("{prefix}.v.weight")) {
2084 let len = w.len().min(gqa.w_v.len());
2085 gqa.w_v[..len].copy_from_slice(&w[..len]);
2086 }
2087 if let Some(w) = reader.get_tensor_as_f32(&format!("{prefix}.o.weight")) {
2088 let len = w.len().min(gqa.w_o.len());
2089 gqa.w_o[..len].copy_from_slice(&w[..len]);
2090 }
2091 }
2092
2093 fn load_mlp_weights(
2095 reader: &format::AprV2ReaderRef<'_>,
2096 prefix: &str,
2097 ffn: &mut model::lfm2::mlp::MlpFfn,
2098 ) {
2099 if let Some(w) = reader.get_tensor_as_f32(&format!("{prefix}.fc1.weight")) {
2100 let len = w.len().min(ffn.fc1.len());
2101 ffn.fc1[..len].copy_from_slice(&w[..len]);
2102 }
2103 if let Some(b) = reader.get_tensor_as_f32(&format!("{prefix}.fc1.bias")) {
2104 ffn.b1 = Some(b);
2105 }
2106 if let Some(w) = reader.get_tensor_as_f32(&format!("{prefix}.fc2.weight")) {
2107 let len = w.len().min(ffn.fc2.len());
2108 ffn.fc2[..len].copy_from_slice(&w[..len]);
2109 }
2110 if let Some(b) = reader.get_tensor_as_f32(&format!("{prefix}.fc2.bias")) {
2111 ffn.b2 = Some(b);
2112 }
2113 }
2114
2115 fn load_gated_mlp_weights(
2117 reader: &format::AprV2ReaderRef<'_>,
2118 prefix: &str,
2119 ffn: &mut model::lfm2::mlp::GatedMlpFfn,
2120 ) {
2121 if let Some(w) = reader.get_tensor_as_f32(&format!("{prefix}.fc1.weight")) {
2122 let len = w.len().min(ffn.fc1.len());
2123 ffn.fc1[..len].copy_from_slice(&w[..len]);
2124 }
2125 if let Some(b) = reader.get_tensor_as_f32(&format!("{prefix}.fc1.bias")) {
2126 ffn.b1 = Some(b);
2127 }
2128 if let Some(w) = reader.get_tensor_as_f32(&format!("{prefix}.fc2.weight")) {
2129 let len = w.len().min(ffn.fc2.len());
2130 ffn.fc2[..len].copy_from_slice(&w[..len]);
2131 }
2132 if let Some(b) = reader.get_tensor_as_f32(&format!("{prefix}.fc2.bias")) {
2133 ffn.b2 = Some(b);
2134 }
2135 }
2136
2137 fn load_conv_stem_weights(reader: &format::AprV2ReaderRef<'_>, stem: &mut audio::ConvStem) {
2139 if let Some(w) = reader.get_tensor_as_f32("encoder.conv1.weight") {
2141 let target = stem.conv1.weight_mut();
2142 let len = w.len().min(target.len());
2143 target[..len].copy_from_slice(&w[..len]);
2144 }
2145
2146 if let Some(w) = reader.get_tensor_as_f32("encoder.conv2.weight") {
2148 let target = stem.conv2.weight_mut();
2149 let len = w.len().min(target.len());
2150 target[..len].copy_from_slice(&w[..len]);
2151 }
2152 if let Some(b) = reader.get_tensor_as_f32("encoder.conv2.bias") {
2153 let target = stem.conv2.bias_mut();
2154 let len = b.len().min(target.len());
2155 target[..len].copy_from_slice(&b[..len]);
2156 }
2157
2158 if let Some(w) = reader.get_tensor_as_f32("encoder.conv3.weight") {
2160 let target = stem.conv3.weight_mut();
2161 let len = w.len().min(target.len());
2162 target[..len].copy_from_slice(&w[..len]);
2163 }
2164 if let Some(b) = reader.get_tensor_as_f32("encoder.conv3.bias") {
2165 let target = stem.conv3.bias_mut();
2166 let len = b.len().min(target.len());
2167 target[..len].copy_from_slice(&b[..len]);
2168 }
2169
2170 if let Some(w) = reader.get_tensor_as_f32("encoder.groupnorm.weight") {
2172 let len = w.len().min(stem.groupnorm.weight.len());
2173 stem.groupnorm.weight[..len].copy_from_slice(&w[..len]);
2174 }
2175 if let Some(b) = reader.get_tensor_as_f32("encoder.groupnorm.bias") {
2176 let len = b.len().min(stem.groupnorm.bias.len());
2177 stem.groupnorm.bias[..len].copy_from_slice(&b[..len]);
2178 }
2179
2180 }
2183
2184 pub fn encoder_mut(&mut self) -> &mut model::Encoder {
2186 &mut self.encoder
2187 }
2188
2189 pub fn decoder_mut(&mut self) -> &mut model::Decoder {
2191 &mut self.decoder
2192 }
2193
2194 pub fn encoder(&self) -> &model::Encoder {
2196 &self.encoder
2197 }
2198
2199 pub fn decoder(&self) -> &model::Decoder {
2201 &self.decoder
2202 }
2203
2204 #[cfg(feature = "realizar-gpu")]
2214 pub fn into_cuda(self, device_ordinal: i32) -> WhisperResult<cuda::WhisperCuda> {
2215 let mel_filters = self.mel_filters.ok_or_else(|| {
2216 WhisperError::Audio("CUDA requires mel filterbank (Whisper model)".into())
2217 })?;
2218 let bpe_tokenizer = match self.tokenizer {
2219 tokenizer::Tokenizer::Bpe(bpe) => bpe,
2220 tokenizer::Tokenizer::SentencePiece(_) => {
2221 return Err(WhisperError::Model(
2222 "CUDA backend requires Whisper BPE tokenizer".into(),
2223 ));
2224 }
2225 };
2226 cuda::WhisperCuda::new_with_components(
2227 self.encoder,
2228 self.decoder,
2229 self.config,
2230 bpe_tokenizer,
2231 mel_filters,
2232 device_ordinal,
2233 )
2234 }
2235
2236 #[must_use]
2238 pub const fn conv_stem(&self) -> Option<&audio::ConvStem> {
2239 self.conv_stem.as_ref()
2240 }
2241
2242 #[must_use]
2244 pub const fn mel_filters(&self) -> Option<&audio::MelFilterbank> {
2245 self.mel_filters.as_ref()
2246 }
2247
2248 pub fn transcribe_batch(
2277 &self,
2278 audio_batch: &[Vec<f32>],
2279 options: TranscribeOptions,
2280 ) -> WhisperResult<BatchTranscriptionResult> {
2281 if audio_batch.is_empty() {
2282 return Err(WhisperError::Audio("empty batch".into()));
2283 }
2284
2285 let start_time = std::time::Instant::now();
2286 let mut results = Vec::with_capacity(audio_batch.len());
2287
2288 for audio in audio_batch {
2290 let result = self.transcribe(audio, options.clone())?;
2291 results.push(result);
2292 }
2293
2294 let total_duration_secs = start_time.elapsed().as_secs_f32();
2295
2296 Ok(BatchTranscriptionResult {
2297 results,
2298 total_duration_secs,
2299 })
2300 }
2301
2302 pub fn transcribe_audio_batch(
2317 &self,
2318 batch: &audio::AudioBatch,
2319 options: TranscribeOptions,
2320 ) -> WhisperResult<BatchTranscriptionResult> {
2321 if batch.is_empty() {
2322 return Err(WhisperError::Audio("empty batch".into()));
2323 }
2324
2325 let start_time = std::time::Instant::now();
2326
2327 let preprocessor = audio::BatchPreprocessor::new(audio::MelConfig::default());
2329 let mel_result = preprocessor.process_batch(batch)?;
2330
2331 let mut results = Vec::with_capacity(batch.len());
2332 let language = options.language.clone().unwrap_or_else(|| "en".to_string());
2333
2334 for mel in &mel_result.mels {
2336 let audio_features = self.encode(mel)?;
2338
2339 let initial_tokens = self.get_initial_tokens(&language, options.task);
2341
2342 let tokens = self.decode(&audio_features, &initial_tokens, &options)?;
2344
2345 let segments = if timestamps::has_timestamps(&tokens) {
2347 timestamps::extract_segments(&tokens, |ts| self.tokenizer.decode(ts).ok())
2348 } else {
2349 Vec::new()
2350 };
2351
2352 let text = self.tokenizer.decode(&tokens)?;
2354
2355 results.push(TranscriptionResult {
2356 text,
2357 language: language.clone(),
2358 segments,
2359 profiling: None,
2360 });
2361 }
2362
2363 let total_duration_secs = start_time.elapsed().as_secs_f32();
2364
2365 Ok(BatchTranscriptionResult {
2366 results,
2367 total_duration_secs,
2368 })
2369 }
2370
2371 #[must_use]
2373 pub fn create_audio_batch(audio_segments: &[Vec<f32>]) -> audio::AudioBatch {
2374 let mut batch = audio::AudioBatch::with_default_config();
2375 for segment in audio_segments {
2376 batch.add_segment(segment.clone());
2377 }
2378 batch
2379 }
2380
2381 pub fn transcribe_batch_optimized(
2396 &self,
2397 audio_batch: &[Vec<f32>],
2398 options: TranscribeOptions,
2399 ) -> WhisperResult<BatchTranscriptionResult> {
2400 if audio_batch.is_empty() {
2401 return Err(WhisperError::Audio("empty batch".into()));
2402 }
2403
2404 let start_time = std::time::Instant::now();
2405
2406 let mut mels = Vec::with_capacity(audio_batch.len());
2408 for audio in audio_batch {
2409 let mel = self.compute_mel(audio)?;
2410 mels.push(mel);
2411 }
2412
2413 let encoder_outputs = self.encoder.forward_batch(&mels)?;
2415
2416 let mut results = Vec::with_capacity(audio_batch.len());
2417 let language = options.language.clone().unwrap_or_else(|| "en".to_string());
2418
2419 for features in &encoder_outputs {
2421 let initial_tokens = self.get_initial_tokens(&language, options.task);
2422 let tokens = self.decode(features, &initial_tokens, &options)?;
2423
2424 let segments = if timestamps::has_timestamps(&tokens) {
2425 timestamps::extract_segments(&tokens, |ts| self.tokenizer.decode(ts).ok())
2426 } else {
2427 Vec::new()
2428 };
2429
2430 let text = self.tokenizer.decode(&tokens)?;
2431
2432 results.push(TranscriptionResult {
2433 text,
2434 language: language.clone(),
2435 segments,
2436 profiling: None,
2437 });
2438 }
2439
2440 let total_duration_secs = start_time.elapsed().as_secs_f32();
2441
2442 Ok(BatchTranscriptionResult {
2443 results,
2444 total_duration_secs,
2445 })
2446 }
2447
2448 pub fn transcribe_with_vad(
2483 &self,
2484 audio: &[f32],
2485 options: TranscribeOptions,
2486 vad_config: Option<vad::VadConfig>,
2487 ) -> WhisperResult<VadTranscriptionResult> {
2488 let start_time = std::time::Instant::now();
2489
2490 let config = vad_config.unwrap_or_default();
2492 let mut vad = vad::VoiceActivityDetector::new(config);
2493
2494 let speech_segments = vad.detect(audio);
2496
2497 if speech_segments.is_empty() {
2498 return Ok(VadTranscriptionResult {
2499 text: String::new(),
2500 language: options.language.unwrap_or_else(|| "en".to_string()),
2501 segments: Vec::new(),
2502 speech_segments: Vec::new(),
2503 total_duration_secs: start_time.elapsed().as_secs_f32(),
2504 speech_duration_secs: 0.0,
2505 });
2506 }
2507
2508 let sample_rate = audio::SAMPLE_RATE as f32;
2510 let mut speech_audios = Vec::with_capacity(speech_segments.len());
2511 let mut speech_duration = 0.0f32;
2512
2513 for segment in &speech_segments {
2514 let start_sample = (segment.start * sample_rate) as usize;
2515 let end_sample = ((segment.end * sample_rate) as usize).min(audio.len());
2516
2517 if end_sample > start_sample {
2518 speech_audios.push((
2519 segment.start,
2520 segment.end,
2521 audio[start_sample..end_sample].to_vec(),
2522 ));
2523 speech_duration += segment.duration();
2524 }
2525 }
2526
2527 let mut all_segments = Vec::new();
2529 let mut full_text = String::new();
2530 let language = options.language.clone().unwrap_or_else(|| "en".to_string());
2531
2532 for (seg_start, seg_end, speech_audio) in &speech_audios {
2533 let result = self.transcribe(speech_audio, options.clone())?;
2535
2536 let segment = VadSpeechSegment {
2538 start: *seg_start,
2539 end: *seg_end,
2540 text: result.text.clone(),
2541 tokens: result
2542 .segments
2543 .first()
2544 .map(|s| s.tokens.clone())
2545 .unwrap_or_default(),
2546 };
2547
2548 if !full_text.is_empty() {
2549 full_text.push(' ');
2550 }
2551 full_text.push_str(&result.text);
2552
2553 all_segments.push(segment);
2554 }
2555
2556 let total_duration_secs = start_time.elapsed().as_secs_f32();
2557
2558 Ok(VadTranscriptionResult {
2559 text: full_text,
2560 language,
2561 segments: all_segments,
2562 speech_segments: speech_segments
2563 .into_iter()
2564 .map(|s| (s.start, s.end))
2565 .collect(),
2566 total_duration_secs,
2567 speech_duration_secs: speech_duration,
2568 })
2569 }
2570
2571 pub fn transcribe_with_silence_detection(
2587 &self,
2588 audio: &[f32],
2589 options: TranscribeOptions,
2590 silence_config: Option<vad::SilenceConfig>,
2591 ) -> WhisperResult<VadTranscriptionResult> {
2592 let start_time = std::time::Instant::now();
2593
2594 let config = silence_config.unwrap_or_default();
2596 let mut detector = vad::SilenceDetector::new(config, audio::SAMPLE_RATE);
2597
2598 let frame_size = 480; let silence_segments = detector.detect(audio, frame_size);
2601
2602 let speech_segments = self.invert_silence_segments(&silence_segments, audio.len());
2604
2605 if speech_segments.is_empty() {
2606 return Ok(VadTranscriptionResult {
2607 text: String::new(),
2608 language: options.language.unwrap_or_else(|| "en".to_string()),
2609 segments: Vec::new(),
2610 speech_segments: Vec::new(),
2611 total_duration_secs: start_time.elapsed().as_secs_f32(),
2612 speech_duration_secs: 0.0,
2613 });
2614 }
2615
2616 let sample_rate = audio::SAMPLE_RATE as f32;
2618 let mut all_segments = Vec::new();
2619 let mut full_text = String::new();
2620 let mut speech_duration = 0.0f32;
2621 let language = options.language.clone().unwrap_or_else(|| "en".to_string());
2622
2623 for (start, end) in &speech_segments {
2624 let start_sample = (start * sample_rate) as usize;
2625 let end_sample = ((end * sample_rate) as usize).min(audio.len());
2626
2627 if end_sample > start_sample {
2628 let speech_audio = &audio[start_sample..end_sample];
2629 let result = self.transcribe(speech_audio, options.clone())?;
2630
2631 let segment = VadSpeechSegment {
2632 start: *start,
2633 end: *end,
2634 text: result.text.clone(),
2635 tokens: result
2636 .segments
2637 .first()
2638 .map(|s| s.tokens.clone())
2639 .unwrap_or_default(),
2640 };
2641
2642 if !full_text.is_empty() {
2643 full_text.push(' ');
2644 }
2645 full_text.push_str(&result.text);
2646
2647 speech_duration += end - start;
2648 all_segments.push(segment);
2649 }
2650 }
2651
2652 let total_duration_secs = start_time.elapsed().as_secs_f32();
2653
2654 Ok(VadTranscriptionResult {
2655 text: full_text,
2656 language,
2657 segments: all_segments,
2658 speech_segments,
2659 total_duration_secs,
2660 speech_duration_secs: speech_duration,
2661 })
2662 }
2663
2664 fn invert_silence_segments(
2666 &self,
2667 silence_segments: &[vad::SilenceSegment],
2668 audio_len: usize,
2669 ) -> Vec<(f32, f32)> {
2670 let _ = self; let sample_rate = audio::SAMPLE_RATE as f32;
2672 let total_duration = audio_len as f32 / sample_rate;
2673 let mut speech_segments = Vec::new();
2674 let mut current_pos = 0.0f32;
2675
2676 for silence in silence_segments {
2677 if silence.start > current_pos {
2678 speech_segments.push((current_pos, silence.start));
2679 }
2680 current_pos = silence.end;
2681 }
2682
2683 if current_pos < total_duration {
2685 speech_segments.push((current_pos, total_duration));
2686 }
2687
2688 speech_segments
2689 }
2690
2691 pub fn transcribe_partial(
2728 &self,
2729 partial_audio: &[f32],
2730 options: TranscribeOptions,
2731 is_final: bool,
2732 ) -> WhisperResult<PartialTranscriptionResult> {
2733 let start_time = std::time::Instant::now();
2734
2735 let min_samples = (audio::SAMPLE_RATE as f32 * 0.5) as usize;
2737 if partial_audio.len() < min_samples {
2738 return Ok(PartialTranscriptionResult {
2739 text: String::new(),
2740 language: options.language.unwrap_or_else(|| "en".to_string()),
2741 is_final,
2742 confidence: 0.0,
2743 duration_secs: partial_audio.len() as f32 / audio::SAMPLE_RATE as f32,
2744 processing_time_secs: start_time.elapsed().as_secs_f32(),
2745 });
2746 }
2747
2748 let result = self.transcribe(partial_audio, options)?;
2750
2751 let processing_time = start_time.elapsed().as_secs_f32();
2752
2753 Ok(PartialTranscriptionResult {
2754 text: result.text,
2755 language: result.language,
2756 is_final,
2757 confidence: 1.0, duration_secs: partial_audio.len() as f32 / audio::SAMPLE_RATE as f32,
2759 processing_time_secs: processing_time,
2760 })
2761 }
2762
2763 #[must_use]
2800 pub fn create_streaming_session(
2801 &self,
2802 options: TranscribeOptions,
2803 input_sample_rate: u32,
2804 ) -> StreamingSession<'_> {
2805 let streaming_config = audio::StreamingConfig::with_sample_rate(input_sample_rate);
2806 let processor = audio::StreamingProcessor::new(streaming_config);
2807
2808 StreamingSession {
2809 whisper: self,
2810 processor,
2811 options,
2812 last_partial_text: String::new(),
2813 }
2814 }
2815
2816 pub fn transcribe_and_summarize(
2851 &self,
2852 audio: &[f32],
2853 transcribe_options: TranscribeOptions,
2854 summarize_options: SummarizeOptions<'_>,
2855 ) -> WhisperResult<TranscribeSummaryResult> {
2856 let transcription = self.transcribe(audio, transcribe_options)?;
2858
2859 if transcription.text.trim().is_empty() {
2861 return Ok(TranscribeSummaryResult {
2862 transcription,
2863 summary: String::new(),
2864 generation_stats: None,
2865 });
2866 }
2867
2868 let input_tokens = summarize_options.tokenizer.encode(&transcription.text);
2870
2871 let (output_tokens, stats) = summarize_options.model.generate_with_stats(
2873 &input_tokens,
2874 summarize_options.max_tokens,
2875 summarize_options.temperature,
2876 Some(|_token: u32, _idx: usize| true), )?;
2878
2879 let summary = summarize_options.tokenizer.decode(&output_tokens);
2881
2882 Ok(TranscribeSummaryResult {
2883 transcription,
2884 summary,
2885 generation_stats: Some(stats),
2886 })
2887 }
2888}
2889
2890#[derive(Debug, Clone)]
2892pub struct PartialTranscriptionResult {
2893 pub text: String,
2895 pub language: String,
2897 pub is_final: bool,
2899 pub confidence: f32,
2901 pub duration_secs: f32,
2903 pub processing_time_secs: f32,
2905}
2906
2907impl PartialTranscriptionResult {
2908 #[must_use]
2910 pub fn has_text(&self) -> bool {
2911 !self.text.is_empty()
2912 }
2913
2914 #[must_use]
2916 pub fn is_empty_interim(&self) -> bool {
2917 self.text.is_empty() && !self.is_final
2918 }
2919
2920 #[must_use]
2922 pub fn real_time_factor(&self) -> f32 {
2923 if self.duration_secs <= 0.0 {
2924 0.0
2925 } else {
2926 self.processing_time_secs / self.duration_secs
2927 }
2928 }
2929}
2930
2931#[derive(Debug)]
2936pub struct StreamingSession<'a> {
2937 whisper: &'a WhisperApr,
2939 processor: audio::StreamingProcessor,
2941 options: TranscribeOptions,
2943 last_partial_text: String,
2945}
2946
2947impl StreamingSession<'_> {
2948 pub fn push(&mut self, audio: &[f32]) -> WhisperResult<Option<PartialTranscriptionResult>> {
2959 self.processor.push_audio(audio);
2960 self.processor.process();
2961
2962 if self.processor.has_partial() {
2964 if let Some(partial_audio) = self.processor.get_partial() {
2965 let result =
2966 self.whisper
2967 .transcribe_partial(&partial_audio, self.options.clone(), false)?;
2968
2969 if result.text != self.last_partial_text {
2971 result.text.clone_into(&mut self.last_partial_text);
2972 return Ok(Some(result));
2973 }
2974 }
2975 }
2976
2977 Ok(None)
2978 }
2979
2980 #[must_use]
2982 pub fn has_chunk(&self) -> bool {
2983 self.processor.has_chunk()
2984 }
2985
2986 #[must_use]
2988 pub fn has_events(&self) -> bool {
2989 self.processor.has_events()
2990 }
2991
2992 pub fn drain_events(&mut self) -> Vec<audio::StreamingEvent> {
2994 self.processor.drain_events()
2995 }
2996
2997 pub fn finalize(&mut self) -> WhisperResult<PartialTranscriptionResult> {
3005 let chunk = self
3006 .processor
3007 .get_chunk()
3008 .ok_or_else(|| WhisperError::Audio("no chunk ready for finalization".into()))?;
3009
3010 let result = self
3011 .whisper
3012 .transcribe_partial(&chunk, self.options.clone(), true)?;
3013 self.last_partial_text.clear();
3014
3015 Ok(result)
3016 }
3017
3018 pub fn flush(&mut self) -> WhisperResult<Option<PartialTranscriptionResult>> {
3026 if let Some(chunk) = self.processor.flush() {
3027 let result = self
3028 .whisper
3029 .transcribe_partial(&chunk, self.options.clone(), true)?;
3030 self.last_partial_text.clear();
3031 Ok(Some(result))
3032 } else {
3033 Ok(None)
3034 }
3035 }
3036
3037 pub fn reset(&mut self) {
3039 self.processor.reset();
3040 self.last_partial_text.clear();
3041 }
3042
3043 #[must_use]
3045 pub fn state(&self) -> audio::ProcessorState {
3046 self.processor.state()
3047 }
3048
3049 #[must_use]
3051 pub fn chunk_progress(&self) -> f32 {
3052 self.processor.chunk_progress()
3053 }
3054
3055 #[must_use]
3057 pub fn partial_duration(&self) -> f32 {
3058 self.processor.partial_duration()
3059 }
3060
3061 pub fn set_partial_threshold(&mut self, seconds: f32) {
3063 self.processor.set_partial_threshold(seconds);
3064 }
3065}
3066
3067#[derive(Debug, Clone)]
3069pub struct VadTranscriptionResult {
3070 pub text: String,
3072 pub language: String,
3074 pub segments: Vec<VadSpeechSegment>,
3076 pub speech_segments: Vec<(f32, f32)>,
3078 pub total_duration_secs: f32,
3080 pub speech_duration_secs: f32,
3082}
3083
3084impl VadTranscriptionResult {
3085 #[must_use]
3087 pub fn num_segments(&self) -> usize {
3088 self.segments.len()
3089 }
3090
3091 #[must_use]
3093 pub fn has_speech(&self) -> bool {
3094 !self.segments.is_empty()
3095 }
3096
3097 #[must_use]
3099 pub fn silence_ratio(&self, audio_duration: f32) -> f32 {
3100 if audio_duration <= 0.0 {
3101 return 1.0;
3102 }
3103 1.0 - (self.speech_duration_secs / audio_duration)
3104 }
3105
3106 #[must_use]
3108 pub fn first_segment(&self) -> Option<&VadSpeechSegment> {
3109 self.segments.first()
3110 }
3111
3112 #[must_use]
3114 pub fn last_segment(&self) -> Option<&VadSpeechSegment> {
3115 self.segments.last()
3116 }
3117
3118 pub fn iter(&self) -> impl Iterator<Item = &VadSpeechSegment> {
3120 self.segments.iter()
3121 }
3122}
3123
3124#[derive(Debug, Clone)]
3126pub struct VadSpeechSegment {
3127 pub start: f32,
3129 pub end: f32,
3131 pub text: String,
3133 pub tokens: Vec<u32>,
3135}
3136
3137impl VadSpeechSegment {
3138 #[must_use]
3140 pub fn duration(&self) -> f32 {
3141 self.end - self.start
3142 }
3143
3144 #[must_use]
3146 pub fn has_text(&self) -> bool {
3147 !self.text.is_empty()
3148 }
3149}
3150
3151#[cfg(test)]
3152mod tests {
3153 use super::*;
3154
3155 #[test]
3156 fn test_default_options() {
3157 let options = TranscribeOptions::default();
3158 assert!(options.language.is_none());
3159 assert_eq!(options.task, Task::Transcribe);
3160 assert!(!options.word_timestamps);
3161 assert!(options.prompt.is_none());
3162 assert!(options.hotwords.is_empty());
3163 }
3164
3165 #[test]
3166 fn test_options_with_prompt() {
3167 let options = TranscribeOptions {
3168 prompt: Some("This lecture covers AWS, YAML, and Rust programming.".into()),
3169 ..TranscribeOptions::default()
3170 };
3171 assert_eq!(
3172 options.prompt.as_deref(),
3173 Some("This lecture covers AWS, YAML, and Rust programming.")
3174 );
3175 }
3176
3177 #[test]
3178 fn test_options_with_hotwords() {
3179 let options = TranscribeOptions {
3180 hotwords: vec!["AWS".into(), "YAML".into(), "SIMD".into(), "Rust".into()],
3181 ..TranscribeOptions::default()
3182 };
3183 assert_eq!(options.hotwords.len(), 4);
3184 assert_eq!(options.hotwords[0], "AWS");
3185 }
3186
3187 #[test]
3188 fn test_options_with_prompt_and_hotwords() {
3189 let options = TranscribeOptions {
3190 language: Some("en".into()),
3191 prompt: Some("Technical lecture on cloud computing".into()),
3192 hotwords: vec!["Kubernetes".into(), "Docker".into()],
3193 ..TranscribeOptions::default()
3194 };
3195 assert!(options.prompt.is_some());
3196 assert_eq!(options.hotwords.len(), 2);
3197 assert_eq!(options.language, Some("en".into()));
3198 }
3199
3200 #[test]
3201 fn test_decoding_strategy_default() {
3202 let strategy = DecodingStrategy::default();
3203 assert!(matches!(strategy, DecodingStrategy::Greedy));
3204 }
3205
3206 #[test]
3211 #[ignore = "Allocates large model - run with --ignored"]
3212 fn test_whisper_tiny() {
3213 let whisper = WhisperApr::tiny();
3214 assert_eq!(whisper.model_type(), ModelType::Tiny);
3215 assert_eq!(whisper.config().n_audio_layer, 4);
3216 }
3217
3218 #[test]
3219 #[ignore = "Allocates large model - run with --ignored"]
3220 fn test_whisper_base() {
3221 let whisper = WhisperApr::base();
3222 assert_eq!(whisper.model_type(), ModelType::Base);
3223 assert_eq!(whisper.config().n_audio_layer, 6);
3224 }
3225
3226 #[test]
3227 #[ignore = "Allocates large model - run with --ignored"]
3228 fn test_whisper_memory_size() {
3229 let tiny = WhisperApr::tiny();
3230 let base = WhisperApr::base();
3231
3232 assert!(tiny.memory_size() < base.memory_size());
3233 assert!(tiny.memory_size() > 100_000_000); }
3235
3236 #[test]
3237 #[ignore = "Allocates large model - run with --ignored"]
3238 fn test_whisper_initial_tokens() {
3239 let whisper = WhisperApr::tiny();
3240
3241 let tokens = whisper.get_initial_tokens("en", Task::Transcribe);
3242 assert_eq!(tokens[0], tokenizer::special_tokens::SOT);
3243 assert!(tokens.len() >= 3); let translate_tokens = whisper.get_initial_tokens("es", Task::Translate);
3246 assert!(translate_tokens.contains(&tokenizer::special_tokens::TRANSLATE));
3247 }
3248
3249 #[test]
3250 #[ignore = "Allocates large model - run with --ignored"]
3251 fn test_whisper_set_resampler() {
3252 let mut whisper = WhisperApr::tiny();
3253
3254 assert!(whisper.resampler.is_none());
3256
3257 whisper.set_resampler(44100).expect("should succeed");
3259 assert!(whisper.resampler.is_some());
3260
3261 whisper.set_resampler(16000).expect("should succeed");
3263 assert!(whisper.resampler.is_none());
3264 }
3265
3266 #[test]
3267 #[ignore = "Allocates large model - run with --ignored"]
3268 fn test_whisper_resample_passthrough() {
3269 let whisper = WhisperApr::tiny();
3270 let audio = vec![0.1, 0.2, 0.3, 0.4];
3271
3272 let resampled = whisper.resample(&audio).expect("should succeed");
3274 assert_eq!(resampled, audio);
3275 }
3276
3277 #[test]
3278 #[ignore = "Allocates large model - run with --ignored"]
3279 fn test_whisper_resample_with_resampler() {
3280 let mut whisper = WhisperApr::tiny();
3281 whisper.set_resampler(32000).expect("should succeed");
3282
3283 let n_samples = 16000;
3285 let audio: Vec<f32> = (0..n_samples)
3286 .map(|i| (2.0 * std::f32::consts::PI * 440.0 * i as f32 / 32000.0).sin())
3287 .collect();
3288
3289 let resampled = whisper.resample(&audio).expect("should succeed");
3290
3291 assert!(resampled.len() > n_samples / 3);
3293 assert!(resampled.len() < n_samples);
3294 }
3295
3296 #[test]
3297 #[ignore = "Allocates large model - run with --ignored"]
3298 fn test_whisper_tokenizer() {
3299 let whisper = WhisperApr::tiny();
3300 let tokenizer = whisper.tokenizer();
3301
3302 let vocab_size = tokenizer.vocab_size();
3304 assert!(
3305 vocab_size >= 256,
3306 "vocab_size should include base byte tokens"
3307 );
3308 }
3309
3310 #[test]
3311 fn test_transcribe_options_with_language() {
3312 let options = TranscribeOptions {
3313 language: Some("es".to_string()),
3314 task: Task::Transcribe,
3315 strategy: DecodingStrategy::Greedy,
3316 word_timestamps: false,
3317 profile: false,
3318 ..Default::default()
3319 };
3320
3321 assert_eq!(options.language, Some("es".to_string()));
3322 }
3323
3324 #[test]
3325 fn test_transcribe_options_beam_search() {
3326 let options = TranscribeOptions {
3327 language: None,
3328 task: Task::Transcribe,
3329 strategy: DecodingStrategy::BeamSearch {
3330 beam_size: 5,
3331 temperature: 0.0,
3332 patience: 1.0,
3333 },
3334 word_timestamps: false,
3335 profile: false,
3336 ..Default::default()
3337 };
3338
3339 assert!(matches!(
3340 options.strategy,
3341 DecodingStrategy::BeamSearch { .. }
3342 ));
3343 }
3344
3345 #[test]
3346 fn test_segment_struct() {
3347 let segment = Segment {
3348 start: 0.0,
3349 end: 2.5,
3350 text: "Hello world".to_string(),
3351 tokens: vec![1, 2, 3],
3352 };
3353
3354 assert!((segment.start - 0.0).abs() < f32::EPSILON);
3355 assert!((segment.end - 2.5).abs() < f32::EPSILON);
3356 assert_eq!(segment.text, "Hello world");
3357 assert_eq!(segment.tokens.len(), 3);
3358 }
3359
3360 #[test]
3361 fn test_transcription_result_struct() {
3362 let result = TranscriptionResult {
3363 text: "Test transcription".to_string(),
3364 language: "en".to_string(),
3365 segments: vec![],
3366 profiling: None,
3367 };
3368
3369 assert_eq!(result.text, "Test transcription");
3370 assert_eq!(result.language, "en");
3371 assert!(result.segments.is_empty());
3372 }
3373
3374 #[test]
3379 fn test_load_from_apr_basic() {
3380 let data = format::create_test_apr();
3382
3383 let result = WhisperApr::load_from_apr(&data);
3384 assert!(result.is_ok());
3385
3386 let whisper = result.expect("should load");
3387 assert_eq!(whisper.model_type(), ModelType::Tiny);
3388 }
3389
3390 #[test]
3391 fn test_load_from_apr_with_progress_callback() {
3392 let data = format::create_test_apr();
3393
3394 let mut progress_updates = Vec::new();
3395 let mut callback = |p: &progress::Progress| {
3396 progress_updates.push(p.percent());
3397 };
3398
3399 let result = WhisperApr::load_from_apr_with_progress(&data, &mut callback);
3400 assert!(result.is_ok());
3401
3402 assert!(!progress_updates.is_empty());
3404 }
3405
3406 #[test]
3407 fn test_load_from_apr_invalid_magic() {
3408 let mut data = format::create_test_apr();
3409 data[0] = b'X'; let result = WhisperApr::load_from_apr(&data);
3412 assert!(result.is_err());
3413 }
3414
3415 #[test]
3416 fn test_load_from_apr_too_short() {
3417 let data = vec![b'A', b'P', b'R', b'1']; let result = WhisperApr::load_from_apr(&data);
3420 assert!(result.is_err());
3421 }
3422
3423 #[test]
3424 #[ignore = "Allocates large model - run with --ignored"]
3425 fn test_encoder_mut_accessor() {
3426 let mut whisper = WhisperApr::tiny();
3427 let encoder = whisper.encoder_mut();
3428 assert_eq!(encoder.n_layers(), 4);
3429 }
3430
3431 #[test]
3432 #[ignore = "Allocates large model - run with --ignored"]
3433 fn test_decoder_mut_accessor() {
3434 let mut whisper = WhisperApr::tiny();
3435 let decoder = whisper.decoder_mut();
3436 assert_eq!(decoder.n_layers(), 4);
3437 }
3438
3439 #[test]
3444 #[ignore = "Allocates large model - run with --ignored"]
3445 fn test_full_pipeline_tiny_model() {
3446 let whisper = WhisperApr::tiny();
3448 assert_eq!(whisper.model_type(), ModelType::Tiny);
3449
3450 assert_eq!(whisper.config().n_audio_layer, 4);
3452 assert_eq!(whisper.config().n_text_layer, 4);
3453 assert_eq!(whisper.config().n_audio_state, 384);
3454 assert_eq!(whisper.config().n_text_state, 384);
3455 }
3456
3457 #[test]
3458 #[ignore = "Allocates large model - run with --ignored"]
3459 fn test_full_pipeline_base_model() {
3460 let whisper = WhisperApr::base();
3462 assert_eq!(whisper.model_type(), ModelType::Base);
3463
3464 assert_eq!(whisper.config().n_audio_layer, 6);
3466 assert_eq!(whisper.config().n_text_layer, 6);
3467 assert_eq!(whisper.config().n_audio_state, 512);
3468 assert_eq!(whisper.config().n_text_state, 512);
3469 }
3470
3471 #[test]
3472 fn test_transcribe_options_all_strategies() {
3473 let opts_greedy = TranscribeOptions::default();
3475 assert!(matches!(opts_greedy.strategy, DecodingStrategy::Greedy));
3476
3477 let opts_beam = TranscribeOptions {
3479 language: Some("en".to_string()),
3480 task: Task::Transcribe,
3481 strategy: DecodingStrategy::BeamSearch {
3482 beam_size: 5,
3483 temperature: 0.0,
3484 patience: 1.0,
3485 },
3486 word_timestamps: false,
3487 profile: false,
3488 ..Default::default()
3489 };
3490 assert!(matches!(
3491 opts_beam.strategy,
3492 DecodingStrategy::BeamSearch { .. }
3493 ));
3494
3495 let opts_sampling = TranscribeOptions {
3497 language: None,
3498 task: Task::Translate,
3499 strategy: DecodingStrategy::Sampling {
3500 temperature: 0.7,
3501 top_k: Some(50),
3502 top_p: Some(0.9),
3503 },
3504 word_timestamps: true,
3505 profile: false,
3506 ..Default::default()
3507 };
3508 assert!(matches!(
3509 opts_sampling.strategy,
3510 DecodingStrategy::Sampling { .. }
3511 ));
3512 }
3513
3514 #[test]
3515 #[ignore = "Allocates large model - run with --ignored"]
3516 fn test_memory_estimation_consistency() {
3517 let tiny = WhisperApr::tiny();
3518 let base = WhisperApr::base();
3519
3520 assert!(base.config().weights_memory_mb() > tiny.config().weights_memory_mb());
3522 assert!(base.config().peak_memory_mb() > tiny.config().peak_memory_mb());
3523 assert!(base.config().parameter_count() > tiny.config().parameter_count());
3524 }
3525
3526 #[test]
3527 fn test_simd_operations_integration() {
3528 use crate::simd;
3529
3530 let a = vec![1.0, 2.0, 3.0, 4.0];
3532 let b = vec![5.0, 6.0, 7.0, 8.0];
3533
3534 let sum = simd::add(&a, &b);
3536 assert_eq!(sum.len(), 4);
3537 assert!((sum[0] - 6.0).abs() < 1e-5);
3538 assert!((sum[3] - 12.0).abs() < 1e-5);
3539
3540 let softmax = simd::softmax(&a);
3542 let sum_softmax: f32 = softmax.iter().sum();
3543 assert!((sum_softmax - 1.0).abs() < 1e-5);
3544
3545 let mat_a = vec![1.0, 2.0, 3.0, 4.0]; let mat_b = vec![5.0, 6.0, 7.0, 8.0]; let result = simd::matmul(&mat_a, &mat_b, 2, 2, 2);
3549 assert_eq!(result.len(), 4);
3550 }
3551
3552 #[test]
3553 fn test_memory_pool_integration() {
3554 use crate::memory::{get_buffer, pool_stats, return_buffer, MemoryPool};
3555
3556 let pool = MemoryPool::new();
3558
3559 let buf1 = pool.get(1024);
3561 assert_eq!(buf1.len(), 1024);
3562 pool.return_buffer(buf1);
3563
3564 let buf2 = pool.get(1024);
3565 assert_eq!(buf2.len(), 1024);
3566
3567 let stats = pool.stats();
3569 assert_eq!(stats.hits, 1);
3570
3571 let tlbuf = get_buffer(512);
3573 assert_eq!(tlbuf.len(), 512);
3574 return_buffer(tlbuf);
3575
3576 let tl_stats = pool_stats();
3577 assert!(tl_stats.allocations > 0);
3578 }
3579
3580 #[test]
3581 fn test_audio_resampling_integration() {
3582 use audio::SincResampler;
3583
3584 let resampler = SincResampler::new(44100, 16000).expect("resampler should work");
3586
3587 let duration_ms = 100;
3589 let samples_at_44100 = (44100 * duration_ms) / 1000;
3590 let input: Vec<f32> = (0..samples_at_44100)
3591 .map(|i| {
3592 let t = i as f32 / 44100.0;
3593 (2.0 * std::f32::consts::PI * 440.0 * t).sin()
3594 })
3595 .collect();
3596
3597 let output = resampler.resample(&input).expect("resample should work");
3598
3599 let expected_len = (input.len() as f32 * 16000.0 / 44100.0) as usize;
3601 assert!(
3602 (output.len() as i32 - expected_len as i32).abs() < 10,
3603 "output len {} vs expected ~{}",
3604 output.len(),
3605 expected_len
3606 );
3607 }
3608
3609 #[test]
3610 fn test_vad_integration() {
3611 use vad::VadConfig;
3612
3613 let config = VadConfig::default();
3615 assert!(config.energy_threshold > 0.0);
3616 assert!(config.zcr_threshold > 0.0);
3617 assert!(config.min_speech_frames > 0);
3618 }
3619
3620 #[test]
3625 fn test_vad_transcription_result_new() {
3626 let result = VadTranscriptionResult {
3627 text: "hello world".to_string(),
3628 language: "en".to_string(),
3629 segments: vec![],
3630 speech_segments: vec![],
3631 total_duration_secs: 1.0,
3632 speech_duration_secs: 0.5,
3633 };
3634
3635 assert_eq!(result.text, "hello world");
3636 assert_eq!(result.language, "en");
3637 assert!(!result.has_speech());
3638 assert_eq!(result.num_segments(), 0);
3639 }
3640
3641 #[test]
3642 fn test_vad_transcription_result_with_segments() {
3643 let result = VadTranscriptionResult {
3644 text: "hello world".to_string(),
3645 language: "en".to_string(),
3646 segments: vec![
3647 VadSpeechSegment {
3648 start: 0.0,
3649 end: 1.0,
3650 text: "hello".to_string(),
3651 tokens: vec![1, 2],
3652 },
3653 VadSpeechSegment {
3654 start: 1.5,
3655 end: 2.5,
3656 text: "world".to_string(),
3657 tokens: vec![3, 4],
3658 },
3659 ],
3660 speech_segments: vec![(0.0, 1.0), (1.5, 2.5)],
3661 total_duration_secs: 3.0,
3662 speech_duration_secs: 2.0,
3663 };
3664
3665 assert!(result.has_speech());
3666 assert_eq!(result.num_segments(), 2);
3667 assert!(result.first_segment().is_some());
3668 assert!(result.last_segment().is_some());
3669 assert_eq!(
3670 result.first_segment().map(|s| &s.text),
3671 Some(&"hello".to_string())
3672 );
3673 assert_eq!(
3674 result.last_segment().map(|s| &s.text),
3675 Some(&"world".to_string())
3676 );
3677 }
3678
3679 #[test]
3680 fn test_vad_transcription_result_silence_ratio() {
3681 let result = VadTranscriptionResult {
3682 text: String::new(),
3683 language: "en".to_string(),
3684 segments: vec![],
3685 speech_segments: vec![],
3686 total_duration_secs: 1.0,
3687 speech_duration_secs: 0.5,
3688 };
3689
3690 let ratio = result.silence_ratio(2.0);
3691 assert!((ratio - 0.75).abs() < 0.01); }
3693
3694 #[test]
3695 fn test_vad_transcription_result_silence_ratio_zero_duration() {
3696 let result = VadTranscriptionResult {
3697 text: String::new(),
3698 language: "en".to_string(),
3699 segments: vec![],
3700 speech_segments: vec![],
3701 total_duration_secs: 0.0,
3702 speech_duration_secs: 0.0,
3703 };
3704
3705 let ratio = result.silence_ratio(0.0);
3706 assert!((ratio - 1.0).abs() < 0.01);
3707 }
3708
3709 #[test]
3710 fn test_vad_transcription_result_iter() {
3711 let result = VadTranscriptionResult {
3712 text: "a b".to_string(),
3713 language: "en".to_string(),
3714 segments: vec![
3715 VadSpeechSegment {
3716 start: 0.0,
3717 end: 1.0,
3718 text: "a".to_string(),
3719 tokens: vec![1],
3720 },
3721 VadSpeechSegment {
3722 start: 1.0,
3723 end: 2.0,
3724 text: "b".to_string(),
3725 tokens: vec![2],
3726 },
3727 ],
3728 speech_segments: vec![(0.0, 1.0), (1.0, 2.0)],
3729 total_duration_secs: 2.0,
3730 speech_duration_secs: 2.0,
3731 };
3732
3733 let texts: Vec<_> = result.iter().map(|s| s.text.as_str()).collect();
3734 assert_eq!(texts, vec!["a", "b"]);
3735 }
3736
3737 #[test]
3738 fn test_vad_speech_segment_duration() {
3739 let segment = VadSpeechSegment {
3740 start: 1.5,
3741 end: 3.0,
3742 text: "test".to_string(),
3743 tokens: vec![1, 2, 3],
3744 };
3745
3746 assert!((segment.duration() - 1.5).abs() < 0.01);
3747 }
3748
3749 #[test]
3750 fn test_vad_speech_segment_has_text() {
3751 let with_text = VadSpeechSegment {
3752 start: 0.0,
3753 end: 1.0,
3754 text: "hello".to_string(),
3755 tokens: vec![1],
3756 };
3757 let empty = VadSpeechSegment {
3758 start: 0.0,
3759 end: 1.0,
3760 text: String::new(),
3761 tokens: vec![],
3762 };
3763
3764 assert!(with_text.has_text());
3765 assert!(!empty.has_text());
3766 }
3767
3768 #[test]
3769 #[ignore = "Allocates large model - run with --ignored"]
3770 fn test_transcribe_with_vad_silence_only() {
3771 let whisper = WhisperApr::tiny();
3772 let silence = vec![0.0; 16000]; let result = whisper
3775 .transcribe_with_vad(&silence, TranscribeOptions::default(), None)
3776 .expect("should succeed");
3777
3778 assert!(!result.has_speech());
3779 assert_eq!(result.num_segments(), 0);
3780 assert!(result.text.is_empty());
3781 }
3782
3783 #[test]
3784 fn test_transcribe_with_silence_detection_config() {
3785 let config = vad::SilenceConfig::new()
3787 .with_min_silence_duration(0.3)
3788 .with_max_silence_duration(2.0)
3789 .with_silence_threshold(0.001);
3790
3791 assert!((config.min_silence_duration - 0.3).abs() < 0.01);
3792 assert!((config.max_silence_duration - 2.0).abs() < 0.01);
3793 assert!((config.silence_threshold - 0.001).abs() < 0.001);
3794 }
3795
3796 #[test]
3797 #[ignore = "Allocates large model - run with --ignored"]
3798 fn test_invert_silence_segments_empty() {
3799 let whisper = WhisperApr::tiny();
3800 let audio_len = 16000; let silence_segments: Vec<vad::SilenceSegment> = vec![];
3803 let speech = whisper.invert_silence_segments(&silence_segments, audio_len);
3804
3805 assert_eq!(speech.len(), 1);
3807 assert!((speech[0].0 - 0.0).abs() < 0.01);
3808 assert!((speech[0].1 - 1.0).abs() < 0.01);
3809 }
3810
3811 #[test]
3812 #[ignore = "Allocates large model - run with --ignored"]
3813 fn test_invert_silence_segments_single() {
3814 let whisper = WhisperApr::tiny();
3815 let audio_len = 32000; let silence_segments = vec![vad::SilenceSegment {
3818 start: 0.5,
3819 end: 1.5,
3820 noise_floor: 0.001,
3821 }];
3822 let speech = whisper.invert_silence_segments(&silence_segments, audio_len);
3823
3824 assert_eq!(speech.len(), 2);
3826 assert!((speech[0].0 - 0.0).abs() < 0.01); assert!((speech[0].1 - 0.5).abs() < 0.01); assert!((speech[1].0 - 1.5).abs() < 0.01); assert!((speech[1].1 - 2.0).abs() < 0.01); }
3831
3832 #[test]
3833 #[ignore = "Allocates large model - run with --ignored"]
3834 fn test_invert_silence_segments_multiple() {
3835 let whisper = WhisperApr::tiny();
3836 let audio_len = 48000; let silence_segments = vec![
3839 vad::SilenceSegment {
3840 start: 0.5,
3841 end: 1.0,
3842 noise_floor: 0.001,
3843 },
3844 vad::SilenceSegment {
3845 start: 2.0,
3846 end: 2.5,
3847 noise_floor: 0.001,
3848 },
3849 ];
3850 let speech = whisper.invert_silence_segments(&silence_segments, audio_len);
3851
3852 assert_eq!(speech.len(), 3);
3854 }
3855
3856 #[test]
3857 fn test_tokenizer_integration() {
3858 use tokenizer::special_tokens;
3859
3860 assert!(special_tokens::SOT > 0);
3862 assert!(special_tokens::EOT > 0);
3863 assert!(special_tokens::TRANSCRIBE > 0);
3864 assert!(special_tokens::TRANSLATE > 0);
3865 assert!(special_tokens::NO_TIMESTAMPS > 0);
3866 }
3867
3868 #[test]
3869 fn test_inference_beam_search_integration() {
3870 use inference::BeamSearchDecoder;
3871
3872 let decoder = BeamSearchDecoder::new(5, 448); assert_eq!(decoder.beam_size(), 5);
3876 }
3877
3878 #[test]
3879 fn test_format_decompression_integration() {
3880 use format::Decompressor;
3881
3882 let mut decompressor = Decompressor::new();
3884
3885 assert!(decompressor.is_empty());
3887
3888 decompressor.reset();
3890 assert!(decompressor.is_empty());
3891 }
3892
3893 #[test]
3894 fn test_timestamps_generation() {
3895 let segment = Segment {
3896 start: 1.5,
3897 end: 3.25,
3898 text: "This is a test".to_string(),
3899 tokens: vec![1, 2, 3, 4],
3900 };
3901
3902 let duration = segment.end - segment.start;
3904 assert!((duration - 1.75).abs() < 1e-5);
3905 }
3906
3907 #[test]
3908 fn test_language_detection_integration() {
3909 use detection::{is_supported, language_name};
3910
3911 assert!(is_supported("en"));
3913 assert!(is_supported("es"));
3914 assert!(is_supported("ja"));
3915 assert!(!is_supported("invalid_lang"));
3916
3917 assert_eq!(language_name("en"), Some("English"));
3919 assert_eq!(language_name("es"), Some("Spanish"));
3920 assert_eq!(language_name("zh"), Some("Chinese"));
3921
3922 for lang in &["en", "es", "fr", "de", "it", "ja", "zh", "ko", "pt", "ru"] {
3925 assert!(
3926 language_name(lang).is_some(),
3927 "language {} should have a name",
3928 lang
3929 );
3930 }
3931 }
3932
3933 #[test]
3934 fn test_model_kv_cache_integration() {
3935 use model::{Decoder, ModelConfig};
3936
3937 let config = ModelConfig::tiny();
3938 let decoder = Decoder::new(&config);
3939
3940 let cache = decoder.create_kv_cache();
3942 assert!(cache.self_attn_cache.iter().all(|c| c.is_empty()));
3943 assert!(cache.cross_attn_cache.iter().all(|c| c.is_empty()));
3944 }
3945
3946 #[test]
3947 fn test_progress_tracking_integration() {
3948 use progress::{format_bytes, Progress};
3949
3950 let progress = Progress::new(50, 100);
3952 assert_eq!(progress.percent(), 50.0);
3953 assert_eq!(progress.current, 50);
3954 assert_eq!(progress.total, 100);
3955
3956 let kb_str = format_bytes(1024);
3958 assert!(kb_str.contains("KB"), "should contain KB: {}", kb_str);
3959 let mb_str = format_bytes(1024 * 1024);
3960 assert!(mb_str.contains("MB"), "should contain MB: {}", mb_str);
3961 }
3962
3963 #[test]
3968 fn test_model_type_variants() {
3969 let tiny = ModelType::Tiny;
3970 let tiny_en = ModelType::TinyEn;
3971 let base = ModelType::Base;
3972 let base_en = ModelType::BaseEn;
3973 let small = ModelType::Small;
3974
3975 assert!(format!("{tiny:?}").contains("Tiny"));
3977 assert!(format!("{tiny_en:?}").contains("TinyEn"));
3978 assert!(format!("{base:?}").contains("Base"));
3979 assert!(format!("{base_en:?}").contains("BaseEn"));
3980 assert!(format!("{small:?}").contains("Small"));
3981
3982 let tiny_clone = tiny;
3984 assert_eq!(tiny_clone, ModelType::Tiny);
3985
3986 assert_eq!(tiny, ModelType::Tiny);
3988 assert_ne!(tiny, base);
3989 }
3990
3991 #[test]
3992 fn test_task_variants() {
3993 let transcribe = Task::Transcribe;
3994 let translate = Task::Translate;
3995
3996 assert!(format!("{transcribe:?}").contains("Transcribe"));
3998 assert!(format!("{translate:?}").contains("Translate"));
3999
4000 let transcribe_clone = transcribe;
4002 assert_eq!(transcribe_clone, Task::Transcribe);
4003
4004 assert_eq!(transcribe, Task::Transcribe);
4006 assert_ne!(transcribe, translate);
4007
4008 assert_eq!(Task::default(), Task::Transcribe);
4010 }
4011
4012 #[test]
4013 fn test_decoding_strategy_sampling() {
4014 let sampling = DecodingStrategy::Sampling {
4015 temperature: 0.8,
4016 top_k: Some(50),
4017 top_p: Some(0.9),
4018 };
4019
4020 let debug_str = format!("{sampling:?}");
4022 assert!(debug_str.contains("Sampling"));
4023
4024 let cloned = sampling.clone();
4026 assert!(matches!(cloned, DecodingStrategy::Sampling { .. }));
4027 }
4028
4029 #[test]
4030 fn test_decoding_strategy_beam_search() {
4031 let beam = DecodingStrategy::BeamSearch {
4032 beam_size: 5,
4033 temperature: 0.0,
4034 patience: 1.0,
4035 };
4036
4037 let debug_str = format!("{beam:?}");
4038 assert!(debug_str.contains("BeamSearch"));
4039 }
4040
4041 #[test]
4042 fn test_transcribe_options_clone() {
4043 let options = TranscribeOptions {
4044 language: Some("fr".to_string()),
4045 task: Task::Translate,
4046 strategy: DecodingStrategy::Greedy,
4047 word_timestamps: true,
4048 profile: false,
4049 ..Default::default()
4050 };
4051
4052 let cloned = options.clone();
4053 assert_eq!(cloned.language, Some("fr".to_string()));
4054 assert_eq!(cloned.task, Task::Translate);
4055 assert!(cloned.word_timestamps);
4056 }
4057
4058 #[test]
4059 fn test_transcribe_options_debug() {
4060 let options = TranscribeOptions::default();
4061 let debug_str = format!("{options:?}");
4062 assert!(debug_str.contains("TranscribeOptions"));
4063 }
4064
4065 #[test]
4066 fn test_segment_clone() {
4067 let segment = Segment {
4068 start: 1.0,
4069 end: 2.0,
4070 text: "test".to_string(),
4071 tokens: vec![1, 2],
4072 };
4073
4074 let cloned = segment.clone();
4075 assert_eq!(cloned.text, "test");
4076 assert_eq!(cloned.tokens, vec![1, 2]);
4077 }
4078
4079 #[test]
4080 fn test_segment_debug() {
4081 let segment = Segment {
4082 start: 0.0,
4083 end: 1.0,
4084 text: "hello".to_string(),
4085 tokens: vec![],
4086 };
4087
4088 let debug_str = format!("{segment:?}");
4089 assert!(debug_str.contains("hello"));
4090 }
4091
4092 #[test]
4093 fn test_transcription_result_clone() {
4094 let result = TranscriptionResult {
4095 text: "hello world".to_string(),
4096 language: "en".to_string(),
4097 segments: vec![Segment {
4098 start: 0.0,
4099 end: 1.0,
4100 text: "hello".to_string(),
4101 tokens: vec![1],
4102 }],
4103 profiling: None,
4104 };
4105
4106 let cloned = result.clone();
4107 assert_eq!(cloned.text, "hello world");
4108 assert_eq!(cloned.segments.len(), 1);
4109 }
4110
4111 #[test]
4112 fn test_transcription_result_debug() {
4113 let result = TranscriptionResult {
4114 text: "test".to_string(),
4115 language: "en".to_string(),
4116 segments: vec![],
4117 profiling: None,
4118 };
4119
4120 let debug_str = format!("{result:?}");
4121 assert!(debug_str.contains("TranscriptionResult"));
4122 }
4123
4124 #[test]
4125 #[ignore = "Allocates large model - run with --ignored"]
4126 fn test_whisper_debug() {
4127 let whisper = WhisperApr::tiny();
4128 let debug_str = format!("{whisper:?}");
4129 assert!(debug_str.contains("WhisperApr"));
4130 }
4131
4132 #[test]
4133 #[ignore = "Allocates large model - run with --ignored"]
4134 fn test_whisper_memory_size_all_models() {
4135 let tiny = WhisperApr::tiny();
4136 let base = WhisperApr::base();
4137
4138 assert!(tiny.memory_size() < base.memory_size());
4140
4141 assert!(tiny.memory_size() > 0);
4143 assert!(base.memory_size() > 0);
4144 }
4145
4146 #[test]
4147 fn test_transcribe_options_sampling_strategy() {
4148 let options = TranscribeOptions {
4149 language: None,
4150 task: Task::Transcribe,
4151 strategy: DecodingStrategy::Sampling {
4152 temperature: 1.0,
4153 top_k: None,
4154 top_p: None,
4155 },
4156 word_timestamps: false,
4157 profile: false,
4158 ..Default::default()
4159 };
4160
4161 assert!(matches!(
4162 options.strategy,
4163 DecodingStrategy::Sampling { .. }
4164 ));
4165 }
4166
4167 #[test]
4168 fn test_transcribe_options_word_timestamps_enabled() {
4169 let options = TranscribeOptions {
4170 language: Some("en".to_string()),
4171 task: Task::Transcribe,
4172 strategy: DecodingStrategy::default(),
4173 word_timestamps: true,
4174 profile: false,
4175 ..Default::default()
4176 };
4177
4178 assert!(options.word_timestamps);
4179 }
4180
4181 #[test]
4186 fn test_model_type_small() {
4187 let small = ModelType::Small;
4188 assert!(format!("{small:?}").contains("Small"));
4189 }
4190
4191 #[test]
4192 fn test_model_type_medium() {
4193 let medium = ModelType::Medium;
4194 assert!(format!("{medium:?}").contains("Medium"));
4195 }
4196
4197 #[test]
4198 fn test_model_type_medium_en() {
4199 let medium_en = ModelType::MediumEn;
4200 assert!(format!("{medium_en:?}").contains("MediumEn"));
4201 }
4202
4203 #[test]
4204 fn test_model_type_large() {
4205 let large = ModelType::Large;
4206 assert!(format!("{large:?}").contains("Large"));
4207 }
4208
4209 #[test]
4210 fn test_model_type_large_v1() {
4211 let large_v1 = ModelType::LargeV1;
4212 assert!(format!("{large_v1:?}").contains("LargeV1"));
4213 }
4214
4215 #[test]
4216 fn test_model_type_large_v2() {
4217 let large_v2 = ModelType::LargeV2;
4218 assert!(format!("{large_v2:?}").contains("LargeV2"));
4219 }
4220
4221 #[test]
4222 fn test_model_type_large_v3() {
4223 let large_v3 = ModelType::LargeV3;
4224 assert!(format!("{large_v3:?}").contains("LargeV3"));
4225 }
4226
4227 #[test]
4228 #[ignore = "Allocates large model - run with --ignored"]
4229 fn test_whisper_small() {
4230 let whisper = WhisperApr::small();
4231 assert_eq!(whisper.model_type(), ModelType::Small);
4232 assert_eq!(whisper.config().n_audio_layer, 12);
4233 assert_eq!(whisper.config().n_audio_state, 768);
4234 }
4235
4236 #[test]
4237 #[ignore = "Allocates large model - run with --ignored"]
4238 fn test_whisper_medium() {
4239 let whisper = WhisperApr::medium();
4240 assert_eq!(whisper.model_type(), ModelType::Medium);
4241 assert_eq!(whisper.config().n_audio_layer, 24);
4242 assert_eq!(whisper.config().n_audio_state, 1024);
4243 }
4244
4245 #[test]
4246 #[ignore = "Allocates large model - run with --ignored"]
4247 fn test_whisper_large() {
4248 let whisper = WhisperApr::large();
4249 assert_eq!(whisper.model_type(), ModelType::Large);
4250 assert_eq!(whisper.config().n_audio_layer, 32);
4251 assert_eq!(whisper.config().n_audio_state, 1280);
4252 }
4253
4254 #[test]
4255 #[ignore = "Allocates large model - run with --ignored"]
4256 fn test_whisper_memory_size_all_extended_models() {
4257 let small = WhisperApr::small();
4258 let medium = WhisperApr::medium();
4259 let large = WhisperApr::large();
4260
4261 assert!(small.memory_size() < medium.memory_size());
4263 assert!(medium.memory_size() < large.memory_size());
4264
4265 assert!(small.memory_size() > 0);
4267 assert!(medium.memory_size() > 0);
4268 assert!(large.memory_size() > 0);
4269 }
4270
4271 #[test]
4272 #[ignore = "Allocates large model - run with --ignored"]
4273 fn test_extended_model_memory_size_estimates() {
4274 let small = WhisperApr::small();
4275 let medium = WhisperApr::medium();
4276 let large = WhisperApr::large();
4277
4278 assert!(small.memory_size() > 900_000_000);
4280 assert!(small.memory_size() < 1_200_000_000);
4281
4282 assert!(medium.memory_size() > 2_500_000_000);
4284 assert!(medium.memory_size() < 3_500_000_000);
4285
4286 assert!(large.memory_size() > 5_000_000_000);
4288 assert!(large.memory_size() < 7_000_000_000);
4289 }
4290
4291 #[test]
4296 fn test_batch_transcription_result_len() {
4297 let result = BatchTranscriptionResult {
4298 results: vec![
4299 TranscriptionResult {
4300 text: "Hello".to_string(),
4301 language: "en".to_string(),
4302 segments: vec![],
4303 profiling: None,
4304 },
4305 TranscriptionResult {
4306 text: "World".to_string(),
4307 language: "en".to_string(),
4308 segments: vec![],
4309 profiling: None,
4310 },
4311 ],
4312 total_duration_secs: 1.5,
4313 };
4314
4315 assert_eq!(result.len(), 2);
4316 assert!(!result.is_empty());
4317 }
4318
4319 #[test]
4320 fn test_batch_transcription_result_empty() {
4321 let result = BatchTranscriptionResult {
4322 results: vec![],
4323 total_duration_secs: 0.0,
4324 };
4325
4326 assert!(result.is_empty());
4327 assert_eq!(result.len(), 0);
4328 }
4329
4330 #[test]
4331 fn test_batch_transcription_result_get() {
4332 let result = BatchTranscriptionResult {
4333 results: vec![
4334 TranscriptionResult {
4335 text: "First".to_string(),
4336 language: "en".to_string(),
4337 segments: vec![],
4338 profiling: None,
4339 },
4340 TranscriptionResult {
4341 text: "Second".to_string(),
4342 language: "es".to_string(),
4343 segments: vec![],
4344 profiling: None,
4345 },
4346 ],
4347 total_duration_secs: 2.0,
4348 };
4349
4350 assert!(result.get(0).is_some());
4351 assert_eq!(result.get(0).map(|r| r.text.as_str()), Some("First"));
4352 assert_eq!(result.get(1).map(|r| r.language.as_str()), Some("es"));
4353 assert!(result.get(2).is_none());
4354 }
4355
4356 #[test]
4357 fn test_batch_transcription_result_texts() {
4358 let result = BatchTranscriptionResult {
4359 results: vec![
4360 TranscriptionResult {
4361 text: "One".to_string(),
4362 language: "en".to_string(),
4363 segments: vec![],
4364 profiling: None,
4365 },
4366 TranscriptionResult {
4367 text: "Two".to_string(),
4368 language: "en".to_string(),
4369 segments: vec![],
4370 profiling: None,
4371 },
4372 TranscriptionResult {
4373 text: "Three".to_string(),
4374 language: "en".to_string(),
4375 segments: vec![],
4376 profiling: None,
4377 },
4378 ],
4379 total_duration_secs: 3.0,
4380 };
4381
4382 let texts = result.texts();
4383 assert_eq!(texts, vec!["One", "Two", "Three"]);
4384 }
4385
4386 #[test]
4387 fn test_batch_transcription_result_iter() {
4388 let result = BatchTranscriptionResult {
4389 results: vec![
4390 TranscriptionResult {
4391 text: "A".to_string(),
4392 language: "en".to_string(),
4393 segments: vec![],
4394 profiling: None,
4395 },
4396 TranscriptionResult {
4397 text: "B".to_string(),
4398 language: "en".to_string(),
4399 segments: vec![],
4400 profiling: None,
4401 },
4402 ],
4403 total_duration_secs: 1.0,
4404 };
4405
4406 let collected: Vec<&str> = result.iter().map(|r| r.text.as_str()).collect();
4407 assert_eq!(collected, vec!["A", "B"]);
4408 }
4409
4410 #[test]
4411 #[ignore = "Allocates large model - run with --ignored"]
4412 fn test_transcribe_batch_empty() {
4413 let whisper = WhisperApr::tiny();
4414 let result = whisper.transcribe_batch(&[], TranscribeOptions::default());
4415 assert!(result.is_err());
4416 }
4417
4418 #[test]
4419 #[ignore = "Allocates large model - run with --ignored"]
4420 fn test_transcribe_audio_batch_empty() {
4421 let whisper = WhisperApr::tiny();
4422 let batch = audio::AudioBatch::with_default_config();
4423 let result = whisper.transcribe_audio_batch(&batch, TranscribeOptions::default());
4424 assert!(result.is_err());
4425 }
4426
4427 #[test]
4428 #[ignore = "Allocates large model - run with --ignored"]
4429 fn test_transcribe_batch_optimized_empty() {
4430 let whisper = WhisperApr::tiny();
4431 let result = whisper.transcribe_batch_optimized(&[], TranscribeOptions::default());
4432 assert!(result.is_err());
4433 }
4434
4435 #[test]
4436 fn test_create_audio_batch() {
4437 let segments = vec![vec![0.1_f32, 0.2, 0.3], vec![0.4_f32, 0.5]];
4438
4439 let batch = WhisperApr::create_audio_batch(&segments);
4440 assert_eq!(batch.len(), 2);
4441 assert!(!batch.is_empty());
4442 }
4443
4444 #[test]
4445 fn test_create_audio_batch_empty() {
4446 let segments: Vec<Vec<f32>> = vec![];
4447 let batch = WhisperApr::create_audio_batch(&segments);
4448 assert!(batch.is_empty());
4449 }
4450
4451 #[test]
4452 fn test_batch_transcription_result_duration() {
4453 let result = BatchTranscriptionResult {
4454 results: vec![],
4455 total_duration_secs: 5.25,
4456 };
4457
4458 assert!((result.total_duration_secs - 5.25).abs() < f32::EPSILON);
4459 }
4460
4461 #[test]
4466 fn test_partial_transcription_result_new() {
4467 let result = PartialTranscriptionResult {
4468 text: "hello".to_string(),
4469 language: "en".to_string(),
4470 is_final: false,
4471 confidence: 0.95,
4472 duration_secs: 1.5,
4473 processing_time_secs: 0.3,
4474 };
4475
4476 assert_eq!(result.text, "hello");
4477 assert_eq!(result.language, "en");
4478 assert!(!result.is_final);
4479 assert!((result.confidence - 0.95).abs() < 0.01);
4480 }
4481
4482 #[test]
4483 fn test_partial_transcription_result_has_text() {
4484 let with_text = PartialTranscriptionResult {
4485 text: "hello".to_string(),
4486 language: "en".to_string(),
4487 is_final: false,
4488 confidence: 1.0,
4489 duration_secs: 1.0,
4490 processing_time_secs: 0.1,
4491 };
4492 let empty = PartialTranscriptionResult {
4493 text: String::new(),
4494 language: "en".to_string(),
4495 is_final: false,
4496 confidence: 0.0,
4497 duration_secs: 0.5,
4498 processing_time_secs: 0.05,
4499 };
4500
4501 assert!(with_text.has_text());
4502 assert!(!empty.has_text());
4503 }
4504
4505 #[test]
4506 fn test_partial_transcription_result_is_empty_interim() {
4507 let empty_interim = PartialTranscriptionResult {
4508 text: String::new(),
4509 language: "en".to_string(),
4510 is_final: false,
4511 confidence: 0.0,
4512 duration_secs: 0.5,
4513 processing_time_secs: 0.05,
4514 };
4515 let empty_final = PartialTranscriptionResult {
4516 text: String::new(),
4517 language: "en".to_string(),
4518 is_final: true,
4519 confidence: 0.0,
4520 duration_secs: 0.5,
4521 processing_time_secs: 0.05,
4522 };
4523 let with_text = PartialTranscriptionResult {
4524 text: "hello".to_string(),
4525 language: "en".to_string(),
4526 is_final: false,
4527 confidence: 1.0,
4528 duration_secs: 1.0,
4529 processing_time_secs: 0.1,
4530 };
4531
4532 assert!(empty_interim.is_empty_interim());
4533 assert!(!empty_final.is_empty_interim()); assert!(!with_text.is_empty_interim()); }
4536
4537 #[test]
4538 fn test_partial_transcription_result_real_time_factor() {
4539 let result = PartialTranscriptionResult {
4540 text: "hello".to_string(),
4541 language: "en".to_string(),
4542 is_final: false,
4543 confidence: 1.0,
4544 duration_secs: 2.0,
4545 processing_time_secs: 0.5,
4546 };
4547
4548 assert!((result.real_time_factor() - 0.25).abs() < 0.01);
4550 }
4551
4552 #[test]
4553 fn test_partial_transcription_result_real_time_factor_zero_duration() {
4554 let result = PartialTranscriptionResult {
4555 text: String::new(),
4556 language: "en".to_string(),
4557 is_final: false,
4558 confidence: 0.0,
4559 duration_secs: 0.0,
4560 processing_time_secs: 0.0,
4561 };
4562
4563 assert!((result.real_time_factor() - 0.0).abs() < 0.01);
4564 }
4565
4566 #[test]
4567 fn test_partial_transcription_result_debug_clone() {
4568 let result = PartialTranscriptionResult {
4569 text: "test".to_string(),
4570 language: "en".to_string(),
4571 is_final: true,
4572 confidence: 0.9,
4573 duration_secs: 1.0,
4574 processing_time_secs: 0.1,
4575 };
4576
4577 let debug_str = format!("{result:?}");
4578 assert!(debug_str.contains("PartialTranscriptionResult"));
4579
4580 let cloned = result.clone();
4581 assert_eq!(cloned.text, "test");
4582 assert!(cloned.is_final);
4583 }
4584
4585 #[test]
4586 #[ignore = "Allocates large model - run with --ignored"]
4587 fn test_transcribe_partial_too_short() {
4588 let whisper = WhisperApr::tiny();
4589 let short_audio = vec![0.0; 4000]; let result = whisper
4592 .transcribe_partial(&short_audio, TranscribeOptions::default(), false)
4593 .expect("should succeed with empty result");
4594
4595 assert!(result.text.is_empty());
4596 assert!(!result.is_final);
4597 assert!((result.confidence - 0.0).abs() < 0.01);
4598 }
4599
4600 #[test]
4601 #[ignore = "Allocates large model - run with --ignored"]
4602 fn test_encode_3_second_chunk() {
4603 let whisper = WhisperApr::tiny();
4606 let audio = vec![0.0; 48000]; let mel = whisper
4610 .compute_mel(&audio)
4611 .expect("mel computation should succeed");
4612
4613 let result = whisper.encode(&mel);
4615 assert!(
4616 result.is_ok(),
4617 "encode should succeed for 3s audio mel: {:?}",
4618 result.err()
4619 );
4620
4621 let encoded = result.expect("encode should succeed");
4623 let d_model = whisper.config().n_text_state as usize; assert_eq!(
4625 encoded.len() % d_model,
4626 0,
4627 "encoded output should be multiple of d_model"
4628 );
4629 }
4630
4631 #[test]
4632 #[ignore = "Allocates large model - run with --ignored"]
4633 fn test_create_streaming_session() {
4634 let whisper = WhisperApr::tiny();
4635 let session = whisper.create_streaming_session(TranscribeOptions::default(), 44100);
4636
4637 assert_eq!(session.state(), audio::ProcessorState::WaitingForSpeech);
4638 assert!((session.chunk_progress() - 0.0).abs() < 0.01);
4639 assert!(!session.has_chunk());
4640 assert!(!session.has_events());
4641 }
4642
4643 #[test]
4644 #[ignore = "Allocates large model - run with --ignored"]
4645 fn test_streaming_session_reset() {
4646 let whisper = WhisperApr::tiny();
4647 let mut session = whisper.create_streaming_session(TranscribeOptions::default(), 16000);
4648
4649 session.push(&vec![0.1; 1000]).expect("push should work");
4651
4652 session.reset();
4654
4655 assert_eq!(session.state(), audio::ProcessorState::WaitingForSpeech);
4656 assert!((session.partial_duration() - 0.0).abs() < 0.01);
4657 }
4658
4659 #[test]
4660 #[ignore = "Allocates large model - run with --ignored"]
4661 fn test_streaming_session_set_partial_threshold() {
4662 let whisper = WhisperApr::tiny();
4663 let mut session = whisper.create_streaming_session(TranscribeOptions::default(), 16000);
4664
4665 session.set_partial_threshold(5.0);
4666 }
4669
4670 #[test]
4671 #[ignore = "Allocates large model - run with --ignored"]
4672 fn test_streaming_session_finalize_no_chunk() {
4673 let whisper = WhisperApr::tiny();
4674 let mut session = whisper.create_streaming_session(TranscribeOptions::default(), 16000);
4675
4676 let result = session.finalize();
4677 assert!(result.is_err());
4678 }
4679
4680 #[test]
4681 #[ignore = "Allocates large model - run with --ignored"]
4682 fn test_streaming_session_flush_empty() {
4683 let whisper = WhisperApr::tiny();
4684 let mut session = whisper.create_streaming_session(TranscribeOptions::default(), 16000);
4685
4686 let result = session.flush().expect("flush should work");
4687 assert!(result.is_none());
4688 }
4689
4690 #[test]
4691 #[ignore = "Allocates large model - run with --ignored"]
4692 fn test_streaming_session_drain_events() {
4693 let whisper = WhisperApr::tiny();
4694 let mut session = whisper.create_streaming_session(TranscribeOptions::default(), 16000);
4695
4696 session.reset();
4698
4699 let events = session.drain_events();
4700 assert!(!events.is_empty());
4701 assert!(events
4702 .iter()
4703 .any(|e| matches!(e, audio::StreamingEvent::Reset)));
4704 }
4705
4706 #[test]
4707 #[ignore = "Allocates large model - run with --ignored"]
4708 fn test_streaming_session_push_silence() {
4709 let whisper = WhisperApr::tiny();
4710 let mut session = whisper.create_streaming_session(TranscribeOptions::default(), 16000);
4711
4712 let result = session.push(&vec![0.0; 16000]).expect("push should work");
4714 assert!(result.is_none()); }
4716
4717 #[test]
4718 #[ignore = "Allocates large model - run with --ignored"]
4719 fn test_streaming_session_debug() {
4720 let whisper = WhisperApr::tiny();
4721 let session = whisper.create_streaming_session(TranscribeOptions::default(), 16000);
4722
4723 let debug_str = format!("{session:?}");
4724 assert!(debug_str.contains("StreamingSession"));
4725 }
4726
4727 #[test]
4732 #[ignore = "Allocates large model - run with --ignored"]
4733 fn test_streaming_session_state() {
4734 let whisper = WhisperApr::tiny();
4735 let session = whisper.create_streaming_session(TranscribeOptions::default(), 16000);
4736
4737 let state = session.state();
4739 assert_eq!(state, audio::ProcessorState::WaitingForSpeech);
4740 }
4741
4742 #[test]
4743 #[ignore = "Allocates large model - run with --ignored"]
4744 fn test_streaming_session_chunk_progress() {
4745 let whisper = WhisperApr::tiny();
4746 let session = whisper.create_streaming_session(TranscribeOptions::default(), 16000);
4747
4748 let progress = session.chunk_progress();
4749 assert!(progress >= 0.0 && progress <= 1.0);
4750 }
4751
4752 #[test]
4753 #[ignore = "Allocates large model - run with --ignored"]
4754 fn test_streaming_session_partial_duration() {
4755 let whisper = WhisperApr::tiny();
4756 let session = whisper.create_streaming_session(TranscribeOptions::default(), 16000);
4757
4758 let duration = session.partial_duration();
4759 assert!(duration >= 0.0);
4760 }
4761
4762 #[test]
4763 fn test_partial_transcription_result_rtf_with_zero_processing() {
4764 let result = PartialTranscriptionResult {
4765 text: "test".to_string(),
4766 language: "en".to_string(),
4767 is_final: true,
4768 confidence: 0.9,
4769 duration_secs: 5.0,
4770 processing_time_secs: 0.0,
4771 };
4772
4773 let rtf = result.real_time_factor();
4774 assert!(rtf >= 0.0);
4775 }
4776
4777 #[test]
4778 fn test_vad_transcription_result_methods_empty() {
4779 let result = VadTranscriptionResult {
4780 text: "Hello world".to_string(),
4781 language: "en".to_string(),
4782 segments: vec![],
4783 speech_segments: vec![],
4784 total_duration_secs: 5.0,
4785 speech_duration_secs: 0.0,
4786 };
4787
4788 assert_eq!(result.num_segments(), 0);
4789 assert!(!result.has_speech());
4790 }
4791
4792 #[test]
4793 fn test_batch_transcription_result_defaults() {
4794 let result = BatchTranscriptionResult {
4795 results: vec![],
4796 total_duration_secs: 0.0,
4797 };
4798
4799 assert_eq!(result.len(), 0);
4800 assert!(result.is_empty());
4801 assert!(result.get(0).is_none());
4802 assert!(result.texts().is_empty());
4803 }
4804
4805 #[test]
4806 #[ignore = "Allocates large model - run with --ignored"]
4807 fn test_whisper_config_accessors() {
4808 let whisper = WhisperApr::tiny();
4809 let config = whisper.config();
4810
4811 assert!(config.n_vocab > 0);
4812 assert!(config.n_audio_ctx > 0);
4813 assert!(config.n_text_ctx > 0);
4814 }
4815
4816 #[test]
4817 fn test_transcribe_options_all_fields() {
4818 let options = TranscribeOptions {
4819 language: Some("fr".to_string()),
4820 task: Task::Translate,
4821 strategy: DecodingStrategy::BeamSearch {
4822 beam_size: 3,
4823 temperature: 0.2,
4824 patience: 1.5,
4825 },
4826 word_timestamps: true,
4827 profile: false,
4828 prompt: Some("domain-specific prompt".into()),
4829 hotwords: vec!["hotword1".into(), "hotword2".into()],
4830 };
4831
4832 assert_eq!(options.language, Some("fr".to_string()));
4833 assert_eq!(options.task, Task::Translate);
4834 assert!(options.word_timestamps);
4835 }
4836
4837 #[test]
4838 fn test_segment_with_tokens() {
4839 let segment = Segment {
4840 text: "Hello".to_string(),
4841 start: 0.0,
4842 end: 1.0,
4843 tokens: vec![1, 2, 3, 4, 5],
4844 };
4845
4846 assert_eq!(segment.tokens.len(), 5);
4847 assert!((segment.end - segment.start - 1.0).abs() < f32::EPSILON);
4848 }
4849
4850 #[test]
4851 fn test_model_from_config() {
4852 let config = model::ModelConfig::tiny();
4853 let whisper = WhisperApr::from_config(config);
4854
4855 assert_eq!(whisper.model_type(), ModelType::Tiny);
4856 }
4857
4858 #[test]
4859 fn test_decoding_strategy_variants() {
4860 let greedy = DecodingStrategy::Greedy;
4861 assert!(matches!(greedy, DecodingStrategy::Greedy));
4862
4863 let sampling = DecodingStrategy::Sampling {
4864 temperature: 0.5,
4865 top_k: Some(40),
4866 top_p: Some(0.9),
4867 };
4868 if let DecodingStrategy::Sampling {
4869 temperature,
4870 top_k,
4871 top_p,
4872 } = sampling
4873 {
4874 assert!((temperature - 0.5).abs() < f32::EPSILON);
4875 assert_eq!(top_k, Some(40));
4876 assert_eq!(top_p, Some(0.9));
4877 }
4878 }
4879
4880 #[test]
4881 fn test_task_variants_eq() {
4882 assert_eq!(Task::Transcribe, Task::Transcribe);
4883 assert_ne!(Task::Transcribe, Task::Translate);
4884 }
4885
4886 #[test]
4887 fn test_model_type_all_variants() {
4888 let variants = vec![
4889 ModelType::Tiny,
4890 ModelType::TinyEn,
4891 ModelType::Base,
4892 ModelType::BaseEn,
4893 ModelType::Small,
4894 ModelType::SmallEn,
4895 ModelType::Medium,
4896 ModelType::MediumEn,
4897 ModelType::Large,
4898 ModelType::LargeV1,
4899 ModelType::LargeV2,
4900 ModelType::LargeV3,
4901 ];
4902
4903 for variant in variants {
4904 let debug_str = format!("{variant:?}");
4905 assert!(!debug_str.is_empty());
4906 }
4907 }
4908
4909 #[test]
4910 fn test_vad_speech_segment_empty_text() {
4911 let segment = VadSpeechSegment {
4912 start: 0.0,
4913 end: 1.0,
4914 text: String::new(),
4915 tokens: vec![],
4916 };
4917
4918 assert!(!segment.has_text());
4919 assert!((segment.duration() - 1.0).abs() < f32::EPSILON);
4920 }
4921
4922 #[test]
4923 fn test_transcription_result_empty() {
4924 let result = TranscriptionResult {
4925 text: String::new(),
4926 language: "en".to_string(),
4927 segments: vec![],
4928 profiling: None,
4929 };
4930
4931 assert!(result.text.is_empty());
4932 assert!(result.segments.is_empty());
4933 }
4934
4935 #[test]
4936 fn test_batch_transcription_result_iter_coverage() {
4937 let results = vec![
4938 TranscriptionResult {
4939 text: "First".to_string(),
4940 language: "en".to_string(),
4941 segments: vec![],
4942 profiling: None,
4943 },
4944 TranscriptionResult {
4945 text: "Second".to_string(),
4946 language: "en".to_string(),
4947 segments: vec![],
4948 profiling: None,
4949 },
4950 ];
4951
4952 let batch = BatchTranscriptionResult {
4953 results,
4954 total_duration_secs: 1.0,
4955 };
4956
4957 let mut count = 0;
4958 for result in batch.iter() {
4959 count += 1;
4960 assert!(!result.text.is_empty());
4961 }
4962 assert_eq!(count, 2);
4963 }
4964
4965 #[test]
4966 #[ignore = "Allocates large model - run with --ignored"]
4967 fn test_whisper_clone() {
4968 let whisper = WhisperApr::tiny();
4969 let cloned = whisper.clone();
4970
4971 assert_eq!(whisper.model_type(), cloned.model_type());
4972 assert_eq!(whisper.memory_size(), cloned.memory_size());
4973 }
4974
4975 #[test]
4976 fn test_vad_transcription_result_first_last_segment() {
4977 let result = VadTranscriptionResult {
4978 text: "hello world".to_string(),
4979 language: "en".to_string(),
4980 segments: vec![
4981 VadSpeechSegment {
4982 start: 0.0,
4983 end: 1.0,
4984 text: "hello".to_string(),
4985 tokens: vec![1],
4986 },
4987 VadSpeechSegment {
4988 start: 1.5,
4989 end: 2.5,
4990 text: "world".to_string(),
4991 tokens: vec![2],
4992 },
4993 ],
4994 speech_segments: vec![(0.0, 1.0), (1.5, 2.5)],
4995 total_duration_secs: 3.0,
4996 speech_duration_secs: 2.0,
4997 };
4998
4999 let first = result.first_segment().expect("first segment");
5000 assert_eq!(first.text, "hello");
5001
5002 let last = result.last_segment().expect("last segment");
5003 assert_eq!(last.text, "world");
5004 }
5005
5006 #[test]
5007 fn test_vad_transcription_result_iter_segments() {
5008 let result = VadTranscriptionResult {
5009 text: "test".to_string(),
5010 language: "en".to_string(),
5011 segments: vec![
5012 VadSpeechSegment {
5013 start: 0.0,
5014 end: 1.0,
5015 text: "a".to_string(),
5016 tokens: vec![1],
5017 },
5018 VadSpeechSegment {
5019 start: 1.0,
5020 end: 2.0,
5021 text: "b".to_string(),
5022 tokens: vec![2],
5023 },
5024 ],
5025 speech_segments: vec![(0.0, 1.0), (1.0, 2.0)],
5026 total_duration_secs: 2.0,
5027 speech_duration_secs: 2.0,
5028 };
5029
5030 let mut count = 0;
5031 for segment in result.iter() {
5032 count += 1;
5033 assert!(segment.has_text());
5034 }
5035 assert_eq!(count, 2);
5036 }
5037
5038 #[test]
5039 fn test_partial_transcription_result_methods() {
5040 let result = PartialTranscriptionResult {
5041 text: "hello".to_string(),
5042 language: "en".to_string(),
5043 is_final: false,
5044 confidence: 0.85,
5045 duration_secs: 2.0,
5046 processing_time_secs: 0.5,
5047 };
5048
5049 assert!(result.has_text());
5050 assert!((result.real_time_factor() - 0.25).abs() < f32::EPSILON);
5051 }
5052
5053 #[test]
5054 fn test_partial_transcription_result_zero_duration() {
5055 let result = PartialTranscriptionResult {
5056 text: "".to_string(),
5057 language: "en".to_string(),
5058 is_final: true,
5059 confidence: 0.0,
5060 duration_secs: 0.0,
5061 processing_time_secs: 0.1,
5062 };
5063
5064 assert!(!result.has_text());
5065 assert!((result.real_time_factor() - 0.0).abs() < f32::EPSILON);
5066 }
5067
5068 #[test]
5073 #[ignore = "Requires model file - run with --ignored"]
5074 fn test_e2e_transcribe_with_int8_model() {
5075 let model_path = std::path::Path::new("models/whisper-tiny-int8.apr");
5077 if !model_path.exists() {
5078 eprintln!(
5079 "Skipping E2E test: model file not found at {:?}",
5080 model_path
5081 );
5082 return;
5083 }
5084
5085 let model_data = std::fs::read(model_path).expect("Failed to read model file");
5086 eprintln!("Loaded model: {} bytes", model_data.len());
5087
5088 let whisper = WhisperApr::load_from_apr(&model_data).expect("Failed to load model");
5089 eprintln!("Model loaded: {:?}", whisper.model_type());
5090
5091 let sample_rate = 16000;
5094 let duration_secs = 3.0;
5095 let num_samples = (sample_rate as f32 * duration_secs) as usize;
5096
5097 let audio: Vec<f32> = (0..num_samples)
5099 .map(|i| {
5100 let noise = ((i as f32 * 0.1).sin() * 0.001) + ((i as f32 * 0.37).cos() * 0.001);
5102 noise
5103 })
5104 .collect();
5105
5106 eprintln!("Generated {} samples of test audio", audio.len());
5107
5108 let start = std::time::Instant::now();
5114 let result = whisper.transcribe(&audio, TranscribeOptions::default());
5115 let elapsed = start.elapsed();
5116
5117 eprintln!("Transcription completed in {:?}", elapsed);
5118
5119 match result {
5120 Ok(transcription) => {
5121 eprintln!("Result: '{}'", transcription.text);
5122 eprintln!("Language: {}", transcription.language);
5123 eprintln!("Segments: {}", transcription.segments.len());
5124
5125 assert!(
5128 transcription.text.len() < 100,
5129 "Unexpected long transcription for silence"
5130 );
5131 }
5132 Err(e) => {
5133 panic!("Transcription failed: {e:?}");
5134 }
5135 }
5136
5137 let rtf = elapsed.as_secs_f32() / duration_secs;
5139 eprintln!("Real-time factor: {rtf:.2}x");
5140
5141 assert!(rtf < 50.0, "RTF {rtf} is too slow, SIMD may not be working");
5143 }
5144
5145 #[test]
5150 fn test_summarize_options_default_params() {
5151 use model::lfm2::{Lfm2, Lfm2Tokenizer};
5152
5153 let config = format::apr2::Lfm2Config {
5155 hidden_size: 64,
5156 num_layers: 2,
5157 num_q_heads: 4,
5158 num_kv_heads: 2,
5159 intermediate_size: 128,
5160 vocab_size: 1000,
5161 max_seq_len: 512,
5162 rope_theta: 10000.0,
5163 conv_dimension: 32,
5164 layer_types: vec![
5165 format::apr2::LayerType::Convolution {
5166 kernel_size: 4,
5167 cache_len: 3,
5168 },
5169 format::apr2::LayerType::Attention { use_gqa: true },
5170 ],
5171 };
5172 let model = Lfm2::new(config).expect("Model creation should succeed");
5173 let tokenizer = Lfm2Tokenizer::new();
5174
5175 let options = SummarizeOptions::new(&model, &tokenizer);
5176
5177 assert_eq!(options.max_tokens, 256);
5178 assert!((options.temperature - 0.3).abs() < f32::EPSILON);
5179 }
5180
5181 #[test]
5182 fn test_summarize_options_builder() {
5183 use model::lfm2::{Lfm2, Lfm2Tokenizer};
5184
5185 let config = format::apr2::Lfm2Config {
5186 hidden_size: 64,
5187 num_layers: 2,
5188 num_q_heads: 4,
5189 num_kv_heads: 2,
5190 intermediate_size: 128,
5191 vocab_size: 1000,
5192 max_seq_len: 512,
5193 rope_theta: 10000.0,
5194 conv_dimension: 32,
5195 layer_types: vec![
5196 format::apr2::LayerType::Convolution {
5197 kernel_size: 4,
5198 cache_len: 3,
5199 },
5200 format::apr2::LayerType::Attention { use_gqa: true },
5201 ],
5202 };
5203 let model = Lfm2::new(config).expect("Model creation should succeed");
5204 let tokenizer = Lfm2Tokenizer::new();
5205
5206 let options = SummarizeOptions::new(&model, &tokenizer)
5207 .with_max_tokens(512)
5208 .with_temperature(0.7);
5209
5210 assert_eq!(options.max_tokens, 512);
5211 assert!((options.temperature - 0.7).abs() < f32::EPSILON);
5212 }
5213
5214 #[test]
5215 fn test_transcribe_summary_result_accessors() {
5216 let transcription = TranscriptionResult {
5217 text: "Hello world".to_string(),
5218 language: "en".to_string(),
5219 segments: vec![],
5220 profiling: None,
5221 };
5222
5223 let result = TranscribeSummaryResult {
5224 transcription,
5225 summary: "Summary text".to_string(),
5226 generation_stats: None,
5227 };
5228
5229 assert_eq!(result.transcript(), "Hello world");
5230 assert_eq!(result.summary(), "Summary text");
5231 assert!(result.has_summary());
5232 }
5233
5234 #[test]
5235 fn test_transcribe_summary_result_empty_summary() {
5236 let transcription = TranscriptionResult {
5237 text: "Hello world".to_string(),
5238 language: "en".to_string(),
5239 segments: vec![],
5240 profiling: None,
5241 };
5242
5243 let result = TranscribeSummaryResult {
5244 transcription,
5245 summary: String::new(),
5246 generation_stats: None,
5247 };
5248
5249 assert!(!result.has_summary());
5250 }
5251
5252 #[test]
5253 #[ignore = "Slow test - run with --ignored"]
5254 fn test_transcribe_and_summarize_empty_audio() {
5255 use model::lfm2::{Lfm2, Lfm2Tokenizer};
5256
5257 let whisper = WhisperApr::tiny();
5258 let config = format::apr2::Lfm2Config {
5260 hidden_size: 64,
5261 num_layers: 2,
5262 num_q_heads: 4,
5263 num_kv_heads: 2,
5264 intermediate_size: 128,
5265 vocab_size: 1000,
5266 max_seq_len: 2048, rope_theta: 10000.0,
5268 conv_dimension: 32,
5269 layer_types: vec![
5270 format::apr2::LayerType::Convolution {
5271 kernel_size: 4,
5272 cache_len: 3,
5273 },
5274 format::apr2::LayerType::Attention { use_gqa: true },
5275 ],
5276 };
5277 let lfm2 = Lfm2::new(config).expect("Model creation should succeed");
5278 let tokenizer = Lfm2Tokenizer::new();
5279
5280 let audio = vec![0.0f32; 16000]; let transcribe_options = TranscribeOptions::default();
5283 let summarize_options = SummarizeOptions::new(&lfm2, &tokenizer).with_max_tokens(8); let result = whisper
5286 .transcribe_and_summarize(&audio, transcribe_options, summarize_options)
5287 .expect("Should not fail");
5288
5289 if result.transcription.text.trim().is_empty() {
5292 assert!(!result.has_summary());
5293 }
5294 }
5295
5296 #[test]
5297 #[ignore = "Slow test - run with --ignored"]
5298 fn test_transcribe_and_summarize_integration() {
5299 use model::lfm2::{Lfm2, Lfm2Tokenizer};
5300
5301 let whisper = WhisperApr::tiny();
5302 let config = format::apr2::Lfm2Config {
5304 hidden_size: 64,
5305 num_layers: 2,
5306 num_q_heads: 4,
5307 num_kv_heads: 2,
5308 intermediate_size: 128,
5309 vocab_size: 1000,
5310 max_seq_len: 2048, rope_theta: 10000.0,
5312 conv_dimension: 32,
5313 layer_types: vec![
5314 format::apr2::LayerType::Convolution {
5315 kernel_size: 4,
5316 cache_len: 3,
5317 },
5318 format::apr2::LayerType::Attention { use_gqa: true },
5319 ],
5320 };
5321 let lfm2 = Lfm2::new(config).expect("Model creation should succeed");
5322 let tokenizer = Lfm2Tokenizer::new();
5323
5324 let sample_rate = 16000;
5326 let duration_secs = 1.0;
5327 let num_samples = (sample_rate as f32 * duration_secs) as usize;
5328 let audio: Vec<f32> = (0..num_samples)
5329 .map(|i| (2.0 * std::f32::consts::PI * 440.0 * i as f32 / sample_rate as f32).sin())
5330 .collect();
5331
5332 let transcribe_options = TranscribeOptions::default();
5333 let summarize_options = SummarizeOptions::new(&lfm2, &tokenizer).with_max_tokens(8); let result =
5337 whisper.transcribe_and_summarize(&audio, transcribe_options, summarize_options);
5338
5339 assert!(result.is_ok());
5342
5343 let result = result.expect("Should succeed");
5344 assert_eq!(result.transcription.language, "en");
5346 }
5347
5348 #[test]
5353 fn test_chunk_constants() {
5354 assert_eq!(WhisperApr::CHUNK_SAMPLES, 30 * 16000);
5356
5357 assert_eq!(WhisperApr::OVERLAP_SAMPLES, 5 * 16000);
5359 }
5360
5361 #[test]
5362 #[ignore = "Allocates large model - run with --ignored"]
5363 fn test_short_audio_uses_single_chunk() {
5364 let _whisper = WhisperApr::tiny();
5365
5366 let short_audio = vec![0.0_f32; 10 * 16000]; assert!(short_audio.len() <= WhisperApr::CHUNK_SAMPLES);
5369
5370 }
5373
5374 #[test]
5375 #[ignore = "Allocates large model - run with --ignored"]
5376 fn test_long_audio_uses_chunking() {
5377 let _whisper = WhisperApr::tiny();
5378
5379 let long_audio = vec![0.0_f32; 60 * 16000]; assert!(long_audio.len() > WhisperApr::CHUNK_SAMPLES);
5382
5383 }
5386
5387 #[test]
5388 #[ignore = "Allocates large model - run with --ignored"]
5389 fn test_merge_overlapping_segments_empty() {
5390 let whisper = WhisperApr::tiny();
5391 let segments: Vec<Segment> = vec![];
5392 let merged = whisper.merge_overlapping_segments(segments);
5393 assert!(merged.is_empty());
5394 }
5395
5396 #[test]
5397 #[ignore = "Allocates large model - run with --ignored"]
5398 fn test_merge_overlapping_segments_single() {
5399 let whisper = WhisperApr::tiny();
5400 let segments = vec![Segment {
5401 start: 0.0,
5402 end: 5.0,
5403 text: "Hello".to_string(),
5404 tokens: vec![1, 2, 3],
5405 }];
5406 let merged = whisper.merge_overlapping_segments(segments);
5407 assert_eq!(merged.len(), 1);
5408 assert_eq!(merged[0].text, "Hello");
5409 }
5410
5411 #[test]
5412 #[ignore = "Allocates large model - run with --ignored"]
5413 fn test_merge_overlapping_segments_no_overlap() {
5414 let whisper = WhisperApr::tiny();
5415 let segments = vec![
5416 Segment {
5417 start: 0.0,
5418 end: 5.0,
5419 text: "Hello".to_string(),
5420 tokens: vec![1],
5421 },
5422 Segment {
5423 start: 10.0,
5424 end: 15.0,
5425 text: "World".to_string(),
5426 tokens: vec![2],
5427 },
5428 ];
5429 let merged = whisper.merge_overlapping_segments(segments);
5430 assert_eq!(merged.len(), 2);
5431 assert_eq!(merged[0].text, "Hello");
5432 assert_eq!(merged[1].text, "World");
5433 }
5434
5435 #[test]
5436 #[ignore = "Allocates large model - run with --ignored"]
5437 fn test_merge_overlapping_segments_with_overlap() {
5438 let whisper = WhisperApr::tiny();
5439 let segments = vec![
5440 Segment {
5441 start: 0.0,
5442 end: 5.0,
5443 text: "Hello".to_string(),
5444 tokens: vec![1],
5445 },
5446 Segment {
5447 start: 4.9, end: 10.0,
5449 text: "World".to_string(),
5450 tokens: vec![2],
5451 },
5452 ];
5453 let merged = whisper.merge_overlapping_segments(segments);
5454 assert_eq!(merged.len(), 1);
5455 assert_eq!(merged[0].text, "Hello World");
5456 assert_eq!(merged[0].start, 0.0);
5457 assert_eq!(merged[0].end, 10.0);
5458 assert_eq!(merged[0].tokens, vec![1, 2]);
5459 }
5460
5461 #[test]
5462 #[ignore = "Allocates large model - run with --ignored"]
5463 fn test_merge_overlapping_segments_multiple_overlaps() {
5464 let whisper = WhisperApr::tiny();
5465 let segments = vec![
5466 Segment {
5467 start: 0.0,
5468 end: 5.0,
5469 text: "A".to_string(),
5470 tokens: vec![1],
5471 },
5472 Segment {
5473 start: 4.95,
5474 end: 10.0,
5475 text: "B".to_string(),
5476 tokens: vec![2],
5477 },
5478 Segment {
5479 start: 9.95,
5480 end: 15.0,
5481 text: "C".to_string(),
5482 tokens: vec![3],
5483 },
5484 Segment {
5485 start: 20.0, end: 25.0,
5487 text: "D".to_string(),
5488 tokens: vec![4],
5489 },
5490 ];
5491 let merged = whisper.merge_overlapping_segments(segments);
5492 assert_eq!(merged.len(), 2);
5493 assert_eq!(merged[0].text, "A B C");
5494 assert_eq!(merged[0].end, 15.0);
5495 assert_eq!(merged[1].text, "D");
5496 assert_eq!(merged[1].start, 20.0);
5497 }
5498
5499 #[test]
5500 fn test_chunk_boundary_calculation() {
5501 let total_samples = 90 * 16000; let chunk_size = WhisperApr::CHUNK_SAMPLES; let overlap = WhisperApr::OVERLAP_SAMPLES; let chunk1_end = (0 + chunk_size + overlap).min(total_samples);
5508 assert_eq!(chunk1_end, 35 * 16000);
5509
5510 let chunk2_start = chunk_size; let chunk2_end = (chunk2_start + chunk_size + overlap).min(total_samples);
5513 assert_eq!(chunk2_end, 65 * 16000);
5514
5515 let chunk3_start = 2 * chunk_size; let chunk3_end = (chunk3_start + chunk_size + overlap).min(total_samples);
5518 assert_eq!(chunk3_end, total_samples);
5519 }
5520
5521 #[test]
5522 #[ignore = "Allocates large model - run with --ignored"]
5523 fn test_falsification_point_25_long_audio_full_transcription() {
5524 let _whisper = WhisperApr::tiny();
5529
5530 let ten_minutes_samples = 10 * 60 * 16000;
5532 assert!(ten_minutes_samples > WhisperApr::CHUNK_SAMPLES);
5533
5534 let expected_chunks =
5536 (ten_minutes_samples + WhisperApr::CHUNK_SAMPLES - 1) / WhisperApr::CHUNK_SAMPLES;
5537 assert_eq!(expected_chunks, 20); let mut offset = 0;
5541 let mut chunk_count = 0;
5542 while offset < ten_minutes_samples {
5543 let chunk_end = (offset + WhisperApr::CHUNK_SAMPLES + WhisperApr::OVERLAP_SAMPLES)
5544 .min(ten_minutes_samples);
5545 let _chunk_len = chunk_end - offset;
5546 offset += WhisperApr::CHUNK_SAMPLES;
5547 chunk_count += 1;
5548 }
5549 assert_eq!(chunk_count, 20);
5550 }
5551
5552 #[test]
5553 fn test_falsification_point_30_streaming_consistency() {
5554 let overlap_seconds = WhisperApr::OVERLAP_SAMPLES as f32 / 16000.0;
5560 assert!((overlap_seconds - 5.0).abs() < 0.01);
5561
5562 let expected_overlap_words = overlap_seconds * 2.5;
5565 assert!(expected_overlap_words >= 10.0);
5566 }
5567
5568 #[test]
5573 #[ignore = "Allocates large model - run with --ignored"]
5574 fn test_moonshine_tiny_constructs() {
5575 let model = WhisperApr::moonshine_tiny();
5576 assert_eq!(model.config.n_audio_state, 288);
5577 assert_eq!(model.config.n_text_state, 288);
5578 assert_eq!(model.config.n_audio_layer, 6);
5579 assert_eq!(model.config.n_text_layer, 6);
5580 assert_eq!(model.config.n_audio_head, 8);
5581 assert_eq!(model.config.n_text_head, 8);
5582 }
5583
5584 #[test]
5585 #[ignore = "Allocates large model - run with --ignored"]
5586 fn test_moonshine_tiny_encoder_has_moonshine_blocks() {
5587 let model = WhisperApr::moonshine_tiny();
5588 assert!(model.encoder.moonshine_blocks().len() > 0);
5590 assert!(model.encoder.blocks().is_empty());
5591 assert!(model.encoder.rope().is_some());
5592 }
5593
5594 #[test]
5595 #[ignore = "Allocates large model - run with --ignored"]
5596 fn test_moonshine_tiny_decoder_has_moonshine_blocks() {
5597 let model = WhisperApr::moonshine_tiny();
5598 assert!(model.decoder.moonshine_blocks().len() > 0);
5600 assert!(model.decoder.blocks().is_empty());
5601 assert!(model.decoder.rope().is_some());
5602 }
5603
5604 #[test]
5605 #[ignore = "Allocates large model - run with --ignored"]
5606 fn test_moonshine_tiny_has_sentencepiece_tokenizer() {
5607 let model = WhisperApr::moonshine_tiny();
5608 match model.tokenizer() {
5609 tokenizer::Tokenizer::SentencePiece(_) => {} tokenizer::Tokenizer::Bpe(_) => panic!("Moonshine should use SentencePiece tokenizer"),
5611 }
5612 }
5613
5614 #[test]
5615 #[ignore = "Allocates large model - run with --ignored"]
5616 fn test_moonshine_tiny_has_conv_stem() {
5617 let model = WhisperApr::moonshine_tiny();
5618 assert!(model.conv_stem.is_some());
5619 assert!(model.mel_filters.is_none());
5620 }
5621
5622 #[test]
5623 #[ignore = "Allocates large model - run with --ignored"]
5624 fn test_moonshine_tiny_initial_tokens() {
5625 let model = WhisperApr::moonshine_tiny();
5626 let tokens = model.get_initial_tokens("en", crate::Task::Transcribe);
5627 assert_eq!(tokens, vec![1]);
5629 }
5630
5631 #[test]
5632 #[ignore = "Allocates large model - run with --ignored"]
5633 fn test_moonshine_tiny_eot_token() {
5634 let model = WhisperApr::moonshine_tiny();
5635 assert_eq!(model.eot_token(), 2); }
5637
5638 #[test]
5639 #[ignore = "Allocates large model - run with --ignored"]
5640 fn test_moonshine_encoder_forward_shape() {
5641 let model = WhisperApr::moonshine_tiny();
5642 let d_model = 288;
5643 let seq_len = 7; let features = vec![0.1_f32; seq_len * d_model];
5647 let output = model.encoder.forward(&features).expect("encoder forward");
5648
5649 assert_eq!(output.len(), seq_len * d_model);
5651 assert!(output.iter().all(|v| v.is_finite()));
5652 }
5653
5654 #[test]
5655 #[ignore = "Allocates large model - run with --ignored"]
5656 fn test_moonshine_decoder_forward_shape() {
5657 let model = WhisperApr::moonshine_tiny();
5658 let d_model = 288;
5659 let n_vocab = model.config.n_vocab as usize;
5660
5661 let enc_seq_len = 7;
5663 let encoder_output = vec![0.1_f32; enc_seq_len * d_model];
5664
5665 let tokens = vec![1_u32, 100];
5667 let logits = model
5668 .decoder
5669 .forward(&tokens, &encoder_output)
5670 .expect("decoder forward");
5671
5672 assert_eq!(logits.len(), tokens.len() * n_vocab);
5674 assert!(logits.iter().all(|v| v.is_finite()));
5675 }
5676
5677 #[test]
5678 #[ignore = "Allocates large model - run with --ignored"]
5679 fn test_moonshine_variable_length_input() {
5680 let model = WhisperApr::moonshine_tiny();
5681 let d_model = 288;
5682
5683 for seq_len in [3, 7, 15, 31] {
5685 let features = vec![0.1_f32; seq_len * d_model];
5686 let output = model.encoder.forward(&features).expect("encoder forward");
5687 assert_eq!(
5688 output.len(),
5689 seq_len * d_model,
5690 "seq_len={seq_len} output mismatch"
5691 );
5692 }
5693 }
5694
5695 #[test]
5696 #[ignore = "Allocates large model - run with --ignored"]
5697 fn test_whisper_tiny_unaffected() {
5698 let model = WhisperApr::tiny();
5700 assert!(model.encoder.blocks().len() > 0);
5701 assert!(model.encoder.moonshine_blocks().is_empty());
5702 assert!(model.encoder.rope().is_none());
5703 assert!(model.decoder.blocks().len() > 0);
5704 assert!(model.decoder.moonshine_blocks().is_empty());
5705 assert!(model.decoder.rope().is_none());
5706 assert!(model.mel_filters.is_some());
5707 assert!(model.conv_stem.is_none());
5708 match model.tokenizer() {
5709 tokenizer::Tokenizer::Bpe(_) => {} tokenizer::Tokenizer::SentencePiece(_) => {
5711 panic!("Whisper should use BPE tokenizer")
5712 }
5713 }
5714 }
5715
5716 #[test]
5721 fn test_moonshine_encoder_uses_rmsnorm_final() {
5722 let config = model::ModelConfig::moonshine_tiny();
5724 let encoder = model::Encoder::new(&config);
5725 assert!(
5726 encoder.ln_post_rms().is_some(),
5727 "Moonshine encoder must use RmsNorm for final layer norm"
5728 );
5729 }
5730
5731 #[test]
5732 fn test_whisper_encoder_uses_layernorm_final() {
5733 let config = model::ModelConfig::tiny();
5735 let encoder = model::Encoder::new(&config);
5736 assert!(
5737 encoder.ln_post_rms().is_none(),
5738 "Whisper encoder must use LayerNorm, not RmsNorm"
5739 );
5740 }
5741
5742 #[test]
5743 #[ignore = "Allocates large model - run with --ignored"]
5744 fn test_moonshine_decoder_forward_one_sets_cross_attn_cached() {
5745 let model = WhisperApr::moonshine_tiny();
5747 let d_model = 288;
5748 let n_layers = model.decoder.n_layers();
5749 let max_len = model.decoder.max_len();
5750 let enc_seq_len = 7;
5751 let encoder_output = vec![0.1_f32; enc_seq_len * d_model];
5752
5753 let mut cache = model::DecoderKVCache::new(n_layers, d_model, max_len);
5754 assert!(!cache.cross_attn_cached);
5755
5756 let _logits = model
5758 .decoder
5759 .forward_one(1, &encoder_output, &mut cache)
5760 .expect("forward_one");
5761
5762 assert!(
5763 cache.cross_attn_cached,
5764 "cross_attn_cached must be true after forward_one for Moonshine"
5765 );
5766 }
5767
5768 #[test]
5769 #[ignore = "Allocates large model - run with --ignored"]
5770 fn test_moonshine_decoder_forward_one_cache_has_layers() {
5771 let config = model::ModelConfig::moonshine_tiny();
5773 let decoder = model::Decoder::new(&config);
5774
5775 let cache =
5777 model::DecoderKVCache::new(decoder.n_layers(), decoder.d_model(), decoder.max_len());
5778 assert!(
5779 !cache.self_attn_cache.is_empty(),
5780 "Moonshine decoder cache must have at least 1 layer"
5781 );
5782 }
5783
5784 #[test]
5785 fn test_moonshine_decoder_is_finalized() {
5786 let config = model::ModelConfig::moonshine_tiny();
5788 let decoder = model::Decoder::new(&config);
5789 assert!(
5791 decoder.is_finalized(),
5792 "Moonshine decoder should be considered finalized"
5793 );
5794 }
5795
5796 #[test]
5797 fn test_whisper_decoder_is_finalized_initially_false() {
5798 let config = model::ModelConfig::tiny();
5800 let decoder = model::Decoder::new(&config);
5801 assert!(
5803 !decoder.is_finalized(),
5804 "Whisper decoder should NOT be finalized before finalize_weights()"
5805 );
5806 }
5807
5808 #[test]
5809 #[ignore = "Allocates large model - run with --ignored"]
5810 fn test_moonshine_decoder_forward_traced() {
5811 let model = WhisperApr::moonshine_tiny();
5813 let d_model = 288;
5814 let enc_seq_len = 7;
5815 let encoder_output = vec![0.1_f32; enc_seq_len * d_model];
5816
5817 let (logits, trace) = model
5818 .decoder
5819 .forward_traced(&[1, 2], &encoder_output)
5820 .expect("forward_traced");
5821
5822 let layer_traces: Vec<_> = trace
5824 .iter()
5825 .filter(|(name, _)| name.starts_with("layer_"))
5826 .collect();
5827 assert!(
5828 !layer_traces.is_empty(),
5829 "forward_traced must produce layer traces for Moonshine"
5830 );
5831
5832 assert!(!logits.is_empty());
5833 assert!(logits.iter().all(|v| v.is_finite()));
5834 }
5835
5836 #[test]
5837 fn test_moonshine_encoder_forward_batch_uses_forward() {
5838 let config = model::ModelConfig::moonshine_tiny();
5840 let encoder = model::Encoder::new(&config);
5841 let d_model = 288;
5842
5843 let batch = vec![vec![0.1_f32; 7 * d_model], vec![0.1_f32; 5 * d_model]];
5845 let results = encoder.forward_batch(&batch).expect("forward_batch");
5846 assert_eq!(results.len(), 2);
5847 assert_eq!(results[0].len(), 7 * d_model);
5848 assert_eq!(results[1].len(), 5 * d_model);
5849 }
5850}