Skip to main content

whisperforge_core/
lib.rs

1use anyhow::Result;
2use burn::tensor::{Tensor, backend::Backend};
3
4pub mod attn_extract;
5pub mod audio;
6pub mod audio_capture;
7pub mod decoding;
8pub mod embed;
9pub mod kv_cache;
10pub mod language;
11pub mod load;
12pub mod model;
13pub mod stream_decode;
14pub mod streaming;
15pub mod transcribe;
16pub mod vad_silero;
17
18#[cfg(feature = "file-io")]
19pub mod audio_stream;
20
21#[cfg(feature = "cubecl-stft")]
22pub mod stft_gpu;
23
24pub use attn_extract::forward_decoder_with_cross_attn;
25pub use audio::batch_mel_spectrograms;
26pub use audio::compute_mel_from_samples;
27pub use audio::prepare_centered_samples_raw;
28pub use audio_capture::{CaptureSource, FakeMic, MicCapture, list_input_devices};
29pub use decoding::{BeamSearchDecoder, DecodingConfig, GreedyDecoder, HybridDecoder};
30pub use embed::extract_speaker_embedding;
31pub use kv_cache::{KvCache, forward_decoder_cached};
32pub use language::{LANGUAGE_CODES, Task, detect_language, language_token_id, task_token_id};
33pub use load::{load_config_from_bytes, load_whisper_from_bytes};
34pub use model::{AudioEncoderConfig, TextDecoderConfig, Whisper, WhisperConfig};
35pub use stream_decode::{
36    DecodeContext, QualityGate, TokenEmit, avg_logprob, decode_window, passes_quality_gate,
37};
38pub use streaming::{
39    Chunker, CommitDelta, Committer, EndpointConfig, Endpointer, PromptContext, StreamWindow,
40    WindowConfig,
41};
42pub use transcribe::{WhisperTranscriber, transcribe_audio};
43pub use vad_silero::{SileroVad, ensure_silero_model};
44
45#[cfg(feature = "file-io")]
46pub use audio::load_audio_file;
47#[cfg(feature = "file-io")]
48pub use audio_stream::{AudioChunk, AudioChunkIterator};
49#[cfg(feature = "file-io")]
50pub use load::{load_config, load_whisper};
51
52#[cfg(feature = "cubecl-stft")]
53pub use stft_gpu::compute_stft_power_gpu;
54
55#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
56pub struct TranscriptionSegment {
57    pub start: f32,
58    pub end: f32,
59    pub text: String,
60    pub tokens: Vec<u32>,
61    pub confidence: f32,
62    /// Per-token timestamps in seconds derived from cross-attention peaks.
63    /// Populated by `transcribe_with_timestamps`; empty for plain `transcribe`.
64    #[serde(default, skip_serializing_if = "Vec::is_empty")]
65    pub token_timestamps: Vec<f32>,
66    /// Speaker label assigned by diarization (e.g. `"SPEAKER_00"`).
67    /// `None` when diarization was not requested.
68    #[serde(default, skip_serializing_if = "Option::is_none")]
69    pub speaker: Option<String>,
70}
71
72#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
73pub struct TranscriptionResult {
74    pub text: String,
75    pub segments: Vec<TranscriptionSegment>,
76    pub language: Option<String>,
77}
78
79pub trait WhisperInference<B: Backend> {
80    fn transcribe(&self, mel_features: Tensor<B, 3>) -> Result<TranscriptionResult>;
81    fn transcribe_with_timestamps(&self, mel_features: Tensor<B, 3>)
82    -> Result<TranscriptionResult>;
83}
84
85pub const SPECIAL_TOKENS: &str = "<|startoftranscript|> <|translate|> <|transcribe|> <|en|> <|zh|> <|de|> <|es|> <|ru|> <|ko|> <|fr|> <|ja|> <|pt|> <|tr|> <|pl|> <|ca|> <|nl|> <|ar|> <|sv|> <|it|> <|id|> <|hi|> <|fi|> <|vi|> <|he|> <|uk|> <|el|> <|ms|> <|cs|> <|ro|> <|da|> <|hu|> <|ta|> <|no|> <|th|> <|ur|> <|hr|> <|bg|> <|lt|> <|la|> <|mi|> <|ml|> <|cy|> <|sk|> <|te|> <|fa|> <|lv|> <|bn|> <|sr|> <|az|> <|sl|> <|kn|> <|et|> <|mk|> <|br|> <|eu|> <|is|> <|hy|> <|ne|> <|mn|> <|bs|> <|kk|> <|sq|> <|sw|> <|gl|> <|mr|> <|pa|> <|si|> <|km|> <|sn|> <|yo|> <|so|> <|af|> <|oc|> <|ka|> <|be|> <|tg|> <|sd|> <|gu|> <|am|> <|yi|> <|lo|> <|uz|> <|fo|> <|ht|> <|ps|> <|tk|> <|nn|> <|mt|> <|sa|> <|lb|> <|my|> <|bo|> <|tl|> <|mg|> <|as|> <|tt|> <|haw|> <|ln|> <|ha|> <|ba|> <|jw|> <|su|> <|yue|> <|notimestamps|>";