1use anyhow::{Context, Result, bail};
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use std::fs;
5use std::path::PathBuf;
6
7use crate::paths;
8use crate::style;
9use crate::ui::Style;
10
11#[derive(Debug, Clone, Default, Serialize, Deserialize)]
13pub struct TlConfig {
14 pub provider: Option<String>,
16 pub model: Option<String>,
18 pub to: Option<String>,
20 pub style: Option<String>,
22}
23
24#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct ProviderConfig {
29 pub endpoint: String,
31 #[serde(default)]
33 pub api_key: Option<String>,
34 #[serde(default)]
36 pub api_key_env: Option<String>,
37 #[serde(default)]
39 pub models: Vec<String>,
40}
41
42impl ProviderConfig {
43 pub fn get_api_key(&self) -> Option<String> {
45 if let Some(env_var) = &self.api_key_env
46 && let Ok(key) = std::env::var(env_var)
47 && !key.is_empty()
48 {
49 return Some(key);
50 }
51 self.api_key.clone()
52 }
53
54 pub const fn requires_api_key(&self) -> bool {
56 self.api_key.is_some() || self.api_key_env.is_some()
57 }
58}
59
60#[derive(Debug, Clone, Serialize, Deserialize)]
62pub struct CustomStyle {
63 pub description: String,
65 pub prompt: String,
67}
68
69#[derive(Debug, Clone, Default, Serialize, Deserialize)]
73pub struct ConfigFile {
74 #[serde(default)]
76 pub tl: TlConfig,
77 #[serde(default)]
79 pub providers: HashMap<String, ProviderConfig>,
80 #[serde(default)]
82 pub styles: HashMap<String, CustomStyle>,
83}
84
85#[derive(Debug, Clone)]
87pub struct ResolvedConfig {
88 pub provider_name: String,
90 pub endpoint: String,
92 pub model: String,
94 pub api_key: Option<String>,
96 pub target_language: String,
98 pub style_name: Option<String>,
100 pub style_prompt: Option<String>,
102}
103
104#[derive(Debug, Clone, Default)]
108pub struct ResolveOptions {
109 pub to: Option<String>,
111 pub provider: Option<String>,
113 pub model: Option<String>,
115 pub style: Option<String>,
117}
118
119pub fn resolve_config(
128 options: &ResolveOptions,
129 config_file: &ConfigFile,
130) -> Result<ResolvedConfig> {
131 let provider_name = options
133 .provider
134 .as_ref()
135 .or(config_file.tl.provider.as_ref())
136 .cloned()
137 .ok_or_else(|| {
138 anyhow::anyhow!(
139 "Missing required configuration: 'provider'\n\n\
140 Please provide it via:\n \
141 - CLI option: tl --provider <name>\n \
142 - Config file: ~/.config/tl/config.toml"
143 )
144 })?;
145
146 let provider_config = config_file.providers.get(&provider_name).ok_or_else(|| {
148 let available: Vec<_> = config_file.providers.keys().collect();
149 if available.is_empty() {
150 anyhow::anyhow!(
151 "Provider '{provider_name}' not found\n\n\
152 No providers configured. Add providers to ~/.config/tl/config.toml"
153 )
154 } else {
155 anyhow::anyhow!(
156 "Provider '{provider_name}' not found\n\n\
157 Available providers:\n \
158 - {}\n\n\
159 Add providers to ~/.config/tl/config.toml",
160 available
161 .iter()
162 .map(|s| s.as_str())
163 .collect::<Vec<_>>()
164 .join("\n - ")
165 )
166 }
167 })?;
168
169 let model = options
171 .model
172 .as_ref()
173 .or(config_file.tl.model.as_ref())
174 .cloned()
175 .ok_or_else(|| {
176 anyhow::anyhow!(
177 "Missing required configuration: 'model'\n\n\
178 Please provide it via:\n \
179 - CLI option: tl --model <name>\n \
180 - Config file: ~/.config/tl/config.toml"
181 )
182 })?;
183
184 if !provider_config.models.is_empty() && !provider_config.models.contains(&model) {
186 eprintln!(
187 "{} Model '{}' is not in the configured models list for '{}'\n\
188 Configured models: {}\n\
189 Proceeding anyway...\n",
190 Style::warning("Warning:"),
191 model,
192 provider_name,
193 provider_config.models.join(", ")
194 );
195 }
196
197 let target_language = options
199 .to
200 .as_ref()
201 .or(config_file.tl.to.as_ref())
202 .cloned()
203 .ok_or_else(|| {
204 anyhow::anyhow!(
205 "Missing required configuration: 'to' (target language)\n\n\
206 Please provide it via:\n \
207 - CLI option: tl --to <lang>\n \
208 - Config file: ~/.config/tl/config.toml"
209 )
210 })?;
211
212 let api_key = provider_config.get_api_key();
214
215 if provider_config.requires_api_key() && api_key.is_none() {
217 let env_var = provider_config.api_key_env.as_deref().unwrap_or("API_KEY");
218 bail!(
219 "Provider '{provider_name}' requires an API key\n\n\
220 Set the {env_var} environment variable:\n \
221 export {env_var}=\"your-api-key\"\n\n\
222 Or set api_key in ~/.config/tl/config.toml"
223 );
224 }
225
226 let style_key = options.style.as_ref().or(config_file.tl.style.as_ref());
228
229 let (style_name, style_prompt) = if let Some(key) = style_key {
230 let resolved =
231 style::resolve_style(key, &config_file.styles).map_err(|e| anyhow::anyhow!("{e}"))?;
232 (Some(key.clone()), Some(resolved.prompt().to_string()))
233 } else {
234 (None, None)
235 };
236
237 Ok(ResolvedConfig {
238 provider_name,
239 endpoint: provider_config.endpoint.clone(),
240 model,
241 api_key,
242 target_language,
243 style_name,
244 style_prompt,
245 })
246}
247
248pub struct ConfigManager {
250 config_path: PathBuf,
251}
252
253impl ConfigManager {
254 pub fn new() -> Result<Self> {
259 Ok(Self {
260 config_path: paths::config_dir()?.join("config.toml"),
261 })
262 }
263
264 pub const fn config_path(&self) -> &PathBuf {
265 &self.config_path
266 }
267
268 pub fn load(&self) -> Result<ConfigFile> {
269 let contents = fs::read_to_string(&self.config_path).with_context(|| {
270 format!("Failed to read config file: {}", self.config_path.display())
271 })?;
272
273 let config_file: ConfigFile =
274 toml::from_str(&contents).with_context(|| "Failed to parse config file")?;
275
276 Ok(config_file)
277 }
278
279 pub fn save(&self, config: &ConfigFile) -> Result<()> {
280 if let Some(parent) = self.config_path.parent() {
281 fs::create_dir_all(parent).with_context(|| {
282 format!("Failed to create config directory: {}", parent.display())
283 })?;
284 }
285
286 let contents = toml::to_string_pretty(config).context("Failed to serialize config")?;
287
288 fs::write(&self.config_path, contents).with_context(|| {
289 format!(
290 "Failed to write config file: {}",
291 self.config_path.display()
292 )
293 })?;
294
295 Ok(())
296 }
297
298 pub fn load_or_default(&self) -> Result<ConfigFile> {
303 match fs::read_to_string(&self.config_path) {
304 Ok(contents) => {
305 toml::from_str(&contents).with_context(|| "Failed to parse config file")
306 }
307 Err(e) if e.kind() == std::io::ErrorKind::NotFound => Ok(ConfigFile::default()),
308 Err(e) => Err(anyhow::anyhow!(
309 "Failed to read config file: {}: {}",
310 self.config_path.display(),
311 e
312 )),
313 }
314 }
315}
316
317#[cfg(test)]
318#[allow(clippy::unwrap_used)]
319mod tests {
320 use super::*;
321 use tempfile::TempDir;
322
323 fn create_test_manager(temp_dir: &TempDir) -> ConfigManager {
324 ConfigManager {
325 config_path: temp_dir.path().join("config.toml"),
326 }
327 }
328
329 #[test]
330 fn test_save_and_load_config() {
331 let temp_dir = TempDir::new().unwrap();
332 let manager = create_test_manager(&temp_dir);
333
334 let mut providers = HashMap::new();
335 providers.insert(
336 "ollama".to_string(),
337 ProviderConfig {
338 endpoint: "http://localhost:11434".to_string(),
339 api_key: None,
340 api_key_env: None,
341 models: vec!["gemma3:12b".to_string(), "llama3.2".to_string()],
342 },
343 );
344
345 let config = ConfigFile {
346 tl: TlConfig {
347 provider: Some("ollama".to_string()),
348 model: Some("gemma3:12b".to_string()),
349 to: Some("ja".to_string()),
350 style: None,
351 },
352 providers,
353 styles: HashMap::new(),
354 };
355
356 manager.save(&config).unwrap();
357 let loaded = manager.load().unwrap();
358
359 assert_eq!(loaded.tl.provider, Some("ollama".to_string()));
360 assert_eq!(loaded.tl.model, Some("gemma3:12b".to_string()));
361 assert_eq!(loaded.tl.to, Some("ja".to_string()));
362 assert!(loaded.providers.contains_key("ollama"));
363 }
364
365 #[test]
366 fn test_load_nonexistent_config() {
367 let temp_dir = TempDir::new().unwrap();
368 let manager = create_test_manager(&temp_dir);
369
370 let result = manager.load();
371 assert!(result.is_err());
372 }
373
374 #[test]
375 fn test_provider_get_api_key_from_env() {
376 unsafe {
378 std::env::set_var("TEST_API_KEY", "test-key-value");
379 }
380
381 let provider = ProviderConfig {
382 endpoint: "https://api.example.com".to_string(),
383 api_key: Some("fallback-key".to_string()),
384 api_key_env: Some("TEST_API_KEY".to_string()),
385 models: vec![],
386 };
387
388 assert_eq!(provider.get_api_key(), Some("test-key-value".to_string()));
390
391 unsafe {
393 std::env::remove_var("TEST_API_KEY");
394 }
395 }
396
397 #[test]
398 fn test_provider_get_api_key_fallback() {
399 unsafe {
401 std::env::remove_var("NONEXISTENT_KEY");
402 }
403
404 let provider = ProviderConfig {
405 endpoint: "https://api.example.com".to_string(),
406 api_key: Some("fallback-key".to_string()),
407 api_key_env: Some("NONEXISTENT_KEY".to_string()),
408 models: vec![],
409 };
410
411 assert_eq!(provider.get_api_key(), Some("fallback-key".to_string()));
413 }
414
415 #[test]
416 fn test_provider_requires_api_key() {
417 let provider_with_key = ProviderConfig {
418 endpoint: "https://api.example.com".to_string(),
419 api_key: Some("key".to_string()),
420 api_key_env: None,
421 models: vec![],
422 };
423 assert!(provider_with_key.requires_api_key());
424
425 let provider_with_env = ProviderConfig {
426 endpoint: "https://api.example.com".to_string(),
427 api_key: None,
428 api_key_env: Some("API_KEY".to_string()),
429 models: vec![],
430 };
431 assert!(provider_with_env.requires_api_key());
432
433 let provider_without = ProviderConfig {
434 endpoint: "http://localhost:11434".to_string(),
435 api_key: None,
436 api_key_env: None,
437 models: vec![],
438 };
439 assert!(!provider_without.requires_api_key());
440 }
441
442 fn create_test_options() -> ResolveOptions {
445 ResolveOptions {
446 to: Some("ja".to_string()),
447 provider: Some("ollama".to_string()),
448 model: Some("gemma3:12b".to_string()),
449 style: None,
450 }
451 }
452
453 fn create_test_config() -> ConfigFile {
454 let mut providers = HashMap::new();
455 providers.insert(
456 "ollama".to_string(),
457 ProviderConfig {
458 endpoint: "http://localhost:11434".to_string(),
459 api_key: None,
460 api_key_env: None,
461 models: vec!["gemma3:12b".to_string()],
462 },
463 );
464 providers.insert(
465 "openrouter".to_string(),
466 ProviderConfig {
467 endpoint: "https://openrouter.ai/api".to_string(),
468 api_key: None,
469 api_key_env: Some("TL_TEST_NONEXISTENT_API_KEY".to_string()),
470 models: vec!["gpt-4o".to_string()],
471 },
472 );
473
474 ConfigFile {
475 tl: TlConfig {
476 provider: Some("ollama".to_string()),
477 model: Some("gemma3:12b".to_string()),
478 to: Some("ja".to_string()),
479 style: None,
480 },
481 providers,
482 styles: HashMap::new(),
483 }
484 }
485
486 #[test]
487 fn test_resolve_config_with_cli_options() {
488 let options = create_test_options();
489 let config = create_test_config();
490
491 let resolved = resolve_config(&options, &config).unwrap();
492
493 assert_eq!(resolved.provider_name, "ollama");
494 assert_eq!(resolved.endpoint, "http://localhost:11434");
495 assert_eq!(resolved.model, "gemma3:12b");
496 assert_eq!(resolved.target_language, "ja");
497 assert!(resolved.api_key.is_none());
498 }
499
500 #[test]
501 fn test_resolve_config_cli_overrides_file() {
502 let mut options = create_test_options();
503 options.to = Some("en".to_string());
504 options.model = Some("llama3".to_string());
505
506 let config = create_test_config();
507
508 let resolved = resolve_config(&options, &config).unwrap();
509
510 assert_eq!(resolved.target_language, "en");
511 assert_eq!(resolved.model, "llama3");
512 }
513
514 #[test]
515 fn test_resolve_config_falls_back_to_file() {
516 let options = ResolveOptions::default();
517 let config = create_test_config();
518
519 let resolved = resolve_config(&options, &config).unwrap();
520
521 assert_eq!(resolved.provider_name, "ollama");
522 assert_eq!(resolved.model, "gemma3:12b");
523 assert_eq!(resolved.target_language, "ja");
524 }
525
526 #[test]
527 fn test_resolve_config_missing_provider() {
528 let options = ResolveOptions {
529 to: Some("ja".to_string()),
530 provider: None,
531 model: Some("model".to_string()),
532 style: None,
533 };
534 let config = ConfigFile::default();
535
536 let result = resolve_config(&options, &config);
537
538 assert!(result.is_err());
539 assert!(result.unwrap_err().to_string().contains("provider"));
540 }
541
542 #[test]
543 fn test_resolve_config_provider_not_found() {
544 let mut options = create_test_options();
545 options.provider = Some("nonexistent".to_string());
546
547 let config = create_test_config();
548
549 let result = resolve_config(&options, &config);
550
551 assert!(result.is_err());
552 assert!(result.unwrap_err().to_string().contains("not found"));
553 }
554
555 #[test]
556 fn test_resolve_config_missing_model() {
557 let mut options = create_test_options();
558 options.model = None;
559
560 let mut config = create_test_config();
561 config.tl.model = None;
562
563 let result = resolve_config(&options, &config);
564
565 assert!(result.is_err());
566 assert!(result.unwrap_err().to_string().contains("model"));
567 }
568
569 #[test]
570 fn test_resolve_config_missing_target_language() {
571 let mut options = create_test_options();
572 options.to = None;
573
574 let mut config = create_test_config();
575 config.tl.to = None;
576
577 let result = resolve_config(&options, &config);
578
579 assert!(result.is_err());
580 assert!(result.unwrap_err().to_string().contains("to"));
581 }
582
583 #[test]
584 fn test_resolve_config_api_key_required_but_missing() {
585 let mut options = create_test_options();
586 options.provider = Some("openrouter".to_string());
587
588 let config = create_test_config();
589
590 let result = resolve_config(&options, &config);
591
592 assert!(result.is_err());
593 assert!(result.unwrap_err().to_string().contains("API key"));
594 }
595
596 #[test]
597 fn test_load_or_default_nonexistent_file() {
598 let temp_dir = TempDir::new().unwrap();
599 let manager = create_test_manager(&temp_dir);
600
601 let result = manager.load_or_default();
603 assert!(result.is_ok());
604 let config = result.unwrap();
605 assert!(config.providers.is_empty());
606 }
607
608 #[test]
609 fn test_load_or_default_valid_file() {
610 let temp_dir = TempDir::new().unwrap();
611 let manager = create_test_manager(&temp_dir);
612
613 let config = create_test_config();
615 manager.save(&config).unwrap();
616
617 let result = manager.load_or_default();
619 assert!(result.is_ok());
620 let loaded = result.unwrap();
621 assert_eq!(loaded.tl.provider, Some("ollama".to_string()));
622 }
623
624 #[test]
625 fn test_load_or_default_invalid_file() {
626 let temp_dir = TempDir::new().unwrap();
627 let manager = create_test_manager(&temp_dir);
628
629 std::fs::write(&manager.config_path, "invalid toml [[[").unwrap();
631
632 let result = manager.load_or_default();
634 assert!(result.is_err());
635 assert!(result.unwrap_err().to_string().contains("parse"));
636 }
637
638 #[test]
639 #[cfg(unix)]
640 fn test_load_or_default_unreadable_file() {
641 use std::os::unix::fs::PermissionsExt;
642
643 let temp_dir = TempDir::new().unwrap();
644 let manager = create_test_manager(&temp_dir);
645
646 std::fs::write(&manager.config_path, "[tl]\nprovider = \"test\"").unwrap();
648
649 let mut perms = std::fs::metadata(&manager.config_path)
651 .unwrap()
652 .permissions();
653 perms.set_mode(0o000);
654 std::fs::set_permissions(&manager.config_path, perms).unwrap();
655
656 let result = manager.load_or_default();
658 assert!(result.is_err());
659
660 let mut perms = std::fs::metadata(&manager.config_path)
662 .unwrap()
663 .permissions();
664 perms.set_mode(0o644);
665 std::fs::set_permissions(&manager.config_path, perms).unwrap();
666 }
667}