spec_ai_core/agent/
transcription.rs1use anyhow::Result;
7use async_trait::async_trait;
8use futures::Stream;
9use serde::{Deserialize, Serialize};
10use std::pin::Pin;
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct TranscriptionConfig {
15 pub duration_secs: Option<u64>,
17 pub chunk_duration_secs: f64,
19 pub model: String,
21 pub out_file: Option<String>,
23 pub language: Option<String>,
25 pub endpoint: Option<String>,
27}
28
29impl Default for TranscriptionConfig {
30 fn default() -> Self {
31 Self {
32 duration_secs: Some(30),
33 chunk_duration_secs: 5.0,
34 model: "whisper-1".to_string(),
35 out_file: None,
36 language: None,
37 endpoint: None,
38 }
39 }
40}
41
42#[derive(Debug, Clone, Serialize, Deserialize)]
44pub enum TranscriptionEvent {
45 Transcription {
47 chunk_id: usize,
49 text: String,
51 timestamp: std::time::SystemTime,
53 },
54 Error {
56 chunk_id: usize,
58 message: String,
60 },
61 Started {
63 timestamp: std::time::SystemTime,
65 },
66 Completed {
68 timestamp: std::time::SystemTime,
70 total_chunks: usize,
72 },
73}
74
75#[derive(Debug, Clone, Serialize, Deserialize)]
77pub struct TranscriptionStats {
78 pub duration_secs: f64,
80 pub total_chunks: usize,
82 pub successful_chunks: usize,
84 pub error_count: usize,
86 pub total_chars: usize,
88}
89
90#[derive(Debug, Clone, Serialize, Deserialize)]
92pub struct TranscriptionProviderMetadata {
93 pub name: String,
95 pub supported_models: Vec<String>,
97 pub supports_streaming: bool,
99 pub supported_languages: Vec<String>,
101}
102
103#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
105#[serde(rename_all = "lowercase")]
106pub enum TranscriptionProviderKind {
107 Mock,
108 #[cfg(feature = "vttrs")]
109 VttRs,
110}
111
112impl TranscriptionProviderKind {
113 pub fn from_str(s: &str) -> Option<Self> {
114 match s.to_lowercase().as_str() {
115 "mock" => Some(TranscriptionProviderKind::Mock),
116 #[cfg(feature = "vttrs")]
117 "vttrs" | "vtt-rs" => Some(TranscriptionProviderKind::VttRs),
118 _ => None,
119 }
120 }
121
122 pub fn as_str(&self) -> &'static str {
123 match self {
124 TranscriptionProviderKind::Mock => "mock",
125 #[cfg(feature = "vttrs")]
126 TranscriptionProviderKind::VttRs => "vttrs",
127 }
128 }
129}
130
131#[async_trait]
133pub trait TranscriptionProvider: Send + Sync {
134 async fn start_transcription(
136 &self,
137 config: &TranscriptionConfig,
138 ) -> Result<Pin<Box<dyn Stream<Item = Result<TranscriptionEvent>> + Send>>>;
139
140 fn metadata(&self) -> TranscriptionProviderMetadata;
142
143 fn kind(&self) -> TranscriptionProviderKind;
145
146 async fn health_check(&self) -> Result<bool> {
148 Ok(true)
149 }
150}
151
152#[cfg(test)]
153mod tests {
154 use super::*;
155
156 #[test]
157 fn test_provider_kind_from_str() {
158 assert_eq!(
159 TranscriptionProviderKind::from_str("mock"),
160 Some(TranscriptionProviderKind::Mock)
161 );
162 assert_eq!(
163 TranscriptionProviderKind::from_str("Mock"),
164 Some(TranscriptionProviderKind::Mock)
165 );
166 assert_eq!(
167 TranscriptionProviderKind::from_str("MOCK"),
168 Some(TranscriptionProviderKind::Mock)
169 );
170 assert_eq!(TranscriptionProviderKind::from_str("invalid"), None);
171 }
172
173 #[test]
174 fn test_provider_kind_as_str() {
175 assert_eq!(TranscriptionProviderKind::Mock.as_str(), "mock");
176 }
177
178 #[test]
179 fn test_transcription_config_default() {
180 let config = TranscriptionConfig::default();
181 assert_eq!(config.duration_secs, Some(30));
182 assert_eq!(config.chunk_duration_secs, 5.0);
183 assert_eq!(config.model, "whisper-1");
184 }
185
186 #[test]
187 fn test_transcription_config_serialization() {
188 let config = TranscriptionConfig {
189 duration_secs: Some(60),
190 chunk_duration_secs: 3.0,
191 model: "whisper-large".to_string(),
192 out_file: Some("/tmp/transcript.txt".to_string()),
193 language: Some("en".to_string()),
194 endpoint: None,
195 };
196
197 let json = serde_json::to_string(&config).unwrap();
198 let deserialized: TranscriptionConfig = serde_json::from_str(&json).unwrap();
199
200 assert_eq!(config.duration_secs, deserialized.duration_secs);
201 assert_eq!(config.model, deserialized.model);
202 }
203}