1use anyhow::{anyhow, Context, Result};
12use directories::ProjectDirs;
13use parking_lot::Mutex;
14use serde::{Deserialize, Serialize};
15use std::path::{Path, PathBuf};
16
17const TRACE_TARGET: &str = "studio_worker::config";
20
21#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct Config {
23 pub api_base_url: String,
25 pub bootstrap_token: String,
27 #[serde(default, skip_serializing_if = "Option::is_none")]
29 pub worker_id: Option<String>,
30 #[serde(default, skip_serializing_if = "Option::is_none")]
32 pub auth_token: Option<String>,
33 pub vram_threshold_gb: f32,
35 pub auto_start: bool,
37 pub auto_enabled: bool,
39 pub engine: String,
43 #[serde(default)]
46 pub engines: Vec<String>,
47 #[serde(default, skip_serializing_if = "Option::is_none")]
49 pub gradio_endpoint_url: Option<String>,
50 #[serde(default)]
53 pub supported_models_override: Vec<String>,
54 #[serde(default = "default_auto_update_enabled")]
57 pub auto_update_enabled: bool,
58 #[serde(default = "default_auto_update_interval")]
60 pub auto_update_interval_secs: u64,
61 #[serde(default = "default_auto_update_feed")]
63 pub auto_update_feed: String,
64 #[serde(default)]
66 pub auto_update_prerelease: bool,
67 #[serde(default, skip_serializing_if = "Option::is_none")]
71 pub models_root: Option<std::path::PathBuf>,
72}
73
74fn default_auto_update_enabled() -> bool {
75 true
76}
77fn default_auto_update_interval() -> u64 {
78 1800
79}
80fn default_auto_update_feed() -> String {
81 "https://api.github.com/repos/webbertakken/studio-worker/releases".into()
82}
83
84impl Default for Config {
85 fn default() -> Self {
86 Self {
87 api_base_url: "http://localhost:9790".into(),
88 bootstrap_token: "dev-bootstrap-token".into(),
89 worker_id: None,
90 auth_token: None,
91 vram_threshold_gb: 12.0,
92 auto_start: true,
93 auto_enabled: true,
94 engine: "synthetic".into(),
95 engines: Vec::new(),
96 gradio_endpoint_url: None,
97 supported_models_override: Vec::new(),
98 auto_update_enabled: default_auto_update_enabled(),
99 auto_update_interval_secs: default_auto_update_interval(),
100 auto_update_feed: default_auto_update_feed(),
101 auto_update_prerelease: false,
102 models_root: None,
103 }
104 }
105}
106
107fn default_config_path() -> Result<PathBuf> {
108 let dirs = ProjectDirs::from("gg", "minis", "minis-studio-worker")
109 .ok_or_else(|| anyhow!("cannot resolve config directory"))?;
110 Ok(dirs.config_dir().join("config.toml"))
111}
112
113pub fn resolve_path(override_path: Option<&str>) -> Result<PathBuf> {
114 if let Some(p) = override_path {
115 Ok(PathBuf::from(p))
116 } else {
117 default_config_path()
118 }
119}
120
121pub fn load(override_path: Option<&str>) -> Result<(Config, PathBuf)> {
122 let path = resolve_path(override_path)?;
123 if !path.exists() {
124 let cfg = Config::default();
125 save(&cfg, &path)?;
126 tracing::info!(
127 target: TRACE_TARGET,
128 op = "load",
129 source = "default_created",
130 config_path = %path.display(),
131 engine = %cfg.engine,
132 api_base_url = %cfg.api_base_url,
133 vram_threshold_gb = cfg.vram_threshold_gb,
134 auto_enabled = cfg.auto_enabled,
135 "config file missing — bootstrapped defaults"
136 );
137 return Ok((cfg, path));
138 }
139 let text =
140 std::fs::read_to_string(&path).with_context(|| format!("reading {}", path.display()))?;
141 let cfg: Config = toml::from_str(&text).with_context(|| "parsing config.toml")?;
142 tracing::debug!(
143 target: TRACE_TARGET,
144 op = "load",
145 source = "existing_file",
146 config_path = %path.display(),
147 engine = %cfg.engine,
148 api_base_url = %cfg.api_base_url,
149 vram_threshold_gb = cfg.vram_threshold_gb,
150 auto_enabled = cfg.auto_enabled,
151 worker_id = cfg.worker_id.as_deref().unwrap_or("(unregistered)"),
152 has_auth_token = cfg.auth_token.is_some(),
153 "loaded config from disk"
154 );
155 Ok((cfg, path))
156}
157
158pub fn save(cfg: &Config, path: &Path) -> Result<()> {
159 if let Some(parent) = path.parent() {
160 std::fs::create_dir_all(parent)
161 .with_context(|| format!("creating {}", parent.display()))?;
162 }
163 let text = toml::to_string_pretty(cfg).with_context(|| "serialising config")?;
164 let bytes = text.len();
165 std::fs::write(path, text).with_context(|| format!("writing {}", path.display()))?;
166 tracing::debug!(
167 target: TRACE_TARGET,
168 op = "save",
169 config_path = %path.display(),
170 engine = %cfg.engine,
171 vram_threshold_gb = cfg.vram_threshold_gb,
172 auto_enabled = cfg.auto_enabled,
173 bytes = bytes,
174 "persisted config to disk"
175 );
176 Ok(())
177}
178
179pub type SharedConfig = std::sync::Arc<Mutex<Config>>;
181
182pub fn shared(cfg: Config) -> SharedConfig {
183 std::sync::Arc::new(Mutex::new(cfg))
184}
185
186#[cfg(test)]
187mod tests {
188 use super::*;
189 use tempfile::tempdir;
190
191 #[test]
192 fn default_values_are_sensible() {
193 let cfg = Config::default();
194 assert_eq!(cfg.engine, "synthetic");
195 assert!(cfg.auto_enabled);
196 assert!(cfg.auto_start);
197 assert!(cfg.auto_update_enabled);
198 assert_eq!(cfg.auto_update_interval_secs, 1800);
199 assert!(!cfg.auto_update_prerelease);
200 assert!(cfg.auto_update_feed.contains("webbertakken/studio-worker"));
201 assert_eq!(cfg.vram_threshold_gb, 12.0);
202 assert!(cfg.worker_id.is_none());
203 assert!(cfg.auth_token.is_none());
204 }
205
206 #[test]
207 fn resolve_path_uses_override_when_provided() {
208 let path = resolve_path(Some("/tmp/test-config.toml")).unwrap();
209 assert_eq!(path, PathBuf::from("/tmp/test-config.toml"));
210 }
211
212 #[test]
213 fn resolve_path_defaults_when_no_override() {
214 let path = resolve_path(None).unwrap();
215 let s = path.to_string_lossy();
216 assert!(
217 s.contains("minis-studio-worker") || s.contains("minis.gg.minis-studio-worker"),
218 "unexpected default path: {s}"
219 );
220 assert!(s.ends_with("config.toml"));
221 }
222
223 #[test]
224 fn load_creates_default_when_file_missing() {
225 let dir = tempdir().unwrap();
226 let path = dir.path().join("sub").join("config.toml");
227 let path_str = path.to_string_lossy().to_string();
228 let (cfg, returned_path) = load(Some(&path_str)).unwrap();
229 assert_eq!(returned_path, path);
230 assert_eq!(cfg.engine, "synthetic");
231 assert!(path.exists());
233 }
234
235 #[test]
236 fn round_trip_via_save_and_load_preserves_fields() {
237 let dir = tempdir().unwrap();
238 let path = dir.path().join("config.toml");
239 let cfg = Config {
240 engine: "gradio".into(),
241 gradio_endpoint_url: Some("http://example.invalid".into()),
242 worker_id: Some("w-123".into()),
243 auth_token: Some("tok-xyz".into()),
244 vram_threshold_gb: 24.0,
245 auto_update_prerelease: true,
246 supported_models_override: vec!["foo".into(), "bar".into()],
247 ..Config::default()
248 };
249 save(&cfg, &path).unwrap();
250
251 let path_str = path.to_string_lossy().to_string();
252 let (loaded, _) = load(Some(&path_str)).unwrap();
253 assert_eq!(loaded.engine, cfg.engine);
254 assert_eq!(loaded.gradio_endpoint_url, cfg.gradio_endpoint_url);
255 assert_eq!(loaded.worker_id, cfg.worker_id);
256 assert_eq!(loaded.auth_token, cfg.auth_token);
257 assert_eq!(loaded.vram_threshold_gb, cfg.vram_threshold_gb);
258 assert_eq!(loaded.auto_update_prerelease, cfg.auto_update_prerelease);
259 assert_eq!(
260 loaded.supported_models_override,
261 cfg.supported_models_override
262 );
263 }
264
265 #[test]
266 fn shared_wraps_in_arc_mutex() {
267 let cfg = Config::default();
268 let shared = shared(cfg.clone());
269 let guard = shared.lock();
270 assert_eq!(guard.engine, cfg.engine);
271 }
272
273 #[test]
274 fn load_returns_error_on_malformed_toml() {
275 let dir = tempdir().unwrap();
276 let path = dir.path().join("config.toml");
277 std::fs::write(&path, "this :: is = not = toml = :").unwrap();
278 let path_str = path.to_string_lossy().to_string();
279 let err = load(Some(&path_str)).unwrap_err();
280 assert!(err.to_string().contains("parsing config.toml"));
281 }
282}