1use anyhow::{anyhow, Context, Result};
12use directories::ProjectDirs;
13use parking_lot::Mutex;
14use serde::{Deserialize, Serialize};
15use std::path::{Path, PathBuf};
16
17const TRACE_TARGET: &str = "studio_worker::config";
20
21#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct Config {
23 pub api_base_url: String,
25 #[serde(default, skip_serializing_if = "Option::is_none")]
28 pub worker_id: Option<String>,
29 #[serde(default, skip_serializing_if = "Option::is_none")]
31 pub auth_token: Option<String>,
32 pub vram_threshold_gb: f32,
34 pub auto_start: bool,
36 pub auto_enabled: bool,
38 pub engine: String,
42 #[serde(default)]
45 pub engines: Vec<String>,
46 #[serde(default, skip_serializing_if = "Option::is_none")]
48 pub gradio_endpoint_url: Option<String>,
49 #[serde(default)]
52 pub supported_models_override: Vec<String>,
53 #[serde(default = "default_auto_update_enabled")]
56 pub auto_update_enabled: bool,
57 #[serde(default = "default_auto_update_interval")]
59 pub auto_update_interval_secs: u64,
60 #[serde(default = "default_auto_update_feed")]
62 pub auto_update_feed: String,
63 #[serde(default)]
65 pub auto_update_prerelease: bool,
66 #[serde(default, skip_serializing_if = "Option::is_none")]
70 pub models_root: Option<std::path::PathBuf>,
71 #[serde(default, skip_serializing_if = "Option::is_none")]
75 pub ws_reconnect_attempts: Option<u32>,
76 #[serde(default, skip_serializing_if = "Option::is_none")]
80 pub install_id: Option<String>,
81 #[serde(default, skip_serializing_if = "Option::is_none")]
85 pub label: Option<String>,
86 #[serde(default, skip_serializing_if = "Option::is_none")]
89 pub registration_request_id: Option<String>,
90 #[serde(default, skip_serializing_if = "Option::is_none")]
93 pub registration_secret: Option<String>,
94}
95
96fn default_auto_update_enabled() -> bool {
97 true
98}
99fn default_auto_update_interval() -> u64 {
100 1800
101}
102fn default_auto_update_feed() -> String {
103 "https://api.github.com/repos/webbertakken/studio-worker/releases".into()
104}
105
106impl Default for Config {
107 fn default() -> Self {
108 Self {
109 api_base_url: "https://studio.minis.gg".into(),
110 worker_id: None,
111 auth_token: None,
112 vram_threshold_gb: 12.0,
113 auto_start: true,
114 auto_enabled: true,
115 engine: "synthetic".into(),
116 engines: Vec::new(),
117 gradio_endpoint_url: None,
118 supported_models_override: Vec::new(),
119 auto_update_enabled: default_auto_update_enabled(),
120 auto_update_interval_secs: default_auto_update_interval(),
121 auto_update_feed: default_auto_update_feed(),
122 auto_update_prerelease: false,
123 models_root: None,
124 ws_reconnect_attempts: None,
125 install_id: None,
126 label: None,
127 registration_request_id: None,
128 registration_secret: None,
129 }
130 }
131}
132
133fn default_config_path() -> Result<PathBuf> {
134 let dirs = ProjectDirs::from("gg", "minis", "minis-studio-worker")
135 .ok_or_else(|| anyhow!("cannot resolve config directory"))?;
136 Ok(dirs.config_dir().join("config.toml"))
137}
138
139pub fn resolve_path(override_path: Option<&str>) -> Result<PathBuf> {
140 if let Some(p) = override_path {
141 Ok(PathBuf::from(p))
142 } else {
143 default_config_path()
144 }
145}
146
147pub fn load(override_path: Option<&str>) -> Result<(Config, PathBuf)> {
148 let path = resolve_path(override_path)?;
149 if !path.exists() {
150 let cfg = Config::default();
151 save(&cfg, &path)?;
152 tracing::info!(
153 target: TRACE_TARGET,
154 op = "load",
155 source = "default_created",
156 config_path = %path.display(),
157 engine = %cfg.engine,
158 api_base_url = %cfg.api_base_url,
159 vram_threshold_gb = cfg.vram_threshold_gb,
160 auto_enabled = cfg.auto_enabled,
161 "config file missing — bootstrapped defaults"
162 );
163 return Ok((cfg, path));
164 }
165 let text =
166 std::fs::read_to_string(&path).with_context(|| format!("reading {}", path.display()))?;
167 let cfg: Config = toml::from_str(&text).with_context(|| "parsing config.toml")?;
168 tracing::debug!(
169 target: TRACE_TARGET,
170 op = "load",
171 source = "existing_file",
172 config_path = %path.display(),
173 engine = %cfg.engine,
174 api_base_url = %cfg.api_base_url,
175 vram_threshold_gb = cfg.vram_threshold_gb,
176 auto_enabled = cfg.auto_enabled,
177 worker_id = cfg.worker_id.as_deref().unwrap_or("(unregistered)"),
178 has_auth_token = cfg.auth_token.is_some(),
179 "loaded config from disk"
180 );
181 Ok((cfg, path))
182}
183
184pub fn save(cfg: &Config, path: &Path) -> Result<()> {
185 if let Some(parent) = path.parent() {
186 std::fs::create_dir_all(parent)
187 .with_context(|| format!("creating {}", parent.display()))?;
188 }
189 let text = toml::to_string_pretty(cfg).with_context(|| "serialising config")?;
190 let bytes = text.len();
191 std::fs::write(path, text).with_context(|| format!("writing {}", path.display()))?;
192 tracing::debug!(
193 target: TRACE_TARGET,
194 op = "save",
195 config_path = %path.display(),
196 engine = %cfg.engine,
197 vram_threshold_gb = cfg.vram_threshold_gb,
198 auto_enabled = cfg.auto_enabled,
199 bytes = bytes,
200 "persisted config to disk"
201 );
202 Ok(())
203}
204
205pub type SharedConfig = std::sync::Arc<Mutex<Config>>;
207
208pub fn shared(cfg: Config) -> SharedConfig {
209 std::sync::Arc::new(Mutex::new(cfg))
210}
211
212#[cfg(test)]
213mod tests {
214 use super::*;
215 use tempfile::tempdir;
216
217 #[test]
218 fn default_values_are_sensible() {
219 let cfg = Config::default();
220 assert_eq!(cfg.engine, "synthetic");
221 assert!(cfg.auto_enabled);
222 assert!(cfg.auto_start);
223 assert!(cfg.auto_update_enabled);
224 assert_eq!(cfg.auto_update_interval_secs, 1800);
225 assert!(!cfg.auto_update_prerelease);
226 assert!(cfg.auto_update_feed.contains("webbertakken/studio-worker"));
227 assert_eq!(cfg.vram_threshold_gb, 12.0);
228 assert!(cfg.worker_id.is_none());
229 assert!(cfg.auth_token.is_none());
230 }
231
232 #[test]
233 fn resolve_path_uses_override_when_provided() {
234 let path = resolve_path(Some("/tmp/test-config.toml")).unwrap();
235 assert_eq!(path, PathBuf::from("/tmp/test-config.toml"));
236 }
237
238 #[test]
239 fn resolve_path_defaults_when_no_override() {
240 let path = resolve_path(None).unwrap();
241 let s = path.to_string_lossy();
242 assert!(
243 s.contains("minis-studio-worker") || s.contains("minis.gg.minis-studio-worker"),
244 "unexpected default path: {s}"
245 );
246 assert!(s.ends_with("config.toml"));
247 }
248
249 #[test]
250 fn load_creates_default_when_file_missing() {
251 let dir = tempdir().unwrap();
252 let path = dir.path().join("sub").join("config.toml");
253 let path_str = path.to_string_lossy().to_string();
254 let (cfg, returned_path) = load(Some(&path_str)).unwrap();
255 assert_eq!(returned_path, path);
256 assert_eq!(cfg.engine, "synthetic");
257 assert!(path.exists());
259 }
260
261 #[test]
262 fn round_trip_via_save_and_load_preserves_fields() {
263 let dir = tempdir().unwrap();
264 let path = dir.path().join("config.toml");
265 let cfg = Config {
266 engine: "gradio".into(),
267 gradio_endpoint_url: Some("http://example.invalid".into()),
268 worker_id: Some("w-123".into()),
269 auth_token: Some("tok-xyz".into()),
270 vram_threshold_gb: 24.0,
271 auto_update_prerelease: true,
272 supported_models_override: vec!["foo".into(), "bar".into()],
273 ..Config::default()
274 };
275 save(&cfg, &path).unwrap();
276
277 let path_str = path.to_string_lossy().to_string();
278 let (loaded, _) = load(Some(&path_str)).unwrap();
279 assert_eq!(loaded.engine, cfg.engine);
280 assert_eq!(loaded.gradio_endpoint_url, cfg.gradio_endpoint_url);
281 assert_eq!(loaded.worker_id, cfg.worker_id);
282 assert_eq!(loaded.auth_token, cfg.auth_token);
283 assert_eq!(loaded.vram_threshold_gb, cfg.vram_threshold_gb);
284 assert_eq!(loaded.auto_update_prerelease, cfg.auto_update_prerelease);
285 assert_eq!(
286 loaded.supported_models_override,
287 cfg.supported_models_override
288 );
289 }
290
291 #[test]
292 fn shared_wraps_in_arc_mutex() {
293 let cfg = Config::default();
294 let shared = shared(cfg.clone());
295 let guard = shared.lock();
296 assert_eq!(guard.engine, cfg.engine);
297 }
298
299 #[test]
300 fn load_returns_error_on_malformed_toml() {
301 let dir = tempdir().unwrap();
302 let path = dir.path().join("config.toml");
303 std::fs::write(&path, "this :: is = not = toml = :").unwrap();
304 let path_str = path.to_string_lossy().to_string();
305 let err = load(Some(&path_str)).unwrap_err();
306 assert!(err.to_string().contains("parsing config.toml"));
307 }
308}