Skip to main content

studio_worker/
config.rs

1//! Persistent config in `~/.config/minis-studio-worker/config.toml` (Linux/macOS)
2//! or `%APPDATA%\minis-studio-worker\config.toml` (Windows).
3//!
4//! Every load/save emits a structured tracing breadcrumb so operators
5//! can tell from `journalctl` which file the worker actually consulted
6//! (and whether the file existed or was freshly bootstrapped with
7//! defaults).  The events deliberately omit the secret fields
8//! (`auth_token`, `registration_secret`) so logs can be shipped
9//! off-box without leaking credentials.  See `tests/config_tracing.rs`
10//! for the regression contract.
11use anyhow::{anyhow, Context, Result};
12use directories::ProjectDirs;
13use parking_lot::Mutex;
14use serde::{Deserialize, Serialize};
15use std::path::{Path, PathBuf};
16
17/// Tracing target for config persistence events.  Stable so operators
18/// can filter with `RUST_LOG=studio_worker::config=debug`.
19const TRACE_TARGET: &str = "studio_worker::config";
20
21#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct Config {
23    /// Base URL of the studio API (e.g. `https://studio.minis.gg`).
24    pub api_base_url: String,
25    /// Worker id, written on operator approval.  Cleared by
26    /// `studio-worker register --reset`.
27    #[serde(default, skip_serializing_if = "Option::is_none")]
28    pub worker_id: Option<String>,
29    /// Per-worker token issued at registration.
30    #[serde(default, skip_serializing_if = "Option::is_none")]
31    pub auth_token: Option<String>,
32    /// VRAM threshold the worker reports as its max claim size, in GB.
33    pub vram_threshold_gb: f32,
34    /// Whether to auto-launch the run loop at boot via the OS service.
35    pub auto_start: bool,
36    /// Whether the worker should claim new jobs.
37    pub auto_enabled: bool,
38    /// Engine identifier: `synthetic`, `gradio`, `multi`, or — when
39    /// built with the matching cargo feature — `llama`, `whisper`,
40    /// `image-candle`, `video`, `tts`.
41    pub engine: String,
42    /// When `engine = "multi"`, the per-modality engines to combine.
43    /// First engine that claims support for a job's kind+model wins.
44    #[serde(default)]
45    pub engines: Vec<String>,
46    /// Local Gradio endpoint URL when `engine = "gradio"`.
47    #[serde(default, skip_serializing_if = "Option::is_none")]
48    pub gradio_endpoint_url: Option<String>,
49    /// Explicit override of supported models.  When empty, the engine
50    /// reports its native list.
51    #[serde(default)]
52    pub supported_models_override: Vec<String>,
53    /// Periodically check the release feed and auto-install newer
54    /// versions when no job is running.
55    #[serde(default = "default_auto_update_enabled")]
56    pub auto_update_enabled: bool,
57    /// How often (seconds) to check the release feed.
58    #[serde(default = "default_auto_update_interval")]
59    pub auto_update_interval_secs: u64,
60    /// GitHub Releases feed for this binary.
61    #[serde(default = "default_auto_update_feed")]
62    pub auto_update_feed: String,
63    /// Whether to upgrade to pre-release versions.
64    #[serde(default)]
65    pub auto_update_prerelease: bool,
66    /// Root directory for downloaded model files (per-engine
67    /// subdirectories: `llm/`, `stt/`, `tts/`, `image/`, `video/`).
68    /// Defaults to the OS cache dir.
69    #[serde(default, skip_serializing_if = "Option::is_none")]
70    pub models_root: Option<std::path::PathBuf>,
71    /// Maximum number of WebSocket reconnect attempts before the
72    /// worker gives up and exits non-zero (relying on the service
73    /// manager to restart it).  `0` = infinite.  Defaults to `5`.
74    #[serde(default, skip_serializing_if = "Option::is_none")]
75    pub ws_reconnect_attempts: Option<u32>,
76    /// Per-install UUID written once on first launch.  Stable across
77    /// worker restarts so the studio can dedup pending requests.
78    /// Internal state, populated by the auto-register flow.
79    #[serde(default, skip_serializing_if = "Option::is_none")]
80    pub install_id: Option<String>,
81    /// Optional human label shown in the studio's Pending Workers
82    /// panel.  Defaults to None; user-settable via
83    /// `studio-worker register --label "..."`.
84    #[serde(default, skip_serializing_if = "Option::is_none")]
85    pub label: Option<String>,
86    /// `requestId` returned by `POST /workers/register-request`.
87    /// Cleared on approval / rejection.  Internal.
88    #[serde(default, skip_serializing_if = "Option::is_none")]
89    pub registration_request_id: Option<String>,
90    /// Bearer secret presented when polling the request status.
91    /// Cleared on approval / rejection.  Internal.
92    #[serde(default, skip_serializing_if = "Option::is_none")]
93    pub registration_secret: Option<String>,
94}
95
96fn default_auto_update_enabled() -> bool {
97    true
98}
99fn default_auto_update_interval() -> u64 {
100    1800
101}
102fn default_auto_update_feed() -> String {
103    "https://api.github.com/repos/webbertakken/studio-worker/releases".into()
104}
105
106impl Default for Config {
107    fn default() -> Self {
108        Self {
109            api_base_url: "https://studio.minis.gg".into(),
110            worker_id: None,
111            auth_token: None,
112            vram_threshold_gb: 12.0,
113            auto_start: true,
114            auto_enabled: true,
115            engine: "synthetic".into(),
116            engines: Vec::new(),
117            gradio_endpoint_url: None,
118            supported_models_override: Vec::new(),
119            auto_update_enabled: default_auto_update_enabled(),
120            auto_update_interval_secs: default_auto_update_interval(),
121            auto_update_feed: default_auto_update_feed(),
122            auto_update_prerelease: false,
123            models_root: None,
124            ws_reconnect_attempts: None,
125            install_id: None,
126            label: None,
127            registration_request_id: None,
128            registration_secret: None,
129        }
130    }
131}
132
133fn default_config_path() -> Result<PathBuf> {
134    let dirs = ProjectDirs::from("gg", "minis", "minis-studio-worker")
135        .ok_or_else(|| anyhow!("cannot resolve config directory"))?;
136    Ok(dirs.config_dir().join("config.toml"))
137}
138
139pub fn resolve_path(override_path: Option<&str>) -> Result<PathBuf> {
140    if let Some(p) = override_path {
141        Ok(PathBuf::from(p))
142    } else {
143        default_config_path()
144    }
145}
146
147pub fn load(override_path: Option<&str>) -> Result<(Config, PathBuf)> {
148    let path = resolve_path(override_path)?;
149    if !path.exists() {
150        let cfg = Config::default();
151        save(&cfg, &path)?;
152        tracing::info!(
153            target: TRACE_TARGET,
154            op = "load",
155            source = "default_created",
156            config_path = %path.display(),
157            engine = %cfg.engine,
158            api_base_url = %cfg.api_base_url,
159            vram_threshold_gb = cfg.vram_threshold_gb,
160            auto_enabled = cfg.auto_enabled,
161            "config file missing — bootstrapped defaults"
162        );
163        return Ok((cfg, path));
164    }
165    let text =
166        std::fs::read_to_string(&path).with_context(|| format!("reading {}", path.display()))?;
167    let cfg: Config = toml::from_str(&text).with_context(|| "parsing config.toml")?;
168    tracing::debug!(
169        target: TRACE_TARGET,
170        op = "load",
171        source = "existing_file",
172        config_path = %path.display(),
173        engine = %cfg.engine,
174        api_base_url = %cfg.api_base_url,
175        vram_threshold_gb = cfg.vram_threshold_gb,
176        auto_enabled = cfg.auto_enabled,
177        worker_id = cfg.worker_id.as_deref().unwrap_or("(unregistered)"),
178        has_auth_token = cfg.auth_token.is_some(),
179        "loaded config from disk"
180    );
181    Ok((cfg, path))
182}
183
184pub fn save(cfg: &Config, path: &Path) -> Result<()> {
185    if let Some(parent) = path.parent() {
186        std::fs::create_dir_all(parent)
187            .with_context(|| format!("creating {}", parent.display()))?;
188    }
189    let text = toml::to_string_pretty(cfg).with_context(|| "serialising config")?;
190    let bytes = text.len();
191    std::fs::write(path, text).with_context(|| format!("writing {}", path.display()))?;
192    tracing::debug!(
193        target: TRACE_TARGET,
194        op = "save",
195        config_path = %path.display(),
196        engine = %cfg.engine,
197        vram_threshold_gb = cfg.vram_threshold_gb,
198        auto_enabled = cfg.auto_enabled,
199        bytes = bytes,
200        "persisted config to disk"
201    );
202    Ok(())
203}
204
205/// Wrap a Config in a mutex for use across the runtime.
206pub type SharedConfig = std::sync::Arc<Mutex<Config>>;
207
208pub fn shared(cfg: Config) -> SharedConfig {
209    std::sync::Arc::new(Mutex::new(cfg))
210}
211
212#[cfg(test)]
213mod tests {
214    use super::*;
215    use tempfile::tempdir;
216
217    #[test]
218    fn default_values_are_sensible() {
219        let cfg = Config::default();
220        assert_eq!(cfg.engine, "synthetic");
221        assert!(cfg.auto_enabled);
222        assert!(cfg.auto_start);
223        assert!(cfg.auto_update_enabled);
224        assert_eq!(cfg.auto_update_interval_secs, 1800);
225        assert!(!cfg.auto_update_prerelease);
226        assert!(cfg.auto_update_feed.contains("webbertakken/studio-worker"));
227        assert_eq!(cfg.vram_threshold_gb, 12.0);
228        assert!(cfg.worker_id.is_none());
229        assert!(cfg.auth_token.is_none());
230    }
231
232    #[test]
233    fn resolve_path_uses_override_when_provided() {
234        let path = resolve_path(Some("/tmp/test-config.toml")).unwrap();
235        assert_eq!(path, PathBuf::from("/tmp/test-config.toml"));
236    }
237
238    #[test]
239    fn resolve_path_defaults_when_no_override() {
240        let path = resolve_path(None).unwrap();
241        let s = path.to_string_lossy();
242        assert!(
243            s.contains("minis-studio-worker") || s.contains("minis.gg.minis-studio-worker"),
244            "unexpected default path: {s}"
245        );
246        assert!(s.ends_with("config.toml"));
247    }
248
249    #[test]
250    fn load_creates_default_when_file_missing() {
251        let dir = tempdir().unwrap();
252        let path = dir.path().join("sub").join("config.toml");
253        let path_str = path.to_string_lossy().to_string();
254        let (cfg, returned_path) = load(Some(&path_str)).unwrap();
255        assert_eq!(returned_path, path);
256        assert_eq!(cfg.engine, "synthetic");
257        // File should have been written.
258        assert!(path.exists());
259    }
260
261    #[test]
262    fn round_trip_via_save_and_load_preserves_fields() {
263        let dir = tempdir().unwrap();
264        let path = dir.path().join("config.toml");
265        let cfg = Config {
266            engine: "gradio".into(),
267            gradio_endpoint_url: Some("http://example.invalid".into()),
268            worker_id: Some("w-123".into()),
269            auth_token: Some("tok-xyz".into()),
270            vram_threshold_gb: 24.0,
271            auto_update_prerelease: true,
272            supported_models_override: vec!["foo".into(), "bar".into()],
273            ..Config::default()
274        };
275        save(&cfg, &path).unwrap();
276
277        let path_str = path.to_string_lossy().to_string();
278        let (loaded, _) = load(Some(&path_str)).unwrap();
279        assert_eq!(loaded.engine, cfg.engine);
280        assert_eq!(loaded.gradio_endpoint_url, cfg.gradio_endpoint_url);
281        assert_eq!(loaded.worker_id, cfg.worker_id);
282        assert_eq!(loaded.auth_token, cfg.auth_token);
283        assert_eq!(loaded.vram_threshold_gb, cfg.vram_threshold_gb);
284        assert_eq!(loaded.auto_update_prerelease, cfg.auto_update_prerelease);
285        assert_eq!(
286            loaded.supported_models_override,
287            cfg.supported_models_override
288        );
289    }
290
291    #[test]
292    fn shared_wraps_in_arc_mutex() {
293        let cfg = Config::default();
294        let shared = shared(cfg.clone());
295        let guard = shared.lock();
296        assert_eq!(guard.engine, cfg.engine);
297    }
298
299    #[test]
300    fn load_returns_error_on_malformed_toml() {
301        let dir = tempdir().unwrap();
302        let path = dir.path().join("config.toml");
303        std::fs::write(&path, "this :: is = not = toml = :").unwrap();
304        let path_str = path.to_string_lossy().to_string();
305        let err = load(Some(&path_str)).unwrap_err();
306        assert!(err.to_string().contains("parsing config.toml"));
307    }
308}