whispr/
models.rs

1//! Model types for OpenAI Audio API.
2
3use serde::{Deserialize, Serialize};
4
5/// Available voices for text-to-speech.
6///
7/// Previews of the voices are available at [OpenAI's Text to Speech guide](https://platform.openai.com/docs/guides/text-to-speech).
8#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
9#[serde(rename_all = "lowercase")]
10#[non_exhaustive]
11pub enum Voice {
12    /// Alloy voice
13    #[default]
14    Alloy,
15    /// Ash voice
16    Ash,
17    /// Ballad voice
18    Ballad,
19    /// Coral voice
20    Coral,
21    /// Echo voice
22    Echo,
23    /// Fable voice
24    Fable,
25    /// Nova voice
26    Nova,
27    /// Onyx voice
28    Onyx,
29    /// Sage voice
30    Sage,
31    /// Shimmer voice
32    Shimmer,
33    /// Verse voice
34    Verse,
35}
36
37impl Voice {
38    /// Returns the string identifier for this voice.
39    pub fn as_str(&self) -> &'static str {
40        match self {
41            Voice::Alloy => "alloy",
42            Voice::Ash => "ash",
43            Voice::Ballad => "ballad",
44            Voice::Coral => "coral",
45            Voice::Echo => "echo",
46            Voice::Fable => "fable",
47            Voice::Nova => "nova",
48            Voice::Onyx => "onyx",
49            Voice::Sage => "sage",
50            Voice::Shimmer => "shimmer",
51            Voice::Verse => "verse",
52        }
53    }
54}
55
56impl std::fmt::Display for Voice {
57    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
58        write!(f, "{}", self.as_str())
59    }
60}
61
62/// Available text-to-speech models.
63#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
64#[non_exhaustive]
65pub enum TtsModel {
66    /// GPT-4o Mini TTS - newest and most capable, supports instructions
67    #[default]
68    #[serde(rename = "gpt-4o-mini-tts")]
69    Gpt4oMiniTts,
70    /// TTS-1 - optimized for speed
71    #[serde(rename = "tts-1")]
72    Tts1,
73    /// TTS-1 HD - optimized for quality
74    #[serde(rename = "tts-1-hd")]
75    Tts1Hd,
76}
77
78impl TtsModel {
79    /// Returns the string identifier for this model.
80    pub fn as_str(&self) -> &'static str {
81        match self {
82            TtsModel::Gpt4oMiniTts => "gpt-4o-mini-tts",
83            TtsModel::Tts1 => "tts-1",
84            TtsModel::Tts1Hd => "tts-1-hd",
85        }
86    }
87
88    /// Returns whether this model supports the `instructions` parameter.
89    pub fn supports_instructions(&self) -> bool {
90        matches!(self, TtsModel::Gpt4oMiniTts)
91    }
92}
93
94impl std::fmt::Display for TtsModel {
95    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
96        write!(f, "{}", self.as_str())
97    }
98}
99
100/// Available audio output formats for text-to-speech.
101#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
102#[serde(rename_all = "lowercase")]
103#[non_exhaustive]
104pub enum AudioFormat {
105    /// MP3 format - default, good for general use
106    #[default]
107    Mp3,
108    /// Opus format - good for internet streaming, low latency
109    Opus,
110    /// AAC format - preferred by YouTube, Android, iOS
111    Aac,
112    /// FLAC format - lossless compression
113    Flac,
114    /// WAV format - uncompressed, good for low-latency applications
115    Wav,
116    /// PCM format - raw samples at 24kHz, 16-bit signed, low-endian
117    Pcm,
118}
119
120impl AudioFormat {
121    /// Returns the string identifier for this format.
122    pub fn as_str(&self) -> &'static str {
123        match self {
124            AudioFormat::Mp3 => "mp3",
125            AudioFormat::Opus => "opus",
126            AudioFormat::Aac => "aac",
127            AudioFormat::Flac => "flac",
128            AudioFormat::Wav => "wav",
129            AudioFormat::Pcm => "pcm",
130        }
131    }
132
133    /// Returns the MIME type for this audio format.
134    pub fn mime_type(&self) -> &'static str {
135        match self {
136            AudioFormat::Mp3 => "audio/mpeg",
137            AudioFormat::Opus => "audio/opus",
138            AudioFormat::Aac => "audio/aac",
139            AudioFormat::Flac => "audio/flac",
140            AudioFormat::Wav => "audio/wav",
141            AudioFormat::Pcm => "audio/pcm",
142        }
143    }
144
145    /// Returns the file extension for this format.
146    pub fn extension(&self) -> &'static str {
147        match self {
148            AudioFormat::Mp3 => "mp3",
149            AudioFormat::Opus => "opus",
150            AudioFormat::Aac => "aac",
151            AudioFormat::Flac => "flac",
152            AudioFormat::Wav => "wav",
153            AudioFormat::Pcm => "pcm",
154        }
155    }
156}
157
158impl std::fmt::Display for AudioFormat {
159    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
160        write!(f, "{}", self.as_str())
161    }
162}
163
164/// Available transcription (speech-to-text) models.
165#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
166#[non_exhaustive]
167pub enum TranscriptionModel {
168    /// GPT-4o Transcribe - high quality transcription
169    #[default]
170    #[serde(rename = "gpt-4o-transcribe")]
171    Gpt4oTranscribe,
172    /// GPT-4o Mini Transcribe - faster, smaller model
173    #[serde(rename = "gpt-4o-mini-transcribe")]
174    Gpt4oMiniTranscribe,
175    /// Whisper-1 - open source Whisper V2 model
176    #[serde(rename = "whisper-1")]
177    Whisper1,
178    /// GPT-4o Transcribe with Diarization - includes speaker labels
179    #[serde(rename = "gpt-4o-transcribe-diarize")]
180    Gpt4oTranscribeDiarize,
181}
182
183impl TranscriptionModel {
184    /// Returns the string identifier for this model.
185    pub fn as_str(&self) -> &'static str {
186        match self {
187            TranscriptionModel::Gpt4oTranscribe => "gpt-4o-transcribe",
188            TranscriptionModel::Gpt4oMiniTranscribe => "gpt-4o-mini-transcribe",
189            TranscriptionModel::Whisper1 => "whisper-1",
190            TranscriptionModel::Gpt4oTranscribeDiarize => "gpt-4o-transcribe-diarize",
191        }
192    }
193
194    /// Returns whether this model supports streaming.
195    pub fn supports_streaming(&self) -> bool {
196        !matches!(self, TranscriptionModel::Whisper1)
197    }
198
199    /// Returns whether this model supports the `prompt` parameter.
200    pub fn supports_prompt(&self) -> bool {
201        !matches!(self, TranscriptionModel::Gpt4oTranscribeDiarize)
202    }
203
204    /// Returns whether this model supports diarization.
205    pub fn supports_diarization(&self) -> bool {
206        matches!(self, TranscriptionModel::Gpt4oTranscribeDiarize)
207    }
208}
209
210impl std::fmt::Display for TranscriptionModel {
211    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
212        write!(f, "{}", self.as_str())
213    }
214}
215
216/// Response format for transcription output.
217#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
218#[serde(rename_all = "snake_case")]
219#[non_exhaustive]
220pub enum TranscriptionResponseFormat {
221    /// JSON format (default)
222    #[default]
223    Json,
224    /// Plain text format
225    Text,
226    /// SRT subtitle format (whisper-1 only)
227    Srt,
228    /// Verbose JSON with additional metadata (whisper-1 only)
229    VerboseJson,
230    /// VTT subtitle format (whisper-1 only)
231    Vtt,
232    /// Diarized JSON with speaker labels (gpt-4o-transcribe-diarize only)
233    DiarizedJson,
234}
235
236impl TranscriptionResponseFormat {
237    /// Returns the string identifier for this format.
238    pub fn as_str(&self) -> &'static str {
239        match self {
240            TranscriptionResponseFormat::Json => "json",
241            TranscriptionResponseFormat::Text => "text",
242            TranscriptionResponseFormat::Srt => "srt",
243            TranscriptionResponseFormat::VerboseJson => "verbose_json",
244            TranscriptionResponseFormat::Vtt => "vtt",
245            TranscriptionResponseFormat::DiarizedJson => "diarized_json",
246        }
247    }
248}
249
250impl std::fmt::Display for TranscriptionResponseFormat {
251    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
252        write!(f, "{}", self.as_str())
253    }
254}
255
256/// Token usage information.
257#[derive(Debug, Clone, Deserialize)]
258pub struct Usage {
259    /// Type of usage tracking
260    #[serde(rename = "type")]
261    pub usage_type: Option<String>,
262    /// Number of input tokens
263    pub input_tokens: Option<u32>,
264    /// Number of output tokens
265    pub output_tokens: Option<u32>,
266    /// Total number of tokens
267    pub total_tokens: Option<u32>,
268    /// Duration in seconds (for duration-based billing)
269    pub seconds: Option<u32>,
270    /// Input token details
271    pub input_token_details: Option<InputTokenDetails>,
272}
273
274/// Details about input token usage.
275#[derive(Debug, Clone, Deserialize)]
276pub struct InputTokenDetails {
277    /// Number of text tokens
278    pub text_tokens: Option<u32>,
279    /// Number of audio tokens
280    pub audio_tokens: Option<u32>,
281}
282
283/// Response from the transcription API.
284#[derive(Debug, Clone, Deserialize)]
285pub struct TranscriptionResponse {
286    /// The transcribed text
287    pub text: String,
288    /// Token usage information (when available)
289    #[serde(default)]
290    pub usage: Option<Usage>,
291    /// The language of the input audio (verbose_json only)
292    #[serde(default)]
293    pub language: Option<String>,
294    /// The duration of the input audio in seconds (verbose_json only)
295    #[serde(default)]
296    pub duration: Option<f64>,
297    /// Segments of the transcribed text (verbose_json and diarized_json)
298    #[serde(default)]
299    pub segments: Option<Vec<TranscriptionSegment>>,
300    /// Words with timestamps (verbose_json with word timestamps)
301    #[serde(default)]
302    pub words: Option<Vec<TranscriptionWord>>,
303    /// The task that was performed
304    #[serde(default)]
305    pub task: Option<String>,
306}
307
308/// A segment of transcribed text.
309#[derive(Debug, Clone, Deserialize)]
310pub struct TranscriptionSegment {
311    /// Segment ID
312    pub id: Option<serde_json::Value>, // Can be int or string depending on model
313    /// Start time in seconds
314    pub start: Option<f64>,
315    /// End time in seconds
316    pub end: Option<f64>,
317    /// Transcribed text for this segment
318    pub text: String,
319    /// Speaker label (diarization only)
320    #[serde(default)]
321    pub speaker: Option<String>,
322    /// Seek position (whisper-1 verbose_json)
323    #[serde(default)]
324    pub seek: Option<u32>,
325    /// Tokens for this segment (whisper-1 verbose_json)
326    #[serde(default)]
327    pub tokens: Option<Vec<u32>>,
328    /// Temperature used (whisper-1 verbose_json)
329    #[serde(default)]
330    pub temperature: Option<f64>,
331    /// Average log probability (whisper-1 verbose_json)
332    #[serde(default)]
333    pub avg_logprob: Option<f64>,
334    /// Compression ratio (whisper-1 verbose_json)
335    #[serde(default)]
336    pub compression_ratio: Option<f64>,
337    /// No speech probability (whisper-1 verbose_json)
338    #[serde(default)]
339    pub no_speech_prob: Option<f64>,
340}
341
342/// A word with timestamp information.
343#[derive(Debug, Clone, Deserialize)]
344pub struct TranscriptionWord {
345    /// The word
346    pub word: String,
347    /// Start time in seconds
348    pub start: f64,
349    /// End time in seconds
350    pub end: f64,
351}
352
353/// Log probability information for a token.
354#[allow(dead_code)]
355#[derive(Debug, Clone, Deserialize)]
356pub struct LogProb {
357    /// The token
358    pub token: String,
359    /// The log probability
360    pub logprob: f64,
361    /// Byte representation
362    #[serde(default)]
363    pub bytes: Option<Vec<u8>>,
364}
365
366/// Timestamp granularity options for transcription.
367#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
368#[serde(rename_all = "lowercase")]
369pub enum TimestampGranularity {
370    /// Word-level timestamps
371    Word,
372    /// Segment-level timestamps
373    Segment,
374}
375
376impl TimestampGranularity {
377    /// Returns the string identifier for this granularity.
378    pub fn as_str(&self) -> &'static str {
379        match self {
380            TimestampGranularity::Word => "word",
381            TimestampGranularity::Segment => "segment",
382        }
383    }
384}
385
386/// Supported input audio formats for transcription.
387#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
388#[non_exhaustive]
389pub enum InputAudioFormat {
390    /// FLAC audio
391    Flac,
392    /// MP3 audio
393    Mp3,
394    /// MP4 audio
395    Mp4,
396    /// MPEG audio
397    Mpeg,
398    /// MPGA audio
399    Mpga,
400    /// M4A audio
401    M4a,
402    /// OGG audio
403    Ogg,
404    /// WAV audio
405    Wav,
406    /// WebM audio
407    Webm,
408}
409
410impl InputAudioFormat {
411    /// Returns the MIME type for this format.
412    pub fn mime_type(&self) -> &'static str {
413        match self {
414            InputAudioFormat::Flac => "audio/flac",
415            InputAudioFormat::Mp3 => "audio/mpeg",
416            InputAudioFormat::Mp4 => "audio/mp4",
417            InputAudioFormat::Mpeg => "audio/mpeg",
418            InputAudioFormat::Mpga => "audio/mpeg",
419            InputAudioFormat::M4a => "audio/mp4",
420            InputAudioFormat::Ogg => "audio/ogg",
421            InputAudioFormat::Wav => "audio/wav",
422            InputAudioFormat::Webm => "audio/webm",
423        }
424    }
425
426    /// Attempts to detect the format from a file extension.
427    pub fn from_extension(ext: &str) -> Option<Self> {
428        match ext.to_lowercase().as_str() {
429            "flac" => Some(InputAudioFormat::Flac),
430            "mp3" => Some(InputAudioFormat::Mp3),
431            "mp4" => Some(InputAudioFormat::Mp4),
432            "mpeg" => Some(InputAudioFormat::Mpeg),
433            "mpga" => Some(InputAudioFormat::Mpga),
434            "m4a" => Some(InputAudioFormat::M4a),
435            "ogg" => Some(InputAudioFormat::Ogg),
436            "wav" => Some(InputAudioFormat::Wav),
437            "webm" => Some(InputAudioFormat::Webm),
438            _ => None,
439        }
440    }
441}
442
443#[cfg(test)]
444mod tests {
445    use super::*;
446
447    #[test]
448    fn test_voice_serialization() {
449        let voice = Voice::Alloy;
450        let serialized = serde_json::to_string(&voice).unwrap();
451        assert_eq!(serialized, "\"alloy\"");
452    }
453
454    #[test]
455    fn test_tts_model_serialization() {
456        let model = TtsModel::Gpt4oMiniTts;
457        let serialized = serde_json::to_string(&model).unwrap();
458        assert_eq!(serialized, "\"gpt-4o-mini-tts\"");
459    }
460
461    #[test]
462    fn test_audio_format_mime_types() {
463        assert_eq!(AudioFormat::Mp3.mime_type(), "audio/mpeg");
464        assert_eq!(AudioFormat::Wav.mime_type(), "audio/wav");
465    }
466
467    #[test]
468    fn test_transcription_model_features() {
469        assert!(TranscriptionModel::Gpt4oTranscribe.supports_streaming());
470        assert!(!TranscriptionModel::Whisper1.supports_streaming());
471        assert!(TranscriptionModel::Gpt4oTranscribeDiarize.supports_diarization());
472    }
473}