spec_ai_core/agent/
transcription.rs

1//! Transcription Provider Abstraction Layer
2//!
3//! This module defines the core traits and types for integrating with various transcription providers.
4//! It provides a unified interface that abstracts away provider-specific details.
5
6use anyhow::Result;
7use async_trait::async_trait;
8use futures::Stream;
9use serde::{Deserialize, Serialize};
10use std::pin::Pin;
11
12/// Configuration for transcription requests
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct TranscriptionConfig {
15    /// Duration to record in seconds (None = continuous until stopped)
16    pub duration_secs: Option<u64>,
17    /// Audio chunk duration in seconds
18    pub chunk_duration_secs: f64,
19    /// Model to use for transcription (e.g., "whisper-1")
20    pub model: String,
21    /// Optional output file path for transcript
22    pub out_file: Option<String>,
23    /// Language code (e.g., "en", "es", "fr")
24    pub language: Option<String>,
25    /// Custom API endpoint (if different from default)
26    pub endpoint: Option<String>,
27}
28
29impl Default for TranscriptionConfig {
30    fn default() -> Self {
31        Self {
32            duration_secs: Some(30),
33            chunk_duration_secs: 5.0,
34            model: "whisper-1".to_string(),
35            out_file: None,
36            language: None,
37            endpoint: None,
38        }
39    }
40}
41
42/// Event emitted during transcription
43#[derive(Debug, Clone, Serialize, Deserialize)]
44pub enum TranscriptionEvent {
45    /// Successful transcription of an audio chunk
46    Transcription {
47        /// Chunk identifier
48        chunk_id: usize,
49        /// Transcribed text
50        text: String,
51        /// Timestamp when this chunk was processed
52        timestamp: std::time::SystemTime,
53    },
54    /// Error during transcription
55    Error {
56        /// Chunk identifier where error occurred
57        chunk_id: usize,
58        /// Error message
59        message: String,
60    },
61    /// Transcription session started
62    Started {
63        /// Start timestamp
64        timestamp: std::time::SystemTime,
65    },
66    /// Transcription session completed
67    Completed {
68        /// End timestamp
69        timestamp: std::time::SystemTime,
70        /// Total chunks processed
71        total_chunks: usize,
72    },
73}
74
75/// Transcription session statistics
76#[derive(Debug, Clone, Serialize, Deserialize)]
77pub struct TranscriptionStats {
78    /// Total duration recorded in seconds
79    pub duration_secs: f64,
80    /// Total chunks processed
81    pub total_chunks: usize,
82    /// Number of successful transcriptions
83    pub successful_chunks: usize,
84    /// Number of errors
85    pub error_count: usize,
86    /// Total characters transcribed
87    pub total_chars: usize,
88}
89
90/// Provider metadata
91#[derive(Debug, Clone, Serialize, Deserialize)]
92pub struct TranscriptionProviderMetadata {
93    /// Provider name
94    pub name: String,
95    /// Supported models
96    pub supported_models: Vec<String>,
97    /// Supports streaming
98    pub supports_streaming: bool,
99    /// Supported languages
100    pub supported_languages: Vec<String>,
101}
102
103/// Types of transcription providers
104#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
105#[serde(rename_all = "lowercase")]
106pub enum TranscriptionProviderKind {
107    Mock,
108    #[cfg(feature = "vttrs")]
109    VttRs,
110}
111
112impl TranscriptionProviderKind {
113    pub fn from_str(s: &str) -> Option<Self> {
114        match s.to_lowercase().as_str() {
115            "mock" => Some(TranscriptionProviderKind::Mock),
116            #[cfg(feature = "vttrs")]
117            "vttrs" | "vtt-rs" => Some(TranscriptionProviderKind::VttRs),
118            _ => None,
119        }
120    }
121
122    pub fn as_str(&self) -> &'static str {
123        match self {
124            TranscriptionProviderKind::Mock => "mock",
125            #[cfg(feature = "vttrs")]
126            TranscriptionProviderKind::VttRs => "vttrs",
127        }
128    }
129}
130
131/// Core trait that all transcription providers must implement
132#[async_trait]
133pub trait TranscriptionProvider: Send + Sync {
134    /// Start a transcription session and return a stream of events
135    async fn start_transcription(
136        &self,
137        config: &TranscriptionConfig,
138    ) -> Result<Pin<Box<dyn Stream<Item = Result<TranscriptionEvent>> + Send>>>;
139
140    /// Get provider metadata
141    fn metadata(&self) -> TranscriptionProviderMetadata;
142
143    /// Get the provider kind
144    fn kind(&self) -> TranscriptionProviderKind;
145
146    /// Check if the provider is available and configured correctly
147    async fn health_check(&self) -> Result<bool> {
148        Ok(true)
149    }
150}
151
152#[cfg(test)]
153mod tests {
154    use super::*;
155
156    #[test]
157    fn test_provider_kind_from_str() {
158        assert_eq!(
159            TranscriptionProviderKind::from_str("mock"),
160            Some(TranscriptionProviderKind::Mock)
161        );
162        assert_eq!(
163            TranscriptionProviderKind::from_str("Mock"),
164            Some(TranscriptionProviderKind::Mock)
165        );
166        assert_eq!(
167            TranscriptionProviderKind::from_str("MOCK"),
168            Some(TranscriptionProviderKind::Mock)
169        );
170        assert_eq!(TranscriptionProviderKind::from_str("invalid"), None);
171    }
172
173    #[test]
174    fn test_provider_kind_as_str() {
175        assert_eq!(TranscriptionProviderKind::Mock.as_str(), "mock");
176    }
177
178    #[test]
179    fn test_transcription_config_default() {
180        let config = TranscriptionConfig::default();
181        assert_eq!(config.duration_secs, Some(30));
182        assert_eq!(config.chunk_duration_secs, 5.0);
183        assert_eq!(config.model, "whisper-1");
184    }
185
186    #[test]
187    fn test_transcription_config_serialization() {
188        let config = TranscriptionConfig {
189            duration_secs: Some(60),
190            chunk_duration_secs: 3.0,
191            model: "whisper-large".to_string(),
192            out_file: Some("/tmp/transcript.txt".to_string()),
193            language: Some("en".to_string()),
194            endpoint: None,
195        };
196
197        let json = serde_json::to_string(&config).unwrap();
198        let deserialized: TranscriptionConfig = serde_json::from_str(&json).unwrap();
199
200        assert_eq!(config.duration_secs, deserialized.duration_secs);
201        assert_eq!(config.model, deserialized.model);
202    }
203}