Skip to main content

suno_core/
config.rs

1//! Configuration model and precedence resolution.
2//!
3//! Parses a TOML string and merges in environment variables and CLI flag
4//! overrides supplied by the caller. Performs no disk or environment IO.
5
6use std::collections::HashMap;
7use std::fmt;
8use std::path::Path;
9use std::str::FromStr;
10
11use serde::{Deserialize, Serialize};
12
13use crate::error::{Error, Result};
14
15/// Audio format for downloaded clips.
16#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
17#[serde(rename_all = "lowercase")]
18pub enum AudioFormat {
19    Mp3,
20    #[default]
21    Flac,
22    Wav,
23}
24
25impl FromStr for AudioFormat {
26    type Err = Error;
27
28    fn from_str(s: &str) -> Result<Self> {
29        match s.to_ascii_lowercase().as_str() {
30            "mp3" => Ok(Self::Mp3),
31            "flac" => Ok(Self::Flac),
32            "wav" => Ok(Self::Wav),
33            other => Err(Error::Config(format!("unknown format '{other}'"))),
34        }
35    }
36}
37
38impl fmt::Display for AudioFormat {
39    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
40        match self {
41            Self::Mp3 => f.write_str("mp3"),
42            Self::Flac => f.write_str("flac"),
43            Self::Wav => f.write_str("wav"),
44        }
45    }
46}
47
48/// Global default settings applied when no account or source override applies.
49#[derive(Debug, Clone, Default, Deserialize)]
50pub struct Defaults {
51    pub format: Option<AudioFormat>,
52    pub concurrency: Option<u32>,
53    pub retries: Option<u32>,
54    pub min_newest: Option<u32>,
55    pub animated_covers: Option<bool>,
56}
57
58/// Per-source overridable settings within an account.
59#[derive(Debug, Clone, Default, Deserialize)]
60pub struct SourceConfig {
61    pub format: Option<AudioFormat>,
62    pub concurrency: Option<u32>,
63    pub retries: Option<u32>,
64    pub min_newest: Option<u32>,
65    pub animated_covers: Option<bool>,
66}
67
68/// Configuration for a single named account.
69#[derive(Debug, Clone, Default, Deserialize)]
70pub struct AccountConfig {
71    pub token: Option<String>,
72    pub root: Option<String>,
73    pub format: Option<AudioFormat>,
74    pub concurrency: Option<u32>,
75    pub retries: Option<u32>,
76    pub min_newest: Option<u32>,
77    pub animated_covers: Option<bool>,
78    #[serde(default)]
79    pub sources: HashMap<String, SourceConfig>,
80}
81
82/// Top-level configuration parsed from a TOML file.
83#[derive(Debug, Clone, Default, Deserialize)]
84pub struct Config {
85    #[serde(default)]
86    pub defaults: Defaults,
87    #[serde(default)]
88    pub accounts: HashMap<String, AccountConfig>,
89}
90
91impl Config {
92    /// Parse `toml_str` and validate the result.
93    ///
94    /// Validation rejects any pair of accounts whose root directories nest
95    /// inside one another. Duplicate account labels are rejected by the TOML
96    /// parser itself.
97    pub fn from_toml(toml_str: &str) -> Result<Self> {
98        let config: Self = toml::from_str(toml_str).map_err(|e| {
99            // Strip source-context lines (those containing " | ") to prevent
100            // token values from being echoed in error messages.
101            let raw = e.to_string();
102            let msg = raw
103                .lines()
104                .filter(|l| !l.contains(" | "))
105                .collect::<Vec<_>>()
106                .join("\n")
107                .trim()
108                .to_owned();
109            Error::Config(if msg.is_empty() {
110                "parse error".into()
111            } else {
112                msg
113            })
114        })?;
115        config.validate()?;
116        Ok(config)
117    }
118
119    fn validate(&self) -> Result<()> {
120        let roots: Vec<(&str, &str)> = self
121            .accounts
122            .iter()
123            .filter_map(|(label, acc)| acc.root.as_deref().map(|r| (label.as_str(), r)))
124            .collect();
125
126        for (i, (label_a, root_a)) in roots.iter().enumerate() {
127            for (label_b, root_b) in roots.iter().skip(i + 1) {
128                let a = Path::new(root_a);
129                let b = Path::new(root_b);
130                if a.starts_with(b) || b.starts_with(a) {
131                    return Err(Error::Config(format!(
132                        "account roots nest: '{label_a}' ({root_a}) and '{label_b}' ({root_b})"
133                    )));
134                }
135            }
136        }
137
138        let mut prefix_seen: HashMap<String, &str> = HashMap::new();
139        for label in self.accounts.keys() {
140            let prefix = label_to_env(label);
141            if let Some(other) = prefix_seen.get(&prefix) {
142                return Err(Error::Config(format!(
143                    "accounts '{label}' and '{other}' share env prefix '{prefix}'"
144                )));
145            }
146            prefix_seen.insert(prefix, label.as_str());
147        }
148
149        Ok(())
150    }
151
152    /// Compute effective settings for `account`, optionally scoped to `source`.
153    ///
154    /// The caller supplies the full environment map and any CLI flag overrides.
155    /// Precedence per field: flag > per-account env > global env > per-source
156    /// file > per-account file > global file defaults > compiled default.
157    pub fn resolve(
158        &self,
159        account: &str,
160        source: Option<&str>,
161        env: &HashMap<String, String>,
162        flags: &FlagOverrides,
163    ) -> Result<EffectiveSettings> {
164        let acc = self
165            .accounts
166            .get(account)
167            .ok_or_else(|| Error::Config(format!("account '{account}' not found")))?;
168
169        let src = source.and_then(|s| acc.sources.get(s));
170        let label_env = label_to_env(account);
171
172        // Look up per-account env first, falling back to global.
173        let env_val = |suffix: &str| -> Option<&str> {
174            env.get(&format!("SUNO_{label_env}_{suffix}"))
175                .or_else(|| env.get(&format!("SUNO_{suffix}")))
176                .map(String::as_str)
177        };
178
179        let format_from_env = env_val("FORMAT")
180            .map(str::parse::<AudioFormat>)
181            .transpose()?;
182
183        let format = flags
184            .format
185            .or(format_from_env)
186            .or_else(|| src.and_then(|s| s.format))
187            .or(acc.format)
188            .or(self.defaults.format)
189            .unwrap_or(AudioFormat::Flac);
190
191        let concurrency = resolve_u32(
192            flags.concurrency,
193            env_val("CONCURRENCY"),
194            src.and_then(|s| s.concurrency),
195            acc.concurrency,
196            self.defaults.concurrency,
197            4,
198            "CONCURRENCY",
199        )?;
200
201        let retries = resolve_u32(
202            flags.retries,
203            env_val("RETRIES"),
204            src.and_then(|s| s.retries),
205            acc.retries,
206            self.defaults.retries,
207            3,
208            "RETRIES",
209        )?;
210
211        let min_newest = resolve_u32(
212            flags.min_newest,
213            env_val("MIN_NEWEST"),
214            src.and_then(|s| s.min_newest),
215            acc.min_newest,
216            self.defaults.min_newest,
217            1,
218            "MIN_NEWEST",
219        )?;
220
221        let animated_covers = resolve_bool(
222            flags.animated_covers,
223            env_val("ANIMATED_COVERS"),
224            src.and_then(|s| s.animated_covers),
225            acc.animated_covers,
226            self.defaults.animated_covers,
227            false,
228            "ANIMATED_COVERS",
229        )?;
230
231        let token = flags
232            .token
233            .clone()
234            .or_else(|| env.get(&format!("SUNO_{label_env}_TOKEN")).cloned())
235            .or_else(|| env.get("SUNO_TOKEN").cloned())
236            .or_else(|| acc.token.clone());
237
238        Ok(EffectiveSettings {
239            token,
240            format,
241            concurrency,
242            retries,
243            min_newest,
244            animated_covers,
245        })
246    }
247}
248
249fn resolve_u32(
250    flag: Option<u32>,
251    env_str: Option<&str>,
252    src: Option<u32>,
253    acc: Option<u32>,
254    defaults: Option<u32>,
255    compiled: u32,
256    name: &str,
257) -> Result<u32> {
258    if let Some(v) = flag {
259        return Ok(v);
260    }
261    if let Some(s) = env_str {
262        return s
263            .parse()
264            .map_err(|_| Error::Config(format!("invalid {name}: '{s}'")));
265    }
266    Ok(src.or(acc).or(defaults).unwrap_or(compiled))
267}
268
269fn resolve_bool(
270    flag: Option<bool>,
271    env_str: Option<&str>,
272    src: Option<bool>,
273    acc: Option<bool>,
274    defaults: Option<bool>,
275    compiled: bool,
276    name: &str,
277) -> Result<bool> {
278    if let Some(v) = flag {
279        return Ok(v);
280    }
281    if let Some(s) = env_str {
282        return s
283            .parse()
284            .map_err(|_| Error::Config(format!("invalid {name}: '{s}'")));
285    }
286    Ok(src.or(acc).or(defaults).unwrap_or(compiled))
287}
288
289/// Convert an account label to its environment variable prefix.
290///
291/// `my-lib` becomes `MY_LIB`.
292fn label_to_env(label: &str) -> String {
293    label.to_ascii_uppercase().replace('-', "_")
294}
295
296/// CLI flag overrides passed to [`Config::resolve`]. `None` means the flag
297/// was not provided.
298#[derive(Debug, Default)]
299pub struct FlagOverrides {
300    pub token: Option<String>,
301    pub format: Option<AudioFormat>,
302    pub concurrency: Option<u32>,
303    pub retries: Option<u32>,
304    pub min_newest: Option<u32>,
305    pub animated_covers: Option<bool>,
306}
307
308/// Resolved effective settings for one account/source combination.
309#[derive(Debug, Clone, PartialEq)]
310pub struct EffectiveSettings {
311    pub token: Option<String>,
312    pub format: AudioFormat,
313    pub concurrency: u32,
314    pub retries: u32,
315    pub min_newest: u32,
316    pub animated_covers: bool,
317}
318
319#[cfg(test)]
320mod tests {
321    use super::*;
322
323    fn no_env() -> HashMap<String, String> {
324        HashMap::new()
325    }
326
327    fn no_flags() -> FlagOverrides {
328        FlagOverrides::default()
329    }
330
331    #[test]
332    fn parse_empty_toml() {
333        let cfg = Config::from_toml("").unwrap();
334        assert!(cfg.accounts.is_empty());
335    }
336
337    #[test]
338    fn parse_basic_account() {
339        let toml = r#"
340            [accounts.alice]
341            token = "tok"
342            root = "/music"
343        "#;
344        let cfg = Config::from_toml(toml).unwrap();
345        let acc = &cfg.accounts["alice"];
346        assert_eq!(acc.token.as_deref(), Some("tok"));
347        assert_eq!(acc.root.as_deref(), Some("/music"));
348    }
349
350    #[test]
351    fn parse_defaults_section() {
352        let toml = r#"
353            [defaults]
354            format = "mp3"
355            concurrency = 8
356            retries = 5
357            min_newest = 2
358            animated_covers = true
359        "#;
360        let cfg = Config::from_toml(toml).unwrap();
361        assert_eq!(cfg.defaults.format, Some(AudioFormat::Mp3));
362        assert_eq!(cfg.defaults.concurrency, Some(8));
363        assert_eq!(cfg.defaults.retries, Some(5));
364        assert_eq!(cfg.defaults.min_newest, Some(2));
365        assert_eq!(cfg.defaults.animated_covers, Some(true));
366    }
367
368    #[test]
369    fn compiled_defaults_when_nothing_set() {
370        let toml = "[accounts.alice]\n";
371        let cfg = Config::from_toml(toml).unwrap();
372        let eff = cfg.resolve("alice", None, &no_env(), &no_flags()).unwrap();
373        assert_eq!(
374            eff,
375            EffectiveSettings {
376                token: None,
377                format: AudioFormat::Flac,
378                concurrency: 4,
379                retries: 3,
380                min_newest: 1,
381                animated_covers: false,
382            }
383        );
384    }
385
386    #[test]
387    fn file_defaults_override_compiled() {
388        let toml = r#"
389            [defaults]
390            format = "mp3"
391            concurrency = 8
392
393            [accounts.alice]
394        "#;
395        let cfg = Config::from_toml(toml).unwrap();
396        let eff = cfg.resolve("alice", None, &no_env(), &no_flags()).unwrap();
397        assert_eq!(eff.format, AudioFormat::Mp3);
398        assert_eq!(eff.concurrency, 8);
399        assert_eq!(eff.retries, 3); // compiled default
400    }
401
402    #[test]
403    fn account_settings_override_defaults() {
404        let toml = r#"
405            [defaults]
406            format = "mp3"
407
408            [accounts.alice]
409            format = "wav"
410        "#;
411        let cfg = Config::from_toml(toml).unwrap();
412        let eff = cfg.resolve("alice", None, &no_env(), &no_flags()).unwrap();
413        assert_eq!(eff.format, AudioFormat::Wav);
414    }
415
416    #[test]
417    fn per_source_overrides_account() {
418        let toml = r#"
419            [accounts.alice]
420            format = "flac"
421
422            [accounts.alice.sources.liked]
423            format = "mp3"
424        "#;
425        let cfg = Config::from_toml(toml).unwrap();
426        let eff = cfg
427            .resolve("alice", Some("liked"), &no_env(), &no_flags())
428            .unwrap();
429        assert_eq!(eff.format, AudioFormat::Mp3);
430    }
431
432    #[test]
433    fn unknown_source_falls_back_to_account() {
434        let toml = r#"
435            [accounts.alice]
436            format = "wav"
437        "#;
438        let cfg = Config::from_toml(toml).unwrap();
439        let eff = cfg
440            .resolve("alice", Some("nonexistent"), &no_env(), &no_flags())
441            .unwrap();
442        assert_eq!(eff.format, AudioFormat::Wav);
443    }
444
445    #[test]
446    fn global_env_overrides_file() {
447        let toml = r#"
448            [accounts.alice]
449            format = "flac"
450        "#;
451        let cfg = Config::from_toml(toml).unwrap();
452        let env: HashMap<String, String> =
453            [("SUNO_FORMAT".into(), "mp3".into())].into_iter().collect();
454        let eff = cfg.resolve("alice", None, &env, &no_flags()).unwrap();
455        assert_eq!(eff.format, AudioFormat::Mp3);
456    }
457
458    #[test]
459    fn per_account_env_overrides_global_env() {
460        let toml = "[accounts.alice]\n";
461        let cfg = Config::from_toml(toml).unwrap();
462        let env: HashMap<String, String> = [
463            ("SUNO_FORMAT".into(), "mp3".into()),
464            ("SUNO_ALICE_FORMAT".into(), "wav".into()),
465        ]
466        .into_iter()
467        .collect();
468        let eff = cfg.resolve("alice", None, &env, &no_flags()).unwrap();
469        assert_eq!(eff.format, AudioFormat::Wav);
470    }
471
472    #[test]
473    fn per_account_env_label_uppersnakedcase() {
474        let toml = "[accounts.my-lib]\n";
475        let cfg = Config::from_toml(toml).unwrap();
476        let env: HashMap<String, String> = [("SUNO_MY_LIB_FORMAT".into(), "wav".into())]
477            .into_iter()
478            .collect();
479        let eff = cfg.resolve("my-lib", None, &env, &no_flags()).unwrap();
480        assert_eq!(eff.format, AudioFormat::Wav);
481    }
482
483    #[test]
484    fn flag_overrides_env_and_file() {
485        let toml = r#"
486            [accounts.alice]
487            format = "flac"
488        "#;
489        let cfg = Config::from_toml(toml).unwrap();
490        let env: HashMap<String, String> =
491            [("SUNO_FORMAT".into(), "mp3".into())].into_iter().collect();
492        let flags = FlagOverrides {
493            format: Some(AudioFormat::Wav),
494            ..Default::default()
495        };
496        let eff = cfg.resolve("alice", None, &env, &flags).unwrap();
497        assert_eq!(eff.format, AudioFormat::Wav);
498    }
499
500    #[test]
501    fn token_precedence() {
502        let toml = r#"
503            [accounts.alice]
504            token = "file_tok"
505        "#;
506        let cfg = Config::from_toml(toml).unwrap();
507
508        // env overrides file
509        let env: HashMap<String, String> = [("SUNO_TOKEN".into(), "env_tok".into())]
510            .into_iter()
511            .collect();
512        let eff = cfg.resolve("alice", None, &env, &no_flags()).unwrap();
513        assert_eq!(eff.token.as_deref(), Some("env_tok"));
514
515        // flag overrides env
516        let flags = FlagOverrides {
517            token: Some("flag_tok".into()),
518            ..Default::default()
519        };
520        let eff = cfg.resolve("alice", None, &env, &flags).unwrap();
521        assert_eq!(eff.token.as_deref(), Some("flag_tok"));
522    }
523
524    #[test]
525    fn per_account_token_env_overrides_global() {
526        let toml = "[accounts.alice]\n";
527        let cfg = Config::from_toml(toml).unwrap();
528        let env: HashMap<String, String> = [
529            ("SUNO_TOKEN".into(), "global".into()),
530            ("SUNO_ALICE_TOKEN".into(), "per_account".into()),
531        ]
532        .into_iter()
533        .collect();
534        let eff = cfg.resolve("alice", None, &env, &no_flags()).unwrap();
535        assert_eq!(eff.token.as_deref(), Some("per_account"));
536    }
537
538    #[test]
539    fn invalid_env_u32_errors() {
540        let toml = "[accounts.alice]\n";
541        let cfg = Config::from_toml(toml).unwrap();
542        let env: HashMap<String, String> = [("SUNO_CONCURRENCY".into(), "many".into())]
543            .into_iter()
544            .collect();
545        assert!(cfg.resolve("alice", None, &env, &no_flags()).is_err());
546    }
547
548    #[test]
549    fn animated_covers_defaults_off_and_follows_precedence() {
550        // Compiled default is off.
551        let cfg = Config::from_toml("[accounts.alice]\n").unwrap();
552        let eff = cfg.resolve("alice", None, &no_env(), &no_flags()).unwrap();
553        assert!(!eff.animated_covers);
554
555        // File default on; per-source off; env on; flag off — flag wins.
556        let toml = r#"
557            [defaults]
558            animated_covers = true
559
560            [accounts.alice.sources.liked]
561            animated_covers = false
562        "#;
563        let cfg = Config::from_toml(toml).unwrap();
564
565        // File default (defaults) turns it on for an unscoped resolve.
566        let eff = cfg.resolve("alice", None, &no_env(), &no_flags()).unwrap();
567        assert!(eff.animated_covers);
568
569        // Per-source file setting overrides the file default.
570        let eff = cfg
571            .resolve("alice", Some("liked"), &no_env(), &no_flags())
572            .unwrap();
573        assert!(!eff.animated_covers);
574
575        // Env overrides file (even the per-source off).
576        let env: HashMap<String, String> = [("SUNO_ANIMATED_COVERS".into(), "true".into())]
577            .into_iter()
578            .collect();
579        let eff = cfg
580            .resolve("alice", Some("liked"), &env, &no_flags())
581            .unwrap();
582        assert!(eff.animated_covers);
583
584        // Flag overrides env.
585        let flags = FlagOverrides {
586            animated_covers: Some(false),
587            ..Default::default()
588        };
589        let eff = cfg.resolve("alice", Some("liked"), &env, &flags).unwrap();
590        assert!(!eff.animated_covers);
591    }
592
593    #[test]
594    fn invalid_env_bool_errors() {
595        let toml = "[accounts.alice]\n";
596        let cfg = Config::from_toml(toml).unwrap();
597        let env: HashMap<String, String> = [("SUNO_ANIMATED_COVERS".into(), "yes".into())]
598            .into_iter()
599            .collect();
600        assert!(cfg.resolve("alice", None, &env, &no_flags()).is_err());
601    }
602
603    #[test]
604    fn unknown_account_errors() {
605        let cfg = Config::from_toml("").unwrap();
606        assert!(cfg.resolve("nobody", None, &no_env(), &no_flags()).is_err());
607    }
608
609    #[test]
610    fn validation_nested_roots() {
611        let toml = r#"
612            [accounts.alice]
613            root = "/music"
614
615            [accounts.bob]
616            root = "/music/bob"
617        "#;
618        assert!(Config::from_toml(toml).is_err());
619    }
620
621    #[test]
622    fn validation_non_nested_roots_ok() {
623        let toml = r#"
624            [accounts.alice]
625            root = "/music/alice"
626
627            [accounts.bob]
628            root = "/music/bob"
629        "#;
630        assert!(Config::from_toml(toml).is_ok());
631    }
632
633    #[test]
634    fn invalid_toml_errors() {
635        assert!(Config::from_toml("not valid toml ][").is_err());
636    }
637
638    #[test]
639    fn duplicate_account_label_errors() {
640        // The TOML spec prohibits duplicate keys; the parser must reject this.
641        let toml = "
642            [accounts.alice]
643            token = \"tok1\"
644
645            [accounts.alice]
646            token = \"tok2\"
647        ";
648        assert!(Config::from_toml(toml).is_err());
649    }
650
651    #[test]
652    fn parse_error_does_not_echo_token() {
653        // A malformed token line must not include the raw value in the error.
654        let toml = "[accounts.alice]\ntoken = \"unterminated\n";
655        let err = Config::from_toml(toml).unwrap_err().to_string();
656        assert!(!err.contains("unterminated"), "error leaked token: {err}");
657    }
658
659    #[test]
660    fn validation_env_prefix_collision_errors() {
661        // 'my-lib' and 'my_lib' both map to SUNO_MY_LIB_* and must be rejected.
662        let toml = "
663            [accounts.my-lib]
664            [accounts.my_lib]
665        ";
666        assert!(Config::from_toml(toml).is_err());
667    }
668
669    #[test]
670    fn audio_format_display_roundtrip() {
671        for fmt in [AudioFormat::Mp3, AudioFormat::Flac, AudioFormat::Wav] {
672            let s = fmt.to_string();
673            assert_eq!(s.parse::<AudioFormat>().unwrap(), fmt);
674        }
675    }
676}