Skip to main content

studio_worker/
config.rs

1//! Persistent config in `~/.config/minis-studio-worker/config.toml` (Linux/macOS)
2//! or `%APPDATA%\minis-studio-worker\config.toml` (Windows).
3//!
4//! Every load/save emits a structured tracing breadcrumb so operators
5//! can tell from `journalctl` which file the worker actually consulted
6//! (and whether the file existed or was freshly bootstrapped with
7//! defaults).  The events deliberately omit the two secret fields
8//! — `bootstrap_token` and `auth_token` — so logs can be shipped
9//! off-box without leaking credentials.  See `tests/config_tracing.rs`
10//! for the regression contract.
11use anyhow::{anyhow, Context, Result};
12use directories::ProjectDirs;
13use parking_lot::Mutex;
14use serde::{Deserialize, Serialize};
15use std::path::{Path, PathBuf};
16
17/// Tracing target for config persistence events.  Stable so operators
18/// can filter with `RUST_LOG=studio_worker::config=debug`.
19const TRACE_TARGET: &str = "studio_worker::config";
20
21#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct Config {
23    /// Base URL of the studio API (e.g. `https://studio.minis.gg`).
24    pub api_base_url: String,
25    /// Shared secret used only for the first registration.
26    pub bootstrap_token: String,
27    /// Worker id, filled in by `register`.
28    #[serde(default, skip_serializing_if = "Option::is_none")]
29    pub worker_id: Option<String>,
30    /// Per-worker token issued at registration.
31    #[serde(default, skip_serializing_if = "Option::is_none")]
32    pub auth_token: Option<String>,
33    /// VRAM threshold the worker reports as its max claim size, in GB.
34    pub vram_threshold_gb: f32,
35    /// Whether to auto-launch the run loop at boot via the OS service.
36    pub auto_start: bool,
37    /// Whether the worker should claim new jobs.
38    pub auto_enabled: bool,
39    /// Engine identifier: `synthetic`, `gradio`, `multi`, or — when
40    /// built with the matching cargo feature — `llama`, `whisper`,
41    /// `image-candle`, `video`, `tts`.
42    pub engine: String,
43    /// When `engine = "multi"`, the per-modality engines to combine.
44    /// First engine that claims support for a job's kind+model wins.
45    #[serde(default)]
46    pub engines: Vec<String>,
47    /// Local Gradio endpoint URL when `engine = "gradio"`.
48    #[serde(default, skip_serializing_if = "Option::is_none")]
49    pub gradio_endpoint_url: Option<String>,
50    /// Explicit override of supported models.  When empty, the engine
51    /// reports its native list.
52    #[serde(default)]
53    pub supported_models_override: Vec<String>,
54    /// Periodically check the release feed and auto-install newer
55    /// versions when no job is running.
56    #[serde(default = "default_auto_update_enabled")]
57    pub auto_update_enabled: bool,
58    /// How often (seconds) to check the release feed.
59    #[serde(default = "default_auto_update_interval")]
60    pub auto_update_interval_secs: u64,
61    /// GitHub Releases feed for this binary.
62    #[serde(default = "default_auto_update_feed")]
63    pub auto_update_feed: String,
64    /// Whether to upgrade to pre-release versions.
65    #[serde(default)]
66    pub auto_update_prerelease: bool,
67    /// Root directory for downloaded model files (per-engine
68    /// subdirectories: `llm/`, `stt/`, `tts/`, `image/`, `video/`).
69    /// Defaults to the OS cache dir.
70    #[serde(default, skip_serializing_if = "Option::is_none")]
71    pub models_root: Option<std::path::PathBuf>,
72}
73
74fn default_auto_update_enabled() -> bool {
75    true
76}
77fn default_auto_update_interval() -> u64 {
78    1800
79}
80fn default_auto_update_feed() -> String {
81    "https://api.github.com/repos/webbertakken/studio-worker/releases".into()
82}
83
84impl Default for Config {
85    fn default() -> Self {
86        Self {
87            api_base_url: "http://localhost:9790".into(),
88            bootstrap_token: "dev-bootstrap-token".into(),
89            worker_id: None,
90            auth_token: None,
91            vram_threshold_gb: 12.0,
92            auto_start: true,
93            auto_enabled: true,
94            engine: "synthetic".into(),
95            engines: Vec::new(),
96            gradio_endpoint_url: None,
97            supported_models_override: Vec::new(),
98            auto_update_enabled: default_auto_update_enabled(),
99            auto_update_interval_secs: default_auto_update_interval(),
100            auto_update_feed: default_auto_update_feed(),
101            auto_update_prerelease: false,
102            models_root: None,
103        }
104    }
105}
106
107fn default_config_path() -> Result<PathBuf> {
108    let dirs = ProjectDirs::from("gg", "minis", "minis-studio-worker")
109        .ok_or_else(|| anyhow!("cannot resolve config directory"))?;
110    Ok(dirs.config_dir().join("config.toml"))
111}
112
113pub fn resolve_path(override_path: Option<&str>) -> Result<PathBuf> {
114    if let Some(p) = override_path {
115        Ok(PathBuf::from(p))
116    } else {
117        default_config_path()
118    }
119}
120
121pub fn load(override_path: Option<&str>) -> Result<(Config, PathBuf)> {
122    let path = resolve_path(override_path)?;
123    if !path.exists() {
124        let cfg = Config::default();
125        save(&cfg, &path)?;
126        tracing::info!(
127            target: TRACE_TARGET,
128            op = "load",
129            source = "default_created",
130            config_path = %path.display(),
131            engine = %cfg.engine,
132            api_base_url = %cfg.api_base_url,
133            vram_threshold_gb = cfg.vram_threshold_gb,
134            auto_enabled = cfg.auto_enabled,
135            "config file missing — bootstrapped defaults"
136        );
137        return Ok((cfg, path));
138    }
139    let text =
140        std::fs::read_to_string(&path).with_context(|| format!("reading {}", path.display()))?;
141    let cfg: Config = toml::from_str(&text).with_context(|| "parsing config.toml")?;
142    tracing::debug!(
143        target: TRACE_TARGET,
144        op = "load",
145        source = "existing_file",
146        config_path = %path.display(),
147        engine = %cfg.engine,
148        api_base_url = %cfg.api_base_url,
149        vram_threshold_gb = cfg.vram_threshold_gb,
150        auto_enabled = cfg.auto_enabled,
151        worker_id = cfg.worker_id.as_deref().unwrap_or("(unregistered)"),
152        has_auth_token = cfg.auth_token.is_some(),
153        "loaded config from disk"
154    );
155    Ok((cfg, path))
156}
157
158pub fn save(cfg: &Config, path: &Path) -> Result<()> {
159    if let Some(parent) = path.parent() {
160        std::fs::create_dir_all(parent)
161            .with_context(|| format!("creating {}", parent.display()))?;
162    }
163    let text = toml::to_string_pretty(cfg).with_context(|| "serialising config")?;
164    let bytes = text.len();
165    std::fs::write(path, text).with_context(|| format!("writing {}", path.display()))?;
166    tracing::debug!(
167        target: TRACE_TARGET,
168        op = "save",
169        config_path = %path.display(),
170        engine = %cfg.engine,
171        vram_threshold_gb = cfg.vram_threshold_gb,
172        auto_enabled = cfg.auto_enabled,
173        bytes = bytes,
174        "persisted config to disk"
175    );
176    Ok(())
177}
178
179/// Wrap a Config in a mutex for use across the runtime.
180pub type SharedConfig = std::sync::Arc<Mutex<Config>>;
181
182pub fn shared(cfg: Config) -> SharedConfig {
183    std::sync::Arc::new(Mutex::new(cfg))
184}
185
186#[cfg(test)]
187mod tests {
188    use super::*;
189    use tempfile::tempdir;
190
191    #[test]
192    fn default_values_are_sensible() {
193        let cfg = Config::default();
194        assert_eq!(cfg.engine, "synthetic");
195        assert!(cfg.auto_enabled);
196        assert!(cfg.auto_start);
197        assert!(cfg.auto_update_enabled);
198        assert_eq!(cfg.auto_update_interval_secs, 1800);
199        assert!(!cfg.auto_update_prerelease);
200        assert!(cfg.auto_update_feed.contains("webbertakken/studio-worker"));
201        assert_eq!(cfg.vram_threshold_gb, 12.0);
202        assert!(cfg.worker_id.is_none());
203        assert!(cfg.auth_token.is_none());
204    }
205
206    #[test]
207    fn resolve_path_uses_override_when_provided() {
208        let path = resolve_path(Some("/tmp/test-config.toml")).unwrap();
209        assert_eq!(path, PathBuf::from("/tmp/test-config.toml"));
210    }
211
212    #[test]
213    fn resolve_path_defaults_when_no_override() {
214        let path = resolve_path(None).unwrap();
215        let s = path.to_string_lossy();
216        assert!(
217            s.contains("minis-studio-worker") || s.contains("minis.gg.minis-studio-worker"),
218            "unexpected default path: {s}"
219        );
220        assert!(s.ends_with("config.toml"));
221    }
222
223    #[test]
224    fn load_creates_default_when_file_missing() {
225        let dir = tempdir().unwrap();
226        let path = dir.path().join("sub").join("config.toml");
227        let path_str = path.to_string_lossy().to_string();
228        let (cfg, returned_path) = load(Some(&path_str)).unwrap();
229        assert_eq!(returned_path, path);
230        assert_eq!(cfg.engine, "synthetic");
231        // File should have been written.
232        assert!(path.exists());
233    }
234
235    #[test]
236    fn round_trip_via_save_and_load_preserves_fields() {
237        let dir = tempdir().unwrap();
238        let path = dir.path().join("config.toml");
239        let cfg = Config {
240            engine: "gradio".into(),
241            gradio_endpoint_url: Some("http://example.invalid".into()),
242            worker_id: Some("w-123".into()),
243            auth_token: Some("tok-xyz".into()),
244            vram_threshold_gb: 24.0,
245            auto_update_prerelease: true,
246            supported_models_override: vec!["foo".into(), "bar".into()],
247            ..Config::default()
248        };
249        save(&cfg, &path).unwrap();
250
251        let path_str = path.to_string_lossy().to_string();
252        let (loaded, _) = load(Some(&path_str)).unwrap();
253        assert_eq!(loaded.engine, cfg.engine);
254        assert_eq!(loaded.gradio_endpoint_url, cfg.gradio_endpoint_url);
255        assert_eq!(loaded.worker_id, cfg.worker_id);
256        assert_eq!(loaded.auth_token, cfg.auth_token);
257        assert_eq!(loaded.vram_threshold_gb, cfg.vram_threshold_gb);
258        assert_eq!(loaded.auto_update_prerelease, cfg.auto_update_prerelease);
259        assert_eq!(
260            loaded.supported_models_override,
261            cfg.supported_models_override
262        );
263    }
264
265    #[test]
266    fn shared_wraps_in_arc_mutex() {
267        let cfg = Config::default();
268        let shared = shared(cfg.clone());
269        let guard = shared.lock();
270        assert_eq!(guard.engine, cfg.engine);
271    }
272
273    #[test]
274    fn load_returns_error_on_malformed_toml() {
275        let dir = tempdir().unwrap();
276        let path = dir.path().join("config.toml");
277        std::fs::write(&path, "this :: is = not = toml = :").unwrap();
278        let path_str = path.to_string_lossy().to_string();
279        let err = load(Some(&path_str)).unwrap_err();
280        assert!(err.to_string().contains("parsing config.toml"));
281    }
282}