1use anyhow::{Context, Result, bail};
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use std::fs;
5use std::path::PathBuf;
6
7use crate::paths;
8use crate::style;
9use crate::ui::Style;
10
11#[derive(Debug, Clone, Default, Serialize, Deserialize)]
13pub struct TlConfig {
14 pub provider: Option<String>,
16 pub model: Option<String>,
18 pub to: Option<String>,
20 pub style: Option<String>,
22}
23
24#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct ProviderConfig {
29 pub endpoint: String,
31 #[serde(default)]
33 pub api_key: Option<String>,
34 #[serde(default)]
36 pub api_key_env: Option<String>,
37 #[serde(default)]
39 pub models: Vec<String>,
40}
41
42impl ProviderConfig {
43 pub fn get_api_key(&self) -> Option<String> {
45 if let Some(env_var) = &self.api_key_env
46 && let Ok(key) = std::env::var(env_var)
47 && !key.is_empty()
48 {
49 return Some(key);
50 }
51 self.api_key.clone()
52 }
53
54 pub const fn requires_api_key(&self) -> bool {
56 self.api_key.is_some() || self.api_key_env.is_some()
57 }
58}
59
60#[derive(Debug, Clone, Serialize, Deserialize)]
62pub struct CustomStyle {
63 pub description: String,
65 pub prompt: String,
67}
68
69#[derive(Debug, Clone, Default, Serialize, Deserialize)]
73pub struct ConfigFile {
74 #[serde(default)]
76 pub tl: TlConfig,
77 #[serde(default)]
79 pub providers: HashMap<String, ProviderConfig>,
80 #[serde(default)]
82 pub styles: HashMap<String, CustomStyle>,
83}
84
85#[derive(Debug, Clone)]
87pub struct ResolvedConfig {
88 pub provider_name: String,
90 pub endpoint: String,
92 pub model: String,
94 pub api_key: Option<String>,
96 pub target_language: String,
98 pub style_name: Option<String>,
100 pub style_prompt: Option<String>,
102}
103
104#[derive(Debug, Clone, Default)]
108pub struct ResolveOptions {
109 pub to: Option<String>,
111 pub provider: Option<String>,
113 pub model: Option<String>,
115 pub style: Option<String>,
117}
118
119pub fn resolve_config(
128 options: &ResolveOptions,
129 config_file: &ConfigFile,
130) -> Result<ResolvedConfig> {
131 let provider_name = options
133 .provider
134 .as_ref()
135 .or(config_file.tl.provider.as_ref())
136 .cloned()
137 .ok_or_else(|| {
138 anyhow::anyhow!(
139 "Missing required configuration: 'provider'\n\n\
140 Please provide it via:\n \
141 - CLI option: tl --provider <name>\n \
142 - Config file: ~/.config/tl/config.toml"
143 )
144 })?;
145
146 let provider_config = config_file.providers.get(&provider_name).ok_or_else(|| {
148 let available: Vec<_> = config_file.providers.keys().collect();
149 if available.is_empty() {
150 anyhow::anyhow!(
151 "Provider '{provider_name}' not found\n\n\
152 No providers configured. Add providers to ~/.config/tl/config.toml"
153 )
154 } else {
155 anyhow::anyhow!(
156 "Provider '{provider_name}' not found\n\n\
157 Available providers:\n \
158 - {}\n\n\
159 Add providers to ~/.config/tl/config.toml",
160 available
161 .iter()
162 .map(|s| s.as_str())
163 .collect::<Vec<_>>()
164 .join("\n - ")
165 )
166 }
167 })?;
168
169 let model = options
171 .model
172 .as_ref()
173 .or(config_file.tl.model.as_ref())
174 .cloned()
175 .ok_or_else(|| {
176 anyhow::anyhow!(
177 "Missing required configuration: 'model'\n\n\
178 Please provide it via:\n \
179 - CLI option: tl --model <name>\n \
180 - Config file: ~/.config/tl/config.toml"
181 )
182 })?;
183
184 if !provider_config.models.is_empty() && !provider_config.models.contains(&model) {
186 eprintln!(
187 "{} Model '{}' is not in the configured models list for '{}'\n\
188 Configured models: {}\n\
189 Proceeding anyway...\n",
190 Style::warning("Warning:"),
191 model,
192 provider_name,
193 provider_config.models.join(", ")
194 );
195 }
196
197 let target_language = options
199 .to
200 .as_ref()
201 .or(config_file.tl.to.as_ref())
202 .cloned()
203 .ok_or_else(|| {
204 anyhow::anyhow!(
205 "Missing required configuration: 'to' (target language)\n\n\
206 Please provide it via:\n \
207 - CLI option: tl --to <lang>\n \
208 - Config file: ~/.config/tl/config.toml"
209 )
210 })?;
211
212 let api_key = provider_config.get_api_key();
214
215 if provider_config.requires_api_key() && api_key.is_none() {
217 let env_var = provider_config.api_key_env.as_deref().unwrap_or("API_KEY");
218 bail!(
219 "Provider '{provider_name}' requires an API key\n\n\
220 Set the {env_var} environment variable:\n \
221 export {env_var}=\"your-api-key\"\n\n\
222 Or set api_key in ~/.config/tl/config.toml"
223 );
224 }
225
226 let style_key = options.style.as_ref().or(config_file.tl.style.as_ref());
228
229 let (style_name, style_prompt) = if let Some(key) = style_key {
230 let resolved =
231 style::resolve_style(key, &config_file.styles).map_err(|e| anyhow::anyhow!("{e}"))?;
232 (Some(key.clone()), Some(resolved.prompt().to_string()))
233 } else {
234 (None, None)
235 };
236
237 Ok(ResolvedConfig {
238 provider_name,
239 endpoint: provider_config.endpoint.clone(),
240 model,
241 api_key,
242 target_language,
243 style_name,
244 style_prompt,
245 })
246}
247
248pub struct ConfigManager {
250 config_path: PathBuf,
251}
252
253impl ConfigManager {
254 pub fn new() -> Result<Self> {
259 Ok(Self {
260 config_path: paths::config_dir()?.join("config.toml"),
261 })
262 }
263
264 pub const fn config_path(&self) -> &PathBuf {
265 &self.config_path
266 }
267
268 pub fn load(&self) -> Result<ConfigFile> {
269 let contents = fs::read_to_string(&self.config_path).with_context(|| {
270 format!("Failed to read config file: {}", self.config_path.display())
271 })?;
272
273 let config_file: ConfigFile =
274 toml::from_str(&contents).with_context(|| "Failed to parse config file")?;
275
276 Ok(config_file)
277 }
278
279 pub fn save(&self, config: &ConfigFile) -> Result<()> {
280 if let Some(parent) = self.config_path.parent() {
281 fs::create_dir_all(parent).with_context(|| {
282 format!("Failed to create config directory: {}", parent.display())
283 })?;
284 }
285
286 let contents = toml::to_string_pretty(config).context("Failed to serialize config")?;
287
288 fs::write(&self.config_path, contents).with_context(|| {
289 format!(
290 "Failed to write config file: {}",
291 self.config_path.display()
292 )
293 })?;
294
295 Ok(())
296 }
297
298 pub fn load_or_default(&self) -> ConfigFile {
299 self.load().unwrap_or_default()
300 }
301}
302
303#[cfg(test)]
304#[allow(clippy::unwrap_used)]
305mod tests {
306 use super::*;
307 use tempfile::TempDir;
308
309 fn create_test_manager(temp_dir: &TempDir) -> ConfigManager {
310 ConfigManager {
311 config_path: temp_dir.path().join("config.toml"),
312 }
313 }
314
315 #[test]
316 fn test_save_and_load_config() {
317 let temp_dir = TempDir::new().unwrap();
318 let manager = create_test_manager(&temp_dir);
319
320 let mut providers = HashMap::new();
321 providers.insert(
322 "ollama".to_string(),
323 ProviderConfig {
324 endpoint: "http://localhost:11434".to_string(),
325 api_key: None,
326 api_key_env: None,
327 models: vec!["gemma3:12b".to_string(), "llama3.2".to_string()],
328 },
329 );
330
331 let config = ConfigFile {
332 tl: TlConfig {
333 provider: Some("ollama".to_string()),
334 model: Some("gemma3:12b".to_string()),
335 to: Some("ja".to_string()),
336 style: None,
337 },
338 providers,
339 styles: HashMap::new(),
340 };
341
342 manager.save(&config).unwrap();
343 let loaded = manager.load().unwrap();
344
345 assert_eq!(loaded.tl.provider, Some("ollama".to_string()));
346 assert_eq!(loaded.tl.model, Some("gemma3:12b".to_string()));
347 assert_eq!(loaded.tl.to, Some("ja".to_string()));
348 assert!(loaded.providers.contains_key("ollama"));
349 }
350
351 #[test]
352 fn test_load_nonexistent_config() {
353 let temp_dir = TempDir::new().unwrap();
354 let manager = create_test_manager(&temp_dir);
355
356 let result = manager.load();
357 assert!(result.is_err());
358 }
359
360 #[test]
361 fn test_provider_get_api_key_from_env() {
362 unsafe {
364 std::env::set_var("TEST_API_KEY", "test-key-value");
365 }
366
367 let provider = ProviderConfig {
368 endpoint: "https://api.example.com".to_string(),
369 api_key: Some("fallback-key".to_string()),
370 api_key_env: Some("TEST_API_KEY".to_string()),
371 models: vec![],
372 };
373
374 assert_eq!(provider.get_api_key(), Some("test-key-value".to_string()));
376
377 unsafe {
379 std::env::remove_var("TEST_API_KEY");
380 }
381 }
382
383 #[test]
384 fn test_provider_get_api_key_fallback() {
385 unsafe {
387 std::env::remove_var("NONEXISTENT_KEY");
388 }
389
390 let provider = ProviderConfig {
391 endpoint: "https://api.example.com".to_string(),
392 api_key: Some("fallback-key".to_string()),
393 api_key_env: Some("NONEXISTENT_KEY".to_string()),
394 models: vec![],
395 };
396
397 assert_eq!(provider.get_api_key(), Some("fallback-key".to_string()));
399 }
400
401 #[test]
402 fn test_provider_requires_api_key() {
403 let provider_with_key = ProviderConfig {
404 endpoint: "https://api.example.com".to_string(),
405 api_key: Some("key".to_string()),
406 api_key_env: None,
407 models: vec![],
408 };
409 assert!(provider_with_key.requires_api_key());
410
411 let provider_with_env = ProviderConfig {
412 endpoint: "https://api.example.com".to_string(),
413 api_key: None,
414 api_key_env: Some("API_KEY".to_string()),
415 models: vec![],
416 };
417 assert!(provider_with_env.requires_api_key());
418
419 let provider_without = ProviderConfig {
420 endpoint: "http://localhost:11434".to_string(),
421 api_key: None,
422 api_key_env: None,
423 models: vec![],
424 };
425 assert!(!provider_without.requires_api_key());
426 }
427
428 fn create_test_options() -> ResolveOptions {
431 ResolveOptions {
432 to: Some("ja".to_string()),
433 provider: Some("ollama".to_string()),
434 model: Some("gemma3:12b".to_string()),
435 style: None,
436 }
437 }
438
439 fn create_test_config() -> ConfigFile {
440 let mut providers = HashMap::new();
441 providers.insert(
442 "ollama".to_string(),
443 ProviderConfig {
444 endpoint: "http://localhost:11434".to_string(),
445 api_key: None,
446 api_key_env: None,
447 models: vec!["gemma3:12b".to_string()],
448 },
449 );
450 providers.insert(
451 "openrouter".to_string(),
452 ProviderConfig {
453 endpoint: "https://openrouter.ai/api".to_string(),
454 api_key: None,
455 api_key_env: Some("TL_TEST_NONEXISTENT_API_KEY".to_string()),
456 models: vec!["gpt-4o".to_string()],
457 },
458 );
459
460 ConfigFile {
461 tl: TlConfig {
462 provider: Some("ollama".to_string()),
463 model: Some("gemma3:12b".to_string()),
464 to: Some("ja".to_string()),
465 style: None,
466 },
467 providers,
468 styles: HashMap::new(),
469 }
470 }
471
472 #[test]
473 fn test_resolve_config_with_cli_options() {
474 let options = create_test_options();
475 let config = create_test_config();
476
477 let resolved = resolve_config(&options, &config).unwrap();
478
479 assert_eq!(resolved.provider_name, "ollama");
480 assert_eq!(resolved.endpoint, "http://localhost:11434");
481 assert_eq!(resolved.model, "gemma3:12b");
482 assert_eq!(resolved.target_language, "ja");
483 assert!(resolved.api_key.is_none());
484 }
485
486 #[test]
487 fn test_resolve_config_cli_overrides_file() {
488 let mut options = create_test_options();
489 options.to = Some("en".to_string());
490 options.model = Some("llama3".to_string());
491
492 let config = create_test_config();
493
494 let resolved = resolve_config(&options, &config).unwrap();
495
496 assert_eq!(resolved.target_language, "en");
497 assert_eq!(resolved.model, "llama3");
498 }
499
500 #[test]
501 fn test_resolve_config_falls_back_to_file() {
502 let options = ResolveOptions::default();
503 let config = create_test_config();
504
505 let resolved = resolve_config(&options, &config).unwrap();
506
507 assert_eq!(resolved.provider_name, "ollama");
508 assert_eq!(resolved.model, "gemma3:12b");
509 assert_eq!(resolved.target_language, "ja");
510 }
511
512 #[test]
513 fn test_resolve_config_missing_provider() {
514 let options = ResolveOptions {
515 to: Some("ja".to_string()),
516 provider: None,
517 model: Some("model".to_string()),
518 style: None,
519 };
520 let config = ConfigFile::default();
521
522 let result = resolve_config(&options, &config);
523
524 assert!(result.is_err());
525 assert!(result.unwrap_err().to_string().contains("provider"));
526 }
527
528 #[test]
529 fn test_resolve_config_provider_not_found() {
530 let mut options = create_test_options();
531 options.provider = Some("nonexistent".to_string());
532
533 let config = create_test_config();
534
535 let result = resolve_config(&options, &config);
536
537 assert!(result.is_err());
538 assert!(result.unwrap_err().to_string().contains("not found"));
539 }
540
541 #[test]
542 fn test_resolve_config_missing_model() {
543 let mut options = create_test_options();
544 options.model = None;
545
546 let mut config = create_test_config();
547 config.tl.model = None;
548
549 let result = resolve_config(&options, &config);
550
551 assert!(result.is_err());
552 assert!(result.unwrap_err().to_string().contains("model"));
553 }
554
555 #[test]
556 fn test_resolve_config_missing_target_language() {
557 let mut options = create_test_options();
558 options.to = None;
559
560 let mut config = create_test_config();
561 config.tl.to = None;
562
563 let result = resolve_config(&options, &config);
564
565 assert!(result.is_err());
566 assert!(result.unwrap_err().to_string().contains("to"));
567 }
568
569 #[test]
570 fn test_resolve_config_api_key_required_but_missing() {
571 let mut options = create_test_options();
572 options.provider = Some("openrouter".to_string());
573
574 let config = create_test_config();
575
576 let result = resolve_config(&options, &config);
577
578 assert!(result.is_err());
579 assert!(result.unwrap_err().to_string().contains("API key"));
580 }
581}