spec_ai_core/agent/
transcription_factory.rs

1//! Transcription Provider Factory
2//!
3//! Creates transcription provider instances based on configuration.
4
5use crate::agent::transcription::{TranscriptionProvider, TranscriptionProviderKind};
6use crate::agent::transcription_providers::MockTranscriptionProvider;
7#[cfg(feature = "vttrs")]
8use crate::agent::transcription_providers::VttRsProvider;
9use anyhow::{anyhow, Context, Result};
10use serde::{Deserialize, Serialize};
11use std::sync::Arc;
12
13/// Configuration for transcription providers
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct TranscriptionProviderConfig {
16    /// Provider type (mock, vttrs, etc.)
17    pub provider: String,
18    /// Optional API key source (env:VAR_NAME, file:PATH, or direct key)
19    pub api_key_source: Option<String>,
20    /// Optional custom endpoint
21    pub endpoint: Option<String>,
22    /// Use on-device transcription (offline mode)
23    #[serde(default)]
24    pub on_device: bool,
25    /// Provider-specific settings
26    #[serde(default)]
27    pub settings: serde_json::Value,
28}
29
30impl Default for TranscriptionProviderConfig {
31    fn default() -> Self {
32        Self {
33            provider: "mock".to_string(),
34            api_key_source: None,
35            endpoint: None,
36            on_device: false,
37            settings: serde_json::Value::Null,
38        }
39    }
40}
41
42/// Create a transcription provider from configuration
43pub fn create_transcription_provider(
44    config: &TranscriptionProviderConfig,
45) -> Result<Arc<dyn TranscriptionProvider>> {
46    let provider_kind = TranscriptionProviderKind::from_str(&config.provider)
47        .ok_or_else(|| anyhow!("Unknown transcription provider: {}", config.provider))?;
48
49    match provider_kind {
50        TranscriptionProviderKind::Mock => {
51            // Create mock provider
52            let provider = MockTranscriptionProvider::new();
53            Ok(Arc::new(provider))
54        }
55
56        #[cfg(feature = "vttrs")]
57        TranscriptionProviderKind::VttRs => {
58            // On-device mode doesn't require API key
59            let api_key = if config.on_device {
60                String::new() // Empty API key for on-device mode
61            } else if let Some(source) = &config.api_key_source {
62                resolve_api_key(source)?
63            } else {
64                // Default to OPENAI_API_KEY or VTT_API_KEY environment variable
65                std::env::var("OPENAI_API_KEY")
66                    .or_else(|_| std::env::var("VTT_API_KEY"))
67                    .unwrap_or_default()
68            };
69
70            // Create VTT-RS provider
71            let mut provider = VttRsProvider::new(api_key);
72
73            // Set custom endpoint if specified
74            if let Some(endpoint) = &config.endpoint {
75                provider = provider.with_endpoint(endpoint.clone());
76            }
77
78            // Set on-device mode if enabled
79            if config.on_device {
80                provider = provider.with_on_device(true);
81            }
82
83            Ok(Arc::new(provider))
84        }
85    }
86}
87
88/// Create a transcription provider with just a provider kind string (for convenience)
89pub fn create_transcription_provider_simple(
90    provider_kind: &str,
91) -> Result<Arc<dyn TranscriptionProvider>> {
92    let config = TranscriptionProviderConfig {
93        provider: provider_kind.to_string(),
94        ..Default::default()
95    };
96    create_transcription_provider(&config)
97}
98
99/// Resolve API key from a source string
100///
101/// Supports the following formats:
102/// - `env:VAR_NAME` - Load from environment variable
103/// - `file:PATH` - Load from file
104/// - Any other string - Use as-is (direct API key)
105pub fn resolve_api_key(source: &str) -> Result<String> {
106    if let Some(env_var) = source.strip_prefix("env:") {
107        load_api_key_from_env(env_var)
108    } else if let Some(path) = source.strip_prefix("file:") {
109        load_api_key_from_file(path)
110    } else {
111        // Treat as direct API key
112        Ok(source.to_string())
113    }
114}
115
116/// Load API key from environment variable
117pub fn load_api_key_from_env(env_var: &str) -> Result<String> {
118    std::env::var(env_var).context(format!("Environment variable {} not set", env_var))
119}
120
121/// Load API key from file
122pub fn load_api_key_from_file(path: &str) -> Result<String> {
123    // Handle tilde expansion manually
124    let expanded_path = if let Some(stripped) = path.strip_prefix("~/") {
125        if let Some(home) = std::env::var_os("HOME") {
126            std::path::PathBuf::from(home).join(stripped)
127        } else {
128            std::path::PathBuf::from(path)
129        }
130    } else {
131        std::path::PathBuf::from(path)
132    };
133
134    std::fs::read_to_string(&expanded_path)
135        .context(format!("Failed to read API key from file: {}", path))
136        .map(|s| s.trim().to_string())
137}
138
139#[cfg(test)]
140mod tests {
141    use super::*;
142
143    #[test]
144    fn test_create_mock_provider() {
145        let config = TranscriptionProviderConfig {
146            provider: "mock".to_string(),
147            ..Default::default()
148        };
149
150        let provider = create_transcription_provider(&config).unwrap();
151        assert_eq!(provider.kind(), TranscriptionProviderKind::Mock);
152    }
153
154    #[test]
155    fn test_create_unknown_provider() {
156        let config = TranscriptionProviderConfig {
157            provider: "unknown-provider".to_string(),
158            ..Default::default()
159        };
160
161        let result = create_transcription_provider(&config);
162        assert!(result.is_err());
163    }
164
165    #[test]
166    fn test_create_simple_mock_provider() {
167        let provider = create_transcription_provider_simple("mock").unwrap();
168        assert_eq!(provider.kind(), TranscriptionProviderKind::Mock);
169    }
170
171    #[test]
172    fn test_load_api_key_from_env() {
173        unsafe {
174            std::env::set_var("TEST_TRANSCRIPTION_API_KEY", "env-key-value");
175        }
176        let key = load_api_key_from_env("TEST_TRANSCRIPTION_API_KEY").unwrap();
177        assert_eq!(key, "env-key-value");
178        unsafe {
179            std::env::remove_var("TEST_TRANSCRIPTION_API_KEY");
180        }
181    }
182
183    #[test]
184    fn test_resolve_api_key_direct() {
185        let key = resolve_api_key("sk-direct-api-key").unwrap();
186        assert_eq!(key, "sk-direct-api-key");
187    }
188
189    #[test]
190    fn test_resolve_api_key_from_env() {
191        unsafe {
192            std::env::set_var("TEST_RESOLVE_TRANSCRIPTION_KEY", "env-resolved-value");
193        }
194        let key = resolve_api_key("env:TEST_RESOLVE_TRANSCRIPTION_KEY").unwrap();
195        assert_eq!(key, "env-resolved-value");
196        unsafe {
197            std::env::remove_var("TEST_RESOLVE_TRANSCRIPTION_KEY");
198        }
199    }
200
201    #[test]
202    fn test_config_default() {
203        let config = TranscriptionProviderConfig::default();
204        assert_eq!(config.provider, "mock");
205        assert!(config.api_key_source.is_none());
206        assert!(config.endpoint.is_none());
207    }
208
209    #[test]
210    fn test_config_serialization() {
211        let config = TranscriptionProviderConfig {
212            provider: "vttrs".to_string(),
213            api_key_source: Some("env:OPENAI_API_KEY".to_string()),
214            endpoint: Some("https://api.openai.com".to_string()),
215            on_device: false,
216            settings: serde_json::json!({"custom": "value"}),
217        };
218
219        let json = serde_json::to_string(&config).unwrap();
220        let deserialized: TranscriptionProviderConfig = serde_json::from_str(&json).unwrap();
221
222        assert_eq!(config.provider, deserialized.provider);
223        assert_eq!(config.api_key_source, deserialized.api_key_source);
224        assert_eq!(config.endpoint, deserialized.endpoint);
225        assert_eq!(config.on_device, deserialized.on_device);
226    }
227}