1use std::collections::HashMap;
7use std::fmt;
8use std::path::Path;
9use std::str::FromStr;
10
11use serde::{Deserialize, Serialize};
12
13use crate::error::{Error, Result};
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
17#[serde(rename_all = "lowercase")]
18pub enum AudioFormat {
19 Mp3,
20 #[default]
21 Flac,
22 Wav,
23}
24
25impl FromStr for AudioFormat {
26 type Err = Error;
27
28 fn from_str(s: &str) -> Result<Self> {
29 match s.to_ascii_lowercase().as_str() {
30 "mp3" => Ok(Self::Mp3),
31 "flac" => Ok(Self::Flac),
32 "wav" => Ok(Self::Wav),
33 other => Err(Error::Config(format!("unknown format '{other}'"))),
34 }
35 }
36}
37
38impl fmt::Display for AudioFormat {
39 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
40 match self {
41 Self::Mp3 => f.write_str("mp3"),
42 Self::Flac => f.write_str("flac"),
43 Self::Wav => f.write_str("wav"),
44 }
45 }
46}
47
48#[derive(Debug, Clone, Default, Deserialize)]
50pub struct Defaults {
51 pub format: Option<AudioFormat>,
52 pub concurrency: Option<u32>,
53 pub retries: Option<u32>,
54 pub min_newest: Option<u32>,
55 pub animated_covers: Option<bool>,
56}
57
58#[derive(Debug, Clone, Default, Deserialize)]
60pub struct SourceConfig {
61 pub format: Option<AudioFormat>,
62 pub concurrency: Option<u32>,
63 pub retries: Option<u32>,
64 pub min_newest: Option<u32>,
65 pub animated_covers: Option<bool>,
66}
67
68#[derive(Debug, Clone, Default, Deserialize)]
70pub struct AccountConfig {
71 pub token: Option<String>,
72 pub root: Option<String>,
73 pub account_id: Option<String>,
77 pub format: Option<AudioFormat>,
78 pub concurrency: Option<u32>,
79 pub retries: Option<u32>,
80 pub min_newest: Option<u32>,
81 pub animated_covers: Option<bool>,
82 #[serde(default)]
83 pub sources: HashMap<String, SourceConfig>,
84}
85
86#[derive(Debug, Clone, Default, Deserialize)]
88pub struct Config {
89 #[serde(default)]
90 pub defaults: Defaults,
91 #[serde(default)]
92 pub accounts: HashMap<String, AccountConfig>,
93}
94
95impl Config {
96 pub fn from_toml(toml_str: &str) -> Result<Self> {
102 let config: Self = toml::from_str(toml_str).map_err(|e| {
103 let raw = e.to_string();
106 let msg = raw
107 .lines()
108 .filter(|l| !l.contains(" | "))
109 .collect::<Vec<_>>()
110 .join("\n")
111 .trim()
112 .to_owned();
113 Error::Config(if msg.is_empty() {
114 "parse error".into()
115 } else {
116 msg
117 })
118 })?;
119 config.validate()?;
120 Ok(config)
121 }
122
123 fn validate(&self) -> Result<()> {
124 let roots: Vec<(&str, &str)> = self
125 .accounts
126 .iter()
127 .filter_map(|(label, acc)| acc.root.as_deref().map(|r| (label.as_str(), r)))
128 .collect();
129
130 for (i, (label_a, root_a)) in roots.iter().enumerate() {
131 for (label_b, root_b) in roots.iter().skip(i + 1) {
132 let a = Path::new(root_a);
133 let b = Path::new(root_b);
134 if a.starts_with(b) || b.starts_with(a) {
135 return Err(Error::Config(format!(
136 "account roots nest: '{label_a}' ({root_a}) and '{label_b}' ({root_b})"
137 )));
138 }
139 }
140 }
141
142 let mut prefix_seen: HashMap<String, &str> = HashMap::new();
143 for label in self.accounts.keys() {
144 let prefix = label_to_env(label);
145 if let Some(other) = prefix_seen.get(&prefix) {
146 return Err(Error::Config(format!(
147 "accounts '{label}' and '{other}' share env prefix '{prefix}'"
148 )));
149 }
150 prefix_seen.insert(prefix, label.as_str());
151 }
152
153 Ok(())
154 }
155
156 pub fn resolve(
162 &self,
163 account: &str,
164 source: Option<&str>,
165 env: &HashMap<String, String>,
166 flags: &FlagOverrides,
167 ) -> Result<EffectiveSettings> {
168 let acc = self
169 .accounts
170 .get(account)
171 .ok_or_else(|| Error::Config(format!("account '{account}' not found")))?;
172
173 let src = source.and_then(|s| acc.sources.get(s));
174 let label_env = label_to_env(account);
175
176 let env_val = |suffix: &str| -> Option<&str> {
178 env.get(&format!("SUNO_{label_env}_{suffix}"))
179 .or_else(|| env.get(&format!("SUNO_{suffix}")))
180 .map(String::as_str)
181 };
182
183 let format_from_env = env_val("FORMAT")
184 .map(str::parse::<AudioFormat>)
185 .transpose()?;
186
187 let format = flags
188 .format
189 .or(format_from_env)
190 .or_else(|| src.and_then(|s| s.format))
191 .or(acc.format)
192 .or(self.defaults.format)
193 .unwrap_or(AudioFormat::Flac);
194
195 let concurrency = resolve_u32(
196 flags.concurrency,
197 env_val("CONCURRENCY"),
198 src.and_then(|s| s.concurrency),
199 acc.concurrency,
200 self.defaults.concurrency,
201 4,
202 "CONCURRENCY",
203 )?;
204
205 let retries = resolve_u32(
206 flags.retries,
207 env_val("RETRIES"),
208 src.and_then(|s| s.retries),
209 acc.retries,
210 self.defaults.retries,
211 3,
212 "RETRIES",
213 )?;
214
215 let min_newest = resolve_u32(
216 flags.min_newest,
217 env_val("MIN_NEWEST"),
218 src.and_then(|s| s.min_newest),
219 acc.min_newest,
220 self.defaults.min_newest,
221 1,
222 "MIN_NEWEST",
223 )?;
224
225 let animated_covers = resolve_bool(
226 flags.animated_covers,
227 env_val("ANIMATED_COVERS"),
228 src.and_then(|s| s.animated_covers),
229 acc.animated_covers,
230 self.defaults.animated_covers,
231 false,
232 "ANIMATED_COVERS",
233 )?;
234
235 let token = flags
236 .token
237 .clone()
238 .or_else(|| env.get(&format!("SUNO_{label_env}_TOKEN")).cloned())
239 .or_else(|| env.get("SUNO_TOKEN").cloned())
240 .or_else(|| acc.token.clone());
241
242 Ok(EffectiveSettings {
243 token,
244 account_id: acc.account_id.clone(),
245 format,
246 concurrency,
247 retries,
248 min_newest,
249 animated_covers,
250 })
251 }
252}
253
254fn resolve_u32(
255 flag: Option<u32>,
256 env_str: Option<&str>,
257 src: Option<u32>,
258 acc: Option<u32>,
259 defaults: Option<u32>,
260 compiled: u32,
261 name: &str,
262) -> Result<u32> {
263 if let Some(v) = flag {
264 return Ok(v);
265 }
266 if let Some(s) = env_str {
267 return s
268 .parse()
269 .map_err(|_| Error::Config(format!("invalid {name}: '{s}'")));
270 }
271 Ok(src.or(acc).or(defaults).unwrap_or(compiled))
272}
273
274fn resolve_bool(
275 flag: Option<bool>,
276 env_str: Option<&str>,
277 src: Option<bool>,
278 acc: Option<bool>,
279 defaults: Option<bool>,
280 compiled: bool,
281 name: &str,
282) -> Result<bool> {
283 if let Some(v) = flag {
284 return Ok(v);
285 }
286 if let Some(s) = env_str {
287 return s
288 .parse()
289 .map_err(|_| Error::Config(format!("invalid {name}: '{s}'")));
290 }
291 Ok(src.or(acc).or(defaults).unwrap_or(compiled))
292}
293
294fn label_to_env(label: &str) -> String {
298 label.to_ascii_uppercase().replace('-', "_")
299}
300
301#[derive(Debug, Default)]
304pub struct FlagOverrides {
305 pub token: Option<String>,
306 pub format: Option<AudioFormat>,
307 pub concurrency: Option<u32>,
308 pub retries: Option<u32>,
309 pub min_newest: Option<u32>,
310 pub animated_covers: Option<bool>,
311}
312
313#[derive(Debug, Clone, PartialEq)]
315pub struct EffectiveSettings {
316 pub token: Option<String>,
317 pub account_id: Option<String>,
319 pub format: AudioFormat,
320 pub concurrency: u32,
321 pub retries: u32,
322 pub min_newest: u32,
323 pub animated_covers: bool,
324}
325
326#[cfg(test)]
327mod tests {
328 use super::*;
329
330 fn no_env() -> HashMap<String, String> {
331 HashMap::new()
332 }
333
334 fn no_flags() -> FlagOverrides {
335 FlagOverrides::default()
336 }
337
338 #[test]
339 fn parse_empty_toml() {
340 let cfg = Config::from_toml("").unwrap();
341 assert!(cfg.accounts.is_empty());
342 }
343
344 #[test]
345 fn parse_basic_account() {
346 let toml = r#"
347 [accounts.alice]
348 token = "tok"
349 root = "/music"
350 "#;
351 let cfg = Config::from_toml(toml).unwrap();
352 let acc = &cfg.accounts["alice"];
353 assert_eq!(acc.token.as_deref(), Some("tok"));
354 assert_eq!(acc.root.as_deref(), Some("/music"));
355 }
356
357 #[test]
358 fn account_id_parses_and_resolves() {
359 let toml = r#"
360 [accounts.alice]
361 token = "tok"
362 root = "/music"
363 account_id = "user_abc123"
364 "#;
365 let cfg = Config::from_toml(toml).unwrap();
366 assert_eq!(
367 cfg.accounts["alice"].account_id.as_deref(),
368 Some("user_abc123")
369 );
370 let eff = cfg.resolve("alice", None, &no_env(), &no_flags()).unwrap();
371 assert_eq!(eff.account_id.as_deref(), Some("user_abc123"));
372 }
373
374 #[test]
375 fn parse_defaults_section() {
376 let toml = r#"
377 [defaults]
378 format = "mp3"
379 concurrency = 8
380 retries = 5
381 min_newest = 2
382 animated_covers = true
383 "#;
384 let cfg = Config::from_toml(toml).unwrap();
385 assert_eq!(cfg.defaults.format, Some(AudioFormat::Mp3));
386 assert_eq!(cfg.defaults.concurrency, Some(8));
387 assert_eq!(cfg.defaults.retries, Some(5));
388 assert_eq!(cfg.defaults.min_newest, Some(2));
389 assert_eq!(cfg.defaults.animated_covers, Some(true));
390 }
391
392 #[test]
393 fn compiled_defaults_when_nothing_set() {
394 let toml = "[accounts.alice]\n";
395 let cfg = Config::from_toml(toml).unwrap();
396 let eff = cfg.resolve("alice", None, &no_env(), &no_flags()).unwrap();
397 assert_eq!(
398 eff,
399 EffectiveSettings {
400 token: None,
401 account_id: None,
402 format: AudioFormat::Flac,
403 concurrency: 4,
404 retries: 3,
405 min_newest: 1,
406 animated_covers: false,
407 }
408 );
409 }
410
411 #[test]
412 fn file_defaults_override_compiled() {
413 let toml = r#"
414 [defaults]
415 format = "mp3"
416 concurrency = 8
417
418 [accounts.alice]
419 "#;
420 let cfg = Config::from_toml(toml).unwrap();
421 let eff = cfg.resolve("alice", None, &no_env(), &no_flags()).unwrap();
422 assert_eq!(eff.format, AudioFormat::Mp3);
423 assert_eq!(eff.concurrency, 8);
424 assert_eq!(eff.retries, 3); }
426
427 #[test]
428 fn account_settings_override_defaults() {
429 let toml = r#"
430 [defaults]
431 format = "mp3"
432
433 [accounts.alice]
434 format = "wav"
435 "#;
436 let cfg = Config::from_toml(toml).unwrap();
437 let eff = cfg.resolve("alice", None, &no_env(), &no_flags()).unwrap();
438 assert_eq!(eff.format, AudioFormat::Wav);
439 }
440
441 #[test]
442 fn per_source_overrides_account() {
443 let toml = r#"
444 [accounts.alice]
445 format = "flac"
446
447 [accounts.alice.sources.liked]
448 format = "mp3"
449 "#;
450 let cfg = Config::from_toml(toml).unwrap();
451 let eff = cfg
452 .resolve("alice", Some("liked"), &no_env(), &no_flags())
453 .unwrap();
454 assert_eq!(eff.format, AudioFormat::Mp3);
455 }
456
457 #[test]
458 fn unknown_source_falls_back_to_account() {
459 let toml = r#"
460 [accounts.alice]
461 format = "wav"
462 "#;
463 let cfg = Config::from_toml(toml).unwrap();
464 let eff = cfg
465 .resolve("alice", Some("nonexistent"), &no_env(), &no_flags())
466 .unwrap();
467 assert_eq!(eff.format, AudioFormat::Wav);
468 }
469
470 #[test]
471 fn global_env_overrides_file() {
472 let toml = r#"
473 [accounts.alice]
474 format = "flac"
475 "#;
476 let cfg = Config::from_toml(toml).unwrap();
477 let env: HashMap<String, String> =
478 [("SUNO_FORMAT".into(), "mp3".into())].into_iter().collect();
479 let eff = cfg.resolve("alice", None, &env, &no_flags()).unwrap();
480 assert_eq!(eff.format, AudioFormat::Mp3);
481 }
482
483 #[test]
484 fn per_account_env_overrides_global_env() {
485 let toml = "[accounts.alice]\n";
486 let cfg = Config::from_toml(toml).unwrap();
487 let env: HashMap<String, String> = [
488 ("SUNO_FORMAT".into(), "mp3".into()),
489 ("SUNO_ALICE_FORMAT".into(), "wav".into()),
490 ]
491 .into_iter()
492 .collect();
493 let eff = cfg.resolve("alice", None, &env, &no_flags()).unwrap();
494 assert_eq!(eff.format, AudioFormat::Wav);
495 }
496
497 #[test]
498 fn per_account_env_label_uppersnakedcase() {
499 let toml = "[accounts.my-lib]\n";
500 let cfg = Config::from_toml(toml).unwrap();
501 let env: HashMap<String, String> = [("SUNO_MY_LIB_FORMAT".into(), "wav".into())]
502 .into_iter()
503 .collect();
504 let eff = cfg.resolve("my-lib", None, &env, &no_flags()).unwrap();
505 assert_eq!(eff.format, AudioFormat::Wav);
506 }
507
508 #[test]
509 fn flag_overrides_env_and_file() {
510 let toml = r#"
511 [accounts.alice]
512 format = "flac"
513 "#;
514 let cfg = Config::from_toml(toml).unwrap();
515 let env: HashMap<String, String> =
516 [("SUNO_FORMAT".into(), "mp3".into())].into_iter().collect();
517 let flags = FlagOverrides {
518 format: Some(AudioFormat::Wav),
519 ..Default::default()
520 };
521 let eff = cfg.resolve("alice", None, &env, &flags).unwrap();
522 assert_eq!(eff.format, AudioFormat::Wav);
523 }
524
525 #[test]
526 fn token_precedence() {
527 let toml = r#"
528 [accounts.alice]
529 token = "file_tok"
530 "#;
531 let cfg = Config::from_toml(toml).unwrap();
532
533 let env: HashMap<String, String> = [("SUNO_TOKEN".into(), "env_tok".into())]
535 .into_iter()
536 .collect();
537 let eff = cfg.resolve("alice", None, &env, &no_flags()).unwrap();
538 assert_eq!(eff.token.as_deref(), Some("env_tok"));
539
540 let flags = FlagOverrides {
542 token: Some("flag_tok".into()),
543 ..Default::default()
544 };
545 let eff = cfg.resolve("alice", None, &env, &flags).unwrap();
546 assert_eq!(eff.token.as_deref(), Some("flag_tok"));
547 }
548
549 #[test]
550 fn per_account_token_env_overrides_global() {
551 let toml = "[accounts.alice]\n";
552 let cfg = Config::from_toml(toml).unwrap();
553 let env: HashMap<String, String> = [
554 ("SUNO_TOKEN".into(), "global".into()),
555 ("SUNO_ALICE_TOKEN".into(), "per_account".into()),
556 ]
557 .into_iter()
558 .collect();
559 let eff = cfg.resolve("alice", None, &env, &no_flags()).unwrap();
560 assert_eq!(eff.token.as_deref(), Some("per_account"));
561 }
562
563 #[test]
564 fn invalid_env_u32_errors() {
565 let toml = "[accounts.alice]\n";
566 let cfg = Config::from_toml(toml).unwrap();
567 let env: HashMap<String, String> = [("SUNO_CONCURRENCY".into(), "many".into())]
568 .into_iter()
569 .collect();
570 assert!(cfg.resolve("alice", None, &env, &no_flags()).is_err());
571 }
572
573 #[test]
574 fn animated_covers_defaults_off_and_follows_precedence() {
575 let cfg = Config::from_toml("[accounts.alice]\n").unwrap();
577 let eff = cfg.resolve("alice", None, &no_env(), &no_flags()).unwrap();
578 assert!(!eff.animated_covers);
579
580 let toml = r#"
582 [defaults]
583 animated_covers = true
584
585 [accounts.alice.sources.liked]
586 animated_covers = false
587 "#;
588 let cfg = Config::from_toml(toml).unwrap();
589
590 let eff = cfg.resolve("alice", None, &no_env(), &no_flags()).unwrap();
592 assert!(eff.animated_covers);
593
594 let eff = cfg
596 .resolve("alice", Some("liked"), &no_env(), &no_flags())
597 .unwrap();
598 assert!(!eff.animated_covers);
599
600 let env: HashMap<String, String> = [("SUNO_ANIMATED_COVERS".into(), "true".into())]
602 .into_iter()
603 .collect();
604 let eff = cfg
605 .resolve("alice", Some("liked"), &env, &no_flags())
606 .unwrap();
607 assert!(eff.animated_covers);
608
609 let flags = FlagOverrides {
611 animated_covers: Some(false),
612 ..Default::default()
613 };
614 let eff = cfg.resolve("alice", Some("liked"), &env, &flags).unwrap();
615 assert!(!eff.animated_covers);
616 }
617
618 #[test]
619 fn invalid_env_bool_errors() {
620 let toml = "[accounts.alice]\n";
621 let cfg = Config::from_toml(toml).unwrap();
622 let env: HashMap<String, String> = [("SUNO_ANIMATED_COVERS".into(), "yes".into())]
623 .into_iter()
624 .collect();
625 assert!(cfg.resolve("alice", None, &env, &no_flags()).is_err());
626 }
627
628 #[test]
629 fn unknown_account_errors() {
630 let cfg = Config::from_toml("").unwrap();
631 assert!(cfg.resolve("nobody", None, &no_env(), &no_flags()).is_err());
632 }
633
634 #[test]
635 fn validation_nested_roots() {
636 let toml = r#"
637 [accounts.alice]
638 root = "/music"
639
640 [accounts.bob]
641 root = "/music/bob"
642 "#;
643 assert!(Config::from_toml(toml).is_err());
644 }
645
646 #[test]
647 fn validation_non_nested_roots_ok() {
648 let toml = r#"
649 [accounts.alice]
650 root = "/music/alice"
651
652 [accounts.bob]
653 root = "/music/bob"
654 "#;
655 assert!(Config::from_toml(toml).is_ok());
656 }
657
658 #[test]
659 fn invalid_toml_errors() {
660 assert!(Config::from_toml("not valid toml ][").is_err());
661 }
662
663 #[test]
664 fn duplicate_account_label_errors() {
665 let toml = "
667 [accounts.alice]
668 token = \"tok1\"
669
670 [accounts.alice]
671 token = \"tok2\"
672 ";
673 assert!(Config::from_toml(toml).is_err());
674 }
675
676 #[test]
677 fn parse_error_does_not_echo_token() {
678 let toml = "[accounts.alice]\ntoken = \"unterminated\n";
680 let err = Config::from_toml(toml).unwrap_err().to_string();
681 assert!(!err.contains("unterminated"), "error leaked token: {err}");
682 }
683
684 #[test]
685 fn validation_env_prefix_collision_errors() {
686 let toml = "
688 [accounts.my-lib]
689 [accounts.my_lib]
690 ";
691 assert!(Config::from_toml(toml).is_err());
692 }
693
694 #[test]
695 fn audio_format_display_roundtrip() {
696 for fmt in [AudioFormat::Mp3, AudioFormat::Flac, AudioFormat::Wav] {
697 let s = fmt.to_string();
698 assert_eq!(s.parse::<AudioFormat>().unwrap(), fmt);
699 }
700 }
701}