1use anyhow::{anyhow, Context, Result};
27use directories::{ProjectDirs, UserDirs};
28use parking_lot::Mutex;
29use serde::{Deserialize, Serialize};
30use std::path::{Path, PathBuf};
31
32const TRACE_TARGET: &str = "studio_worker::config";
35
36#[derive(Debug, Clone, Serialize, Deserialize)]
37pub struct Config {
38 pub api_base_url: String,
40 #[serde(default, skip_serializing_if = "Option::is_none")]
44 pub worker_id: Option<String>,
45 #[serde(default, skip_serializing_if = "Option::is_none")]
48 pub auth_token: Option<String>,
49 pub vram_threshold_gb: f32,
51 pub auto_start: bool,
53 #[serde(default = "default_auto_update_enabled")]
56 pub auto_update_enabled: bool,
57 #[serde(default = "default_auto_update_interval")]
59 pub auto_update_interval_secs: u64,
60 #[serde(default = "default_auto_update_feed")]
62 pub auto_update_feed: String,
63 #[serde(default)]
65 pub auto_update_prerelease: bool,
66 #[serde(default = "default_models_root_persisted")]
70 pub models_root: PathBuf,
71 #[serde(default, skip_serializing_if = "Option::is_none")]
75 pub ws_reconnect_attempts: Option<u32>,
76 #[serde(default, skip_serializing_if = "Option::is_none")]
80 pub install_id: Option<String>,
81 #[serde(default, skip_serializing_if = "Option::is_none")]
84 pub registration_request_id: Option<String>,
85 #[serde(default, skip_serializing_if = "Option::is_none")]
88 pub registration_secret: Option<String>,
89}
90
91fn default_auto_update_enabled() -> bool {
92 true
93}
94fn default_auto_update_interval() -> u64 {
95 1800
96}
97fn default_auto_update_feed() -> String {
98 "https://api.github.com/repos/webbertakken/studio-worker/releases".into()
99}
100
101pub fn default_models_root() -> PathBuf {
105 if let Some(user) = UserDirs::new() {
106 return user.home_dir().join("models");
107 }
108 std::env::temp_dir().join("studio-worker-models")
109}
110
111fn default_models_root_persisted() -> PathBuf {
112 default_models_root()
113}
114
115fn expand_home(path: PathBuf) -> PathBuf {
120 let s = path.to_string_lossy();
121 if s == "~" {
122 return UserDirs::new()
123 .map(|d| d.home_dir().to_path_buf())
124 .unwrap_or(path);
125 }
126 if let Some(rest) = s.strip_prefix("~/") {
127 if let Some(d) = UserDirs::new() {
128 return d.home_dir().join(rest);
129 }
130 }
131 path
132}
133
134impl Default for Config {
135 fn default() -> Self {
136 Self {
137 api_base_url: "https://studio.minis.gg/".into(),
138 worker_id: None,
139 auth_token: None,
140 vram_threshold_gb: 12.0,
141 auto_start: true,
142 auto_update_enabled: default_auto_update_enabled(),
143 auto_update_interval_secs: default_auto_update_interval(),
144 auto_update_feed: default_auto_update_feed(),
145 auto_update_prerelease: false,
146 models_root: default_models_root(),
147 ws_reconnect_attempts: None,
148 install_id: None,
149 registration_request_id: None,
150 registration_secret: None,
151 }
152 }
153}
154
155fn default_config_path() -> Result<PathBuf> {
156 let dirs = ProjectDirs::from("gg", "minis", "minis-studio-worker")
157 .ok_or_else(|| anyhow!("cannot resolve config directory"))?;
158 Ok(dirs.config_dir().join("config.toml"))
159}
160
161pub fn resolve_path(override_path: Option<&str>) -> Result<PathBuf> {
162 if let Some(p) = override_path {
163 Ok(PathBuf::from(p))
164 } else {
165 default_config_path()
166 }
167}
168
169pub fn load(override_path: Option<&str>) -> Result<(Config, PathBuf)> {
170 let path = resolve_path(override_path)?;
171 if !path.exists() {
172 let cfg = Config::default();
173 save(&cfg, &path)?;
174 tracing::info!(
175 target: TRACE_TARGET,
176 op = "load",
177 source = "default_created",
178 config_path = %path.display(),
179 api_base_url = %cfg.api_base_url,
180 vram_threshold_gb = cfg.vram_threshold_gb,
181 auto_start = cfg.auto_start,
182 models_root = %cfg.models_root.display(),
183 "config file missing — bootstrapped defaults"
184 );
185 return Ok((cfg, path));
186 }
187 let text = match std::fs::read_to_string(&path) {
188 Ok(text) => text,
189 Err(e) => {
190 tracing::warn!(
194 target: TRACE_TARGET,
195 op = "load",
196 config_path = %path.display(),
197 error = %e,
198 "failed to read config file"
199 );
200 return Err(e).with_context(|| format!("reading {}", path.display()));
201 }
202 };
203 let mut cfg: Config = match toml::from_str(&text) {
204 Ok(cfg) => cfg,
205 Err(e) => {
206 tracing::warn!(
212 target: TRACE_TARGET,
213 op = "load",
214 config_path = %path.display(),
215 "config file is not valid TOML"
216 );
217 return Err(e).context("parsing config.toml");
218 }
219 };
220 cfg.models_root = expand_home(std::mem::take(&mut cfg.models_root));
221 tracing::debug!(
222 target: TRACE_TARGET,
223 op = "load",
224 source = "existing_file",
225 config_path = %path.display(),
226 api_base_url = %cfg.api_base_url,
227 vram_threshold_gb = cfg.vram_threshold_gb,
228 auto_start = cfg.auto_start,
229 models_root = %cfg.models_root.display(),
230 worker_id = cfg.worker_id.as_deref().unwrap_or("(unregistered)"),
231 has_auth_token = cfg.auth_token.is_some(),
232 "loaded config from disk"
233 );
234 Ok((cfg, path))
235}
236
237pub fn save(cfg: &Config, path: &Path) -> Result<()> {
238 match write_config(cfg, path) {
239 Ok(bytes) => {
240 tracing::debug!(
241 target: TRACE_TARGET,
242 op = "save",
243 config_path = %path.display(),
244 vram_threshold_gb = cfg.vram_threshold_gb,
245 auto_start = cfg.auto_start,
246 models_root = %cfg.models_root.display(),
247 bytes = bytes,
248 "persisted config to disk"
249 );
250 Ok(())
251 }
252 Err(e) => {
253 tracing::warn!(
260 target: TRACE_TARGET,
261 op = "save",
262 config_path = %path.display(),
263 error = %e,
264 "failed to persist config to disk"
265 );
266 Err(e)
267 }
268 }
269}
270
271fn write_config(cfg: &Config, path: &Path) -> Result<usize> {
276 if let Some(parent) = path.parent() {
277 std::fs::create_dir_all(parent)
278 .with_context(|| format!("creating {}", parent.display()))?;
279 }
280 let text = toml::to_string_pretty(cfg).with_context(|| "serialising config")?;
281 let bytes = text.len();
282 write_atomic(path, text.as_bytes())?;
283 Ok(bytes)
284}
285
286fn write_atomic(path: &Path, bytes: &[u8]) -> Result<()> {
303 use std::io::Write as _;
304 let dir = match path.parent() {
305 Some(p) if !p.as_os_str().is_empty() => p,
306 _ => Path::new("."),
307 };
308 let mut tmp = tempfile::NamedTempFile::new_in(dir)
309 .with_context(|| format!("creating temp file in {}", dir.display()))?;
310 tmp.write_all(bytes)
311 .with_context(|| "writing temp config")?;
312 tmp.as_file()
313 .sync_all()
314 .with_context(|| "flushing temp config to disk")?;
315 tmp.persist(path)
316 .map_err(|e| anyhow!("atomically replacing {}: {}", path.display(), e.error))?;
317 Ok(())
318}
319
320pub type SharedConfig = std::sync::Arc<Mutex<Config>>;
322
323pub fn shared(cfg: Config) -> SharedConfig {
324 std::sync::Arc::new(Mutex::new(cfg))
325}
326
327#[cfg(test)]
328mod tests {
329 use super::*;
330 use tempfile::tempdir;
331
332 #[test]
333 fn default_values_are_sensible() {
334 let cfg = Config::default();
335 assert_eq!(cfg.api_base_url, "https://studio.minis.gg/");
336 assert!(cfg.auto_start);
337 assert!(cfg.auto_update_enabled);
338 assert_eq!(cfg.auto_update_interval_secs, 1800);
339 assert!(!cfg.auto_update_prerelease);
340 assert!(cfg.auto_update_feed.contains("webbertakken/studio-worker"));
341 assert_eq!(cfg.vram_threshold_gb, 12.0);
342 assert!(cfg.worker_id.is_none());
343 assert!(cfg.auth_token.is_none());
344 let m = cfg.models_root.to_string_lossy().to_string();
347 assert!(m.ends_with("models") || m.contains("studio-worker-models"));
348 }
349
350 #[test]
351 fn resolve_path_uses_override_when_provided() {
352 let path = resolve_path(Some("/tmp/test-config.toml")).unwrap();
353 assert_eq!(path, PathBuf::from("/tmp/test-config.toml"));
354 }
355
356 #[test]
357 fn resolve_path_defaults_when_no_override() {
358 let path = resolve_path(None).unwrap();
359 let s = path.to_string_lossy();
360 assert!(
361 s.contains("minis-studio-worker") || s.contains("minis.gg.minis-studio-worker"),
362 "unexpected default path: {s}"
363 );
364 assert!(s.ends_with("config.toml"));
365 }
366
367 #[test]
368 fn load_creates_default_when_file_missing() {
369 let dir = tempdir().unwrap();
370 let path = dir.path().join("sub").join("config.toml");
371 let path_str = path.to_string_lossy().to_string();
372 let (cfg, returned_path) = load(Some(&path_str)).unwrap();
373 assert_eq!(returned_path, path);
374 assert_eq!(cfg.api_base_url, "https://studio.minis.gg/");
375 assert!(path.exists());
377 }
378
379 #[test]
380 fn round_trip_via_save_and_load_preserves_fields() {
381 let dir = tempdir().unwrap();
382 let path = dir.path().join("config.toml");
383 let cfg = Config {
384 worker_id: Some("w-123".into()),
385 auth_token: Some("tok-xyz".into()),
386 vram_threshold_gb: 24.0,
387 auto_update_prerelease: true,
388 models_root: PathBuf::from("/tmp/test-models"),
389 ..Config::default()
390 };
391 save(&cfg, &path).unwrap();
392
393 let path_str = path.to_string_lossy().to_string();
394 let (loaded, _) = load(Some(&path_str)).unwrap();
395 assert_eq!(loaded.api_base_url, cfg.api_base_url);
396 assert_eq!(loaded.worker_id, cfg.worker_id);
397 assert_eq!(loaded.auth_token, cfg.auth_token);
398 assert_eq!(loaded.vram_threshold_gb, cfg.vram_threshold_gb);
399 assert_eq!(loaded.auto_update_prerelease, cfg.auto_update_prerelease);
400 assert_eq!(loaded.models_root, cfg.models_root);
401 }
402
403 #[test]
404 fn shared_wraps_in_arc_mutex() {
405 let cfg = Config::default();
406 let shared = shared(cfg.clone());
407 let guard = shared.lock();
408 assert_eq!(guard.api_base_url, cfg.api_base_url);
409 }
410
411 #[test]
412 fn load_returns_error_on_malformed_toml() {
413 let dir = tempdir().unwrap();
414 let path = dir.path().join("config.toml");
415 std::fs::write(&path, "this :: is = not = toml = :").unwrap();
416 let path_str = path.to_string_lossy().to_string();
417 let err = load(Some(&path_str)).unwrap_err();
418 assert!(err.to_string().contains("parsing config.toml"));
419 }
420
421 #[test]
422 fn load_strips_legacy_engine_fields_silently() {
423 let dir = tempdir().unwrap();
427 let path = dir.path().join("config.toml");
428 let legacy = r#"
429 api_base_url = "https://example.invalid"
430 vram_threshold_gb = 8.0
431 auto_start = true
432 engine = "multi"
433 engines = ["llama", "synthetic"]
434 auto_enabled = false
435 label = "alice's rig"
436 "#;
437 std::fs::write(&path, legacy).unwrap();
438 let (cfg, _) = load(Some(&path.to_string_lossy())).unwrap();
439 assert_eq!(cfg.api_base_url, "https://example.invalid");
440 assert_eq!(cfg.vram_threshold_gb, 8.0);
441 }
442
443 #[test]
444 fn load_expands_leading_tilde_in_models_root() {
445 let dir = tempdir().unwrap();
448 let path = dir.path().join("config.toml");
449 let raw = r#"
450 api_base_url = "https://x.invalid"
451 vram_threshold_gb = 4.0
452 auto_start = true
453 auto_update_enabled = false
454 auto_update_interval_secs = 1
455 auto_update_feed = "https://x.invalid"
456 auto_update_prerelease = false
457 models_root = "~/models-test"
458 "#;
459 std::fs::write(&path, raw).unwrap();
460 let (cfg, _) = load(Some(&path.to_string_lossy())).unwrap();
461 assert!(
462 cfg.models_root.is_absolute(),
463 "~/ should expand to an absolute path, got {}",
464 cfg.models_root.display()
465 );
466 assert!(cfg.models_root.ends_with("models-test"));
467 }
468
469 #[test]
470 fn expand_home_leaves_absolute_paths_alone() {
471 let p = PathBuf::from("/tmp/anywhere");
472 assert_eq!(expand_home(p.clone()), p);
473 }
474
475 #[test]
476 fn expand_home_handles_bare_tilde() {
477 let expanded = expand_home(PathBuf::from("~"));
478 assert!(
479 expanded.is_absolute() || expanded == Path::new("~"),
480 "bare ~ expands to home (or stays put on weird boxes), got {}",
481 expanded.display()
482 );
483 }
484
485 #[cfg(unix)]
486 #[test]
487 fn save_writes_config_owner_only_because_it_holds_secrets() {
488 use std::os::unix::fs::PermissionsExt;
494 let dir = tempdir().unwrap();
495 let path = dir.path().join("config.toml");
496 let cfg = Config {
497 auth_token: Some("super-secret-token".into()),
498 registration_secret: Some("reg-secret".into()),
499 ..Config::default()
500 };
501 save(&cfg, &path).unwrap();
502 let mode = std::fs::metadata(&path).unwrap().permissions().mode();
503 assert_eq!(
504 mode & 0o077,
505 0,
506 "secrets-bearing config must not be group/world-accessible; got mode {mode:o}"
507 );
508 }
509
510 #[test]
511 fn save_atomically_replaces_existing_config_without_temp_litter() {
512 let dir = tempdir().unwrap();
516 let path = dir.path().join("config.toml");
517
518 let big = Config {
519 api_base_url: "https://a-very-long-host-name.example.invalid/studio/".into(),
520 worker_id: Some("worker-with-a-longish-id-000000".into()),
521 ..Config::default()
522 };
523 save(&big, &path).unwrap();
524
525 let small = Config {
526 api_base_url: "https://x/".into(),
527 ..Config::default()
528 };
529 save(&small, &path).unwrap();
530
531 let (loaded, _) = load(Some(&path.to_string_lossy())).unwrap();
532 assert_eq!(loaded.api_base_url, "https://x/");
533 assert!(
534 loaded.worker_id.is_none(),
535 "a replacing save must not leave the previous worker_id behind"
536 );
537
538 let names: Vec<String> = std::fs::read_dir(dir.path())
539 .unwrap()
540 .map(|e| e.unwrap().file_name().to_string_lossy().to_string())
541 .collect();
542 assert_eq!(
543 names,
544 vec!["config.toml".to_string()],
545 "atomic save must leave only the target file, found: {names:?}"
546 );
547 }
548}