Skip to main content

suno_core/
config.rs

1//! Configuration model and precedence resolution.
2//!
3//! Parses a TOML string and merges in environment variables and CLI flag
4//! overrides supplied by the caller. Performs no disk or environment IO.
5
6use std::collections::HashMap;
7use std::fmt;
8use std::path::Path;
9use std::str::FromStr;
10
11use serde::{Deserialize, Serialize};
12
13use crate::error::{Error, Result};
14
15/// Audio format for downloaded clips.
16#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
17#[serde(rename_all = "lowercase")]
18pub enum AudioFormat {
19    Mp3,
20    #[default]
21    Flac,
22    Wav,
23}
24
25impl FromStr for AudioFormat {
26    type Err = Error;
27
28    fn from_str(s: &str) -> Result<Self> {
29        match s.to_ascii_lowercase().as_str() {
30            "mp3" => Ok(Self::Mp3),
31            "flac" => Ok(Self::Flac),
32            "wav" => Ok(Self::Wav),
33            other => Err(Error::Config(format!("unknown format '{other}'"))),
34        }
35    }
36}
37
38impl fmt::Display for AudioFormat {
39    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
40        match self {
41            Self::Mp3 => f.write_str("mp3"),
42            Self::Flac => f.write_str("flac"),
43            Self::Wav => f.write_str("wav"),
44        }
45    }
46}
47
48/// Global default settings applied when no account or source override applies.
49#[derive(Debug, Clone, Default, Deserialize)]
50pub struct Defaults {
51    pub format: Option<AudioFormat>,
52    pub concurrency: Option<u32>,
53    pub retries: Option<u32>,
54    pub min_newest: Option<u32>,
55    pub animated_covers: Option<bool>,
56}
57
58/// Per-source overridable settings within an account.
59#[derive(Debug, Clone, Default, Deserialize)]
60pub struct SourceConfig {
61    pub format: Option<AudioFormat>,
62    pub concurrency: Option<u32>,
63    pub retries: Option<u32>,
64    pub min_newest: Option<u32>,
65    pub animated_covers: Option<bool>,
66}
67
68/// Configuration for a single named account.
69#[derive(Debug, Clone, Default, Deserialize)]
70pub struct AccountConfig {
71    pub token: Option<String>,
72    pub root: Option<String>,
73    /// Optional Suno user id to assert this account authenticates as, refusing
74    /// to run on a mismatch (a belt-and-braces check alongside the on-disk
75    /// owner pin in the lineage store).
76    pub account_id: Option<String>,
77    pub format: Option<AudioFormat>,
78    pub concurrency: Option<u32>,
79    pub retries: Option<u32>,
80    pub min_newest: Option<u32>,
81    pub animated_covers: Option<bool>,
82    #[serde(default)]
83    pub sources: HashMap<String, SourceConfig>,
84}
85
86/// Top-level configuration parsed from a TOML file.
87#[derive(Debug, Clone, Default, Deserialize)]
88pub struct Config {
89    #[serde(default)]
90    pub defaults: Defaults,
91    #[serde(default)]
92    pub accounts: HashMap<String, AccountConfig>,
93}
94
95impl Config {
96    /// Parse `toml_str` and validate the result.
97    ///
98    /// Validation rejects any pair of accounts whose root directories nest
99    /// inside one another. Duplicate account labels are rejected by the TOML
100    /// parser itself.
101    pub fn from_toml(toml_str: &str) -> Result<Self> {
102        let config: Self = toml::from_str(toml_str).map_err(|e| {
103            // Strip source-context lines (those containing " | ") to prevent
104            // token values from being echoed in error messages.
105            let raw = e.to_string();
106            let msg = raw
107                .lines()
108                .filter(|l| !l.contains(" | "))
109                .collect::<Vec<_>>()
110                .join("\n")
111                .trim()
112                .to_owned();
113            Error::Config(if msg.is_empty() {
114                "parse error".into()
115            } else {
116                msg
117            })
118        })?;
119        config.validate()?;
120        Ok(config)
121    }
122
123    fn validate(&self) -> Result<()> {
124        let roots: Vec<(&str, &str)> = self
125            .accounts
126            .iter()
127            .filter_map(|(label, acc)| acc.root.as_deref().map(|r| (label.as_str(), r)))
128            .collect();
129
130        for (i, (label_a, root_a)) in roots.iter().enumerate() {
131            for (label_b, root_b) in roots.iter().skip(i + 1) {
132                let a = Path::new(root_a);
133                let b = Path::new(root_b);
134                if a.starts_with(b) || b.starts_with(a) {
135                    return Err(Error::Config(format!(
136                        "account roots nest: '{label_a}' ({root_a}) and '{label_b}' ({root_b})"
137                    )));
138                }
139            }
140        }
141
142        let mut prefix_seen: HashMap<String, &str> = HashMap::new();
143        for label in self.accounts.keys() {
144            let prefix = label_to_env(label);
145            if let Some(other) = prefix_seen.get(&prefix) {
146                return Err(Error::Config(format!(
147                    "accounts '{label}' and '{other}' share env prefix '{prefix}'"
148                )));
149            }
150            prefix_seen.insert(prefix, label.as_str());
151        }
152
153        Ok(())
154    }
155
156    /// Compute effective settings for `account`, optionally scoped to `source`.
157    ///
158    /// The caller supplies the full environment map and any CLI flag overrides.
159    /// Precedence per field: flag > per-account env > global env > per-source
160    /// file > per-account file > global file defaults > compiled default.
161    pub fn resolve(
162        &self,
163        account: &str,
164        source: Option<&str>,
165        env: &HashMap<String, String>,
166        flags: &FlagOverrides,
167    ) -> Result<EffectiveSettings> {
168        let acc = self
169            .accounts
170            .get(account)
171            .ok_or_else(|| Error::Config(format!("account '{account}' not found")))?;
172
173        let src = source.and_then(|s| acc.sources.get(s));
174        let label_env = label_to_env(account);
175
176        // Look up per-account env first, falling back to global.
177        let env_val = |suffix: &str| -> Option<&str> {
178            env.get(&format!("SUNO_{label_env}_{suffix}"))
179                .or_else(|| env.get(&format!("SUNO_{suffix}")))
180                .map(String::as_str)
181        };
182
183        let format_from_env = env_val("FORMAT")
184            .map(str::parse::<AudioFormat>)
185            .transpose()?;
186
187        let format = flags
188            .format
189            .or(format_from_env)
190            .or_else(|| src.and_then(|s| s.format))
191            .or(acc.format)
192            .or(self.defaults.format)
193            .unwrap_or(AudioFormat::Flac);
194
195        let concurrency = resolve_u32(
196            flags.concurrency,
197            env_val("CONCURRENCY"),
198            src.and_then(|s| s.concurrency),
199            acc.concurrency,
200            self.defaults.concurrency,
201            4,
202            "CONCURRENCY",
203        )?;
204
205        let retries = resolve_u32(
206            flags.retries,
207            env_val("RETRIES"),
208            src.and_then(|s| s.retries),
209            acc.retries,
210            self.defaults.retries,
211            3,
212            "RETRIES",
213        )?;
214
215        let min_newest = resolve_u32(
216            flags.min_newest,
217            env_val("MIN_NEWEST"),
218            src.and_then(|s| s.min_newest),
219            acc.min_newest,
220            self.defaults.min_newest,
221            1,
222            "MIN_NEWEST",
223        )?;
224
225        let animated_covers = resolve_bool(
226            flags.animated_covers,
227            env_val("ANIMATED_COVERS"),
228            src.and_then(|s| s.animated_covers),
229            acc.animated_covers,
230            self.defaults.animated_covers,
231            false,
232            "ANIMATED_COVERS",
233        )?;
234
235        let token = flags
236            .token
237            .clone()
238            .or_else(|| env.get(&format!("SUNO_{label_env}_TOKEN")).cloned())
239            .or_else(|| env.get("SUNO_TOKEN").cloned())
240            .or_else(|| acc.token.clone());
241
242        Ok(EffectiveSettings {
243            token,
244            account_id: acc.account_id.clone(),
245            format,
246            concurrency,
247            retries,
248            min_newest,
249            animated_covers,
250        })
251    }
252}
253
254fn resolve_u32(
255    flag: Option<u32>,
256    env_str: Option<&str>,
257    src: Option<u32>,
258    acc: Option<u32>,
259    defaults: Option<u32>,
260    compiled: u32,
261    name: &str,
262) -> Result<u32> {
263    if let Some(v) = flag {
264        return Ok(v);
265    }
266    if let Some(s) = env_str {
267        return s
268            .parse()
269            .map_err(|_| Error::Config(format!("invalid {name}: '{s}'")));
270    }
271    Ok(src.or(acc).or(defaults).unwrap_or(compiled))
272}
273
274fn resolve_bool(
275    flag: Option<bool>,
276    env_str: Option<&str>,
277    src: Option<bool>,
278    acc: Option<bool>,
279    defaults: Option<bool>,
280    compiled: bool,
281    name: &str,
282) -> Result<bool> {
283    if let Some(v) = flag {
284        return Ok(v);
285    }
286    if let Some(s) = env_str {
287        return s
288            .parse()
289            .map_err(|_| Error::Config(format!("invalid {name}: '{s}'")));
290    }
291    Ok(src.or(acc).or(defaults).unwrap_or(compiled))
292}
293
294/// Convert an account label to its environment variable prefix.
295///
296/// `my-lib` becomes `MY_LIB`.
297fn label_to_env(label: &str) -> String {
298    label.to_ascii_uppercase().replace('-', "_")
299}
300
301/// CLI flag overrides passed to [`Config::resolve`]. `None` means the flag
302/// was not provided.
303#[derive(Debug, Default)]
304pub struct FlagOverrides {
305    pub token: Option<String>,
306    pub format: Option<AudioFormat>,
307    pub concurrency: Option<u32>,
308    pub retries: Option<u32>,
309    pub min_newest: Option<u32>,
310    pub animated_covers: Option<bool>,
311}
312
313/// Resolved effective settings for one account/source combination.
314#[derive(Debug, Clone, PartialEq)]
315pub struct EffectiveSettings {
316    pub token: Option<String>,
317    /// The optional configured account id assertion (see [`AccountConfig`]).
318    pub account_id: Option<String>,
319    pub format: AudioFormat,
320    pub concurrency: u32,
321    pub retries: u32,
322    pub min_newest: u32,
323    pub animated_covers: bool,
324}
325
326#[cfg(test)]
327mod tests {
328    use super::*;
329
330    fn no_env() -> HashMap<String, String> {
331        HashMap::new()
332    }
333
334    fn no_flags() -> FlagOverrides {
335        FlagOverrides::default()
336    }
337
338    #[test]
339    fn parse_empty_toml() {
340        let cfg = Config::from_toml("").unwrap();
341        assert!(cfg.accounts.is_empty());
342    }
343
344    #[test]
345    fn parse_basic_account() {
346        let toml = r#"
347            [accounts.alice]
348            token = "tok"
349            root = "/music"
350        "#;
351        let cfg = Config::from_toml(toml).unwrap();
352        let acc = &cfg.accounts["alice"];
353        assert_eq!(acc.token.as_deref(), Some("tok"));
354        assert_eq!(acc.root.as_deref(), Some("/music"));
355    }
356
357    #[test]
358    fn account_id_parses_and_resolves() {
359        let toml = r#"
360            [accounts.alice]
361            token = "tok"
362            root = "/music"
363            account_id = "user_abc123"
364        "#;
365        let cfg = Config::from_toml(toml).unwrap();
366        assert_eq!(
367            cfg.accounts["alice"].account_id.as_deref(),
368            Some("user_abc123")
369        );
370        let eff = cfg.resolve("alice", None, &no_env(), &no_flags()).unwrap();
371        assert_eq!(eff.account_id.as_deref(), Some("user_abc123"));
372    }
373
374    #[test]
375    fn parse_defaults_section() {
376        let toml = r#"
377            [defaults]
378            format = "mp3"
379            concurrency = 8
380            retries = 5
381            min_newest = 2
382            animated_covers = true
383        "#;
384        let cfg = Config::from_toml(toml).unwrap();
385        assert_eq!(cfg.defaults.format, Some(AudioFormat::Mp3));
386        assert_eq!(cfg.defaults.concurrency, Some(8));
387        assert_eq!(cfg.defaults.retries, Some(5));
388        assert_eq!(cfg.defaults.min_newest, Some(2));
389        assert_eq!(cfg.defaults.animated_covers, Some(true));
390    }
391
392    #[test]
393    fn compiled_defaults_when_nothing_set() {
394        let toml = "[accounts.alice]\n";
395        let cfg = Config::from_toml(toml).unwrap();
396        let eff = cfg.resolve("alice", None, &no_env(), &no_flags()).unwrap();
397        assert_eq!(
398            eff,
399            EffectiveSettings {
400                token: None,
401                account_id: None,
402                format: AudioFormat::Flac,
403                concurrency: 4,
404                retries: 3,
405                min_newest: 1,
406                animated_covers: false,
407            }
408        );
409    }
410
411    #[test]
412    fn file_defaults_override_compiled() {
413        let toml = r#"
414            [defaults]
415            format = "mp3"
416            concurrency = 8
417
418            [accounts.alice]
419        "#;
420        let cfg = Config::from_toml(toml).unwrap();
421        let eff = cfg.resolve("alice", None, &no_env(), &no_flags()).unwrap();
422        assert_eq!(eff.format, AudioFormat::Mp3);
423        assert_eq!(eff.concurrency, 8);
424        assert_eq!(eff.retries, 3); // compiled default
425    }
426
427    #[test]
428    fn account_settings_override_defaults() {
429        let toml = r#"
430            [defaults]
431            format = "mp3"
432
433            [accounts.alice]
434            format = "wav"
435        "#;
436        let cfg = Config::from_toml(toml).unwrap();
437        let eff = cfg.resolve("alice", None, &no_env(), &no_flags()).unwrap();
438        assert_eq!(eff.format, AudioFormat::Wav);
439    }
440
441    #[test]
442    fn per_source_overrides_account() {
443        let toml = r#"
444            [accounts.alice]
445            format = "flac"
446
447            [accounts.alice.sources.liked]
448            format = "mp3"
449        "#;
450        let cfg = Config::from_toml(toml).unwrap();
451        let eff = cfg
452            .resolve("alice", Some("liked"), &no_env(), &no_flags())
453            .unwrap();
454        assert_eq!(eff.format, AudioFormat::Mp3);
455    }
456
457    #[test]
458    fn unknown_source_falls_back_to_account() {
459        let toml = r#"
460            [accounts.alice]
461            format = "wav"
462        "#;
463        let cfg = Config::from_toml(toml).unwrap();
464        let eff = cfg
465            .resolve("alice", Some("nonexistent"), &no_env(), &no_flags())
466            .unwrap();
467        assert_eq!(eff.format, AudioFormat::Wav);
468    }
469
470    #[test]
471    fn global_env_overrides_file() {
472        let toml = r#"
473            [accounts.alice]
474            format = "flac"
475        "#;
476        let cfg = Config::from_toml(toml).unwrap();
477        let env: HashMap<String, String> =
478            [("SUNO_FORMAT".into(), "mp3".into())].into_iter().collect();
479        let eff = cfg.resolve("alice", None, &env, &no_flags()).unwrap();
480        assert_eq!(eff.format, AudioFormat::Mp3);
481    }
482
483    #[test]
484    fn per_account_env_overrides_global_env() {
485        let toml = "[accounts.alice]\n";
486        let cfg = Config::from_toml(toml).unwrap();
487        let env: HashMap<String, String> = [
488            ("SUNO_FORMAT".into(), "mp3".into()),
489            ("SUNO_ALICE_FORMAT".into(), "wav".into()),
490        ]
491        .into_iter()
492        .collect();
493        let eff = cfg.resolve("alice", None, &env, &no_flags()).unwrap();
494        assert_eq!(eff.format, AudioFormat::Wav);
495    }
496
497    #[test]
498    fn per_account_env_label_uppersnakedcase() {
499        let toml = "[accounts.my-lib]\n";
500        let cfg = Config::from_toml(toml).unwrap();
501        let env: HashMap<String, String> = [("SUNO_MY_LIB_FORMAT".into(), "wav".into())]
502            .into_iter()
503            .collect();
504        let eff = cfg.resolve("my-lib", None, &env, &no_flags()).unwrap();
505        assert_eq!(eff.format, AudioFormat::Wav);
506    }
507
508    #[test]
509    fn flag_overrides_env_and_file() {
510        let toml = r#"
511            [accounts.alice]
512            format = "flac"
513        "#;
514        let cfg = Config::from_toml(toml).unwrap();
515        let env: HashMap<String, String> =
516            [("SUNO_FORMAT".into(), "mp3".into())].into_iter().collect();
517        let flags = FlagOverrides {
518            format: Some(AudioFormat::Wav),
519            ..Default::default()
520        };
521        let eff = cfg.resolve("alice", None, &env, &flags).unwrap();
522        assert_eq!(eff.format, AudioFormat::Wav);
523    }
524
525    #[test]
526    fn token_precedence() {
527        let toml = r#"
528            [accounts.alice]
529            token = "file_tok"
530        "#;
531        let cfg = Config::from_toml(toml).unwrap();
532
533        // env overrides file
534        let env: HashMap<String, String> = [("SUNO_TOKEN".into(), "env_tok".into())]
535            .into_iter()
536            .collect();
537        let eff = cfg.resolve("alice", None, &env, &no_flags()).unwrap();
538        assert_eq!(eff.token.as_deref(), Some("env_tok"));
539
540        // flag overrides env
541        let flags = FlagOverrides {
542            token: Some("flag_tok".into()),
543            ..Default::default()
544        };
545        let eff = cfg.resolve("alice", None, &env, &flags).unwrap();
546        assert_eq!(eff.token.as_deref(), Some("flag_tok"));
547    }
548
549    #[test]
550    fn per_account_token_env_overrides_global() {
551        let toml = "[accounts.alice]\n";
552        let cfg = Config::from_toml(toml).unwrap();
553        let env: HashMap<String, String> = [
554            ("SUNO_TOKEN".into(), "global".into()),
555            ("SUNO_ALICE_TOKEN".into(), "per_account".into()),
556        ]
557        .into_iter()
558        .collect();
559        let eff = cfg.resolve("alice", None, &env, &no_flags()).unwrap();
560        assert_eq!(eff.token.as_deref(), Some("per_account"));
561    }
562
563    #[test]
564    fn invalid_env_u32_errors() {
565        let toml = "[accounts.alice]\n";
566        let cfg = Config::from_toml(toml).unwrap();
567        let env: HashMap<String, String> = [("SUNO_CONCURRENCY".into(), "many".into())]
568            .into_iter()
569            .collect();
570        assert!(cfg.resolve("alice", None, &env, &no_flags()).is_err());
571    }
572
573    #[test]
574    fn animated_covers_defaults_off_and_follows_precedence() {
575        // Compiled default is off.
576        let cfg = Config::from_toml("[accounts.alice]\n").unwrap();
577        let eff = cfg.resolve("alice", None, &no_env(), &no_flags()).unwrap();
578        assert!(!eff.animated_covers);
579
580        // File default on; per-source off; env on; flag off — flag wins.
581        let toml = r#"
582            [defaults]
583            animated_covers = true
584
585            [accounts.alice.sources.liked]
586            animated_covers = false
587        "#;
588        let cfg = Config::from_toml(toml).unwrap();
589
590        // File default (defaults) turns it on for an unscoped resolve.
591        let eff = cfg.resolve("alice", None, &no_env(), &no_flags()).unwrap();
592        assert!(eff.animated_covers);
593
594        // Per-source file setting overrides the file default.
595        let eff = cfg
596            .resolve("alice", Some("liked"), &no_env(), &no_flags())
597            .unwrap();
598        assert!(!eff.animated_covers);
599
600        // Env overrides file (even the per-source off).
601        let env: HashMap<String, String> = [("SUNO_ANIMATED_COVERS".into(), "true".into())]
602            .into_iter()
603            .collect();
604        let eff = cfg
605            .resolve("alice", Some("liked"), &env, &no_flags())
606            .unwrap();
607        assert!(eff.animated_covers);
608
609        // Flag overrides env.
610        let flags = FlagOverrides {
611            animated_covers: Some(false),
612            ..Default::default()
613        };
614        let eff = cfg.resolve("alice", Some("liked"), &env, &flags).unwrap();
615        assert!(!eff.animated_covers);
616    }
617
618    #[test]
619    fn invalid_env_bool_errors() {
620        let toml = "[accounts.alice]\n";
621        let cfg = Config::from_toml(toml).unwrap();
622        let env: HashMap<String, String> = [("SUNO_ANIMATED_COVERS".into(), "yes".into())]
623            .into_iter()
624            .collect();
625        assert!(cfg.resolve("alice", None, &env, &no_flags()).is_err());
626    }
627
628    #[test]
629    fn unknown_account_errors() {
630        let cfg = Config::from_toml("").unwrap();
631        assert!(cfg.resolve("nobody", None, &no_env(), &no_flags()).is_err());
632    }
633
634    #[test]
635    fn validation_nested_roots() {
636        let toml = r#"
637            [accounts.alice]
638            root = "/music"
639
640            [accounts.bob]
641            root = "/music/bob"
642        "#;
643        assert!(Config::from_toml(toml).is_err());
644    }
645
646    #[test]
647    fn validation_non_nested_roots_ok() {
648        let toml = r#"
649            [accounts.alice]
650            root = "/music/alice"
651
652            [accounts.bob]
653            root = "/music/bob"
654        "#;
655        assert!(Config::from_toml(toml).is_ok());
656    }
657
658    #[test]
659    fn invalid_toml_errors() {
660        assert!(Config::from_toml("not valid toml ][").is_err());
661    }
662
663    #[test]
664    fn duplicate_account_label_errors() {
665        // The TOML spec prohibits duplicate keys; the parser must reject this.
666        let toml = "
667            [accounts.alice]
668            token = \"tok1\"
669
670            [accounts.alice]
671            token = \"tok2\"
672        ";
673        assert!(Config::from_toml(toml).is_err());
674    }
675
676    #[test]
677    fn parse_error_does_not_echo_token() {
678        // A malformed token line must not include the raw value in the error.
679        let toml = "[accounts.alice]\ntoken = \"unterminated\n";
680        let err = Config::from_toml(toml).unwrap_err().to_string();
681        assert!(!err.contains("unterminated"), "error leaked token: {err}");
682    }
683
684    #[test]
685    fn validation_env_prefix_collision_errors() {
686        // 'my-lib' and 'my_lib' both map to SUNO_MY_LIB_* and must be rejected.
687        let toml = "
688            [accounts.my-lib]
689            [accounts.my_lib]
690        ";
691        assert!(Config::from_toml(toml).is_err());
692    }
693
694    #[test]
695    fn audio_format_display_roundtrip() {
696        for fmt in [AudioFormat::Mp3, AudioFormat::Flac, AudioFormat::Wav] {
697            let s = fmt.to_string();
698            assert_eq!(s.parse::<AudioFormat>().unwrap(), fmt);
699        }
700    }
701}