spec_ai_core/agent/
transcription_factory.rs1use crate::agent::transcription::{TranscriptionProvider, TranscriptionProviderKind};
6use crate::agent::transcription_providers::MockTranscriptionProvider;
7#[cfg(feature = "vttrs")]
8use crate::agent::transcription_providers::VttRsProvider;
9use anyhow::{anyhow, Context, Result};
10use serde::{Deserialize, Serialize};
11use std::sync::Arc;
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct TranscriptionProviderConfig {
16 pub provider: String,
18 pub api_key_source: Option<String>,
20 pub endpoint: Option<String>,
22 #[serde(default)]
24 pub on_device: bool,
25 #[serde(default)]
27 pub settings: serde_json::Value,
28}
29
30impl Default for TranscriptionProviderConfig {
31 fn default() -> Self {
32 Self {
33 provider: "mock".to_string(),
34 api_key_source: None,
35 endpoint: None,
36 on_device: false,
37 settings: serde_json::Value::Null,
38 }
39 }
40}
41
42pub fn create_transcription_provider(
44 config: &TranscriptionProviderConfig,
45) -> Result<Arc<dyn TranscriptionProvider>> {
46 let provider_kind = TranscriptionProviderKind::from_str(&config.provider)
47 .ok_or_else(|| anyhow!("Unknown transcription provider: {}", config.provider))?;
48
49 match provider_kind {
50 TranscriptionProviderKind::Mock => {
51 let provider = MockTranscriptionProvider::new();
53 Ok(Arc::new(provider))
54 }
55
56 #[cfg(feature = "vttrs")]
57 TranscriptionProviderKind::VttRs => {
58 let api_key = if config.on_device {
60 String::new() } else if let Some(source) = &config.api_key_source {
62 resolve_api_key(source)?
63 } else {
64 std::env::var("OPENAI_API_KEY")
66 .or_else(|_| std::env::var("VTT_API_KEY"))
67 .unwrap_or_default()
68 };
69
70 let mut provider = VttRsProvider::new(api_key);
72
73 if let Some(endpoint) = &config.endpoint {
75 provider = provider.with_endpoint(endpoint.clone());
76 }
77
78 if config.on_device {
80 provider = provider.with_on_device(true);
81 }
82
83 Ok(Arc::new(provider))
84 }
85 }
86}
87
88pub fn create_transcription_provider_simple(
90 provider_kind: &str,
91) -> Result<Arc<dyn TranscriptionProvider>> {
92 let config = TranscriptionProviderConfig {
93 provider: provider_kind.to_string(),
94 ..Default::default()
95 };
96 create_transcription_provider(&config)
97}
98
99pub fn resolve_api_key(source: &str) -> Result<String> {
106 if let Some(env_var) = source.strip_prefix("env:") {
107 load_api_key_from_env(env_var)
108 } else if let Some(path) = source.strip_prefix("file:") {
109 load_api_key_from_file(path)
110 } else {
111 Ok(source.to_string())
113 }
114}
115
116pub fn load_api_key_from_env(env_var: &str) -> Result<String> {
118 std::env::var(env_var).context(format!("Environment variable {} not set", env_var))
119}
120
121pub fn load_api_key_from_file(path: &str) -> Result<String> {
123 let expanded_path = if let Some(stripped) = path.strip_prefix("~/") {
125 if let Some(home) = std::env::var_os("HOME") {
126 std::path::PathBuf::from(home).join(stripped)
127 } else {
128 std::path::PathBuf::from(path)
129 }
130 } else {
131 std::path::PathBuf::from(path)
132 };
133
134 std::fs::read_to_string(&expanded_path)
135 .context(format!("Failed to read API key from file: {}", path))
136 .map(|s| s.trim().to_string())
137}
138
139#[cfg(test)]
140mod tests {
141 use super::*;
142
143 #[test]
144 fn test_create_mock_provider() {
145 let config = TranscriptionProviderConfig {
146 provider: "mock".to_string(),
147 ..Default::default()
148 };
149
150 let provider = create_transcription_provider(&config).unwrap();
151 assert_eq!(provider.kind(), TranscriptionProviderKind::Mock);
152 }
153
154 #[test]
155 fn test_create_unknown_provider() {
156 let config = TranscriptionProviderConfig {
157 provider: "unknown-provider".to_string(),
158 ..Default::default()
159 };
160
161 let result = create_transcription_provider(&config);
162 assert!(result.is_err());
163 }
164
165 #[test]
166 fn test_create_simple_mock_provider() {
167 let provider = create_transcription_provider_simple("mock").unwrap();
168 assert_eq!(provider.kind(), TranscriptionProviderKind::Mock);
169 }
170
171 #[test]
172 fn test_load_api_key_from_env() {
173 unsafe {
174 std::env::set_var("TEST_TRANSCRIPTION_API_KEY", "env-key-value");
175 }
176 let key = load_api_key_from_env("TEST_TRANSCRIPTION_API_KEY").unwrap();
177 assert_eq!(key, "env-key-value");
178 unsafe {
179 std::env::remove_var("TEST_TRANSCRIPTION_API_KEY");
180 }
181 }
182
183 #[test]
184 fn test_resolve_api_key_direct() {
185 let key = resolve_api_key("sk-direct-api-key").unwrap();
186 assert_eq!(key, "sk-direct-api-key");
187 }
188
189 #[test]
190 fn test_resolve_api_key_from_env() {
191 unsafe {
192 std::env::set_var("TEST_RESOLVE_TRANSCRIPTION_KEY", "env-resolved-value");
193 }
194 let key = resolve_api_key("env:TEST_RESOLVE_TRANSCRIPTION_KEY").unwrap();
195 assert_eq!(key, "env-resolved-value");
196 unsafe {
197 std::env::remove_var("TEST_RESOLVE_TRANSCRIPTION_KEY");
198 }
199 }
200
201 #[test]
202 fn test_config_default() {
203 let config = TranscriptionProviderConfig::default();
204 assert_eq!(config.provider, "mock");
205 assert!(config.api_key_source.is_none());
206 assert!(config.endpoint.is_none());
207 }
208
209 #[test]
210 fn test_config_serialization() {
211 let config = TranscriptionProviderConfig {
212 provider: "vttrs".to_string(),
213 api_key_source: Some("env:OPENAI_API_KEY".to_string()),
214 endpoint: Some("https://api.openai.com".to_string()),
215 on_device: false,
216 settings: serde_json::json!({"custom": "value"}),
217 };
218
219 let json = serde_json::to_string(&config).unwrap();
220 let deserialized: TranscriptionProviderConfig = serde_json::from_str(&json).unwrap();
221
222 assert_eq!(config.provider, deserialized.provider);
223 assert_eq!(config.api_key_source, deserialized.api_key_source);
224 assert_eq!(config.endpoint, deserialized.endpoint);
225 assert_eq!(config.on_device, deserialized.on_device);
226 }
227}