1use std::path::Path;
18use std::sync::Arc;
19
20use anyhow::Result;
21use chrono::{DateTime, Utc};
22use parking_lot::Mutex;
23use sha2::{Digest, Sha256};
24
25use crate::{
26 config::{self, SharedConfig},
27 engine,
28 http::ApiClient,
29 runtime::build_capabilities,
30 types::{AutoRegisterRequest, RegisterStatus},
31 AGENT_VERSION,
32};
33
34#[derive(Debug, Clone, PartialEq, Eq)]
38pub enum RegistrationState {
39 Pristine,
41 Pending {
44 request_id: String,
45 since: DateTime<Utc>,
47 },
48 Approved,
51 Rejected { reason: String },
54}
55
56pub type SharedRegistration = Arc<Mutex<RegistrationState>>;
59
60pub fn shared_initial() -> SharedRegistration {
61 Arc::new(Mutex::new(RegistrationState::Pristine))
62}
63
64pub async fn tick(
73 cfg: &SharedConfig,
74 config_path: &Path,
75 observers: &SharedRegistration,
76) -> RegistrationState {
77 {
79 let snap = cfg.lock();
80 if snap.worker_id.is_some() && snap.auth_token.is_some() {
81 *observers.lock() = RegistrationState::Approved;
82 return RegistrationState::Approved;
83 }
84 }
85
86 ensure_install_state(cfg, config_path);
88
89 let (api_base_url, request_id, secret, install_id, label) = {
91 let snap = cfg.lock();
92 (
93 snap.api_base_url.clone(),
94 snap.registration_request_id.clone(),
95 snap.registration_secret.clone(),
96 snap.install_id.clone(),
97 snap.label.clone(),
98 )
99 };
100
101 match (request_id, secret) {
102 (Some(rid), Some(sec)) => {
103 poll_existing(cfg, config_path, observers, api_base_url, rid, sec).await
104 }
105 _ => {
106 create_request(
107 cfg,
108 config_path,
109 observers,
110 api_base_url,
111 install_id.expect("ensure_install_state seeds install_id"),
112 label,
113 )
114 .await
115 }
116 }
117}
118
119fn ensure_install_state(cfg: &SharedConfig, config_path: &Path) {
120 let mut snap = cfg.lock();
121 let mut dirty = false;
122 if snap.install_id.is_none() {
123 snap.install_id = Some(new_uuid());
124 dirty = true;
125 }
126 if snap.registration_request_id.is_none() && snap.registration_secret.is_none() {
129 snap.registration_secret = Some(new_secret_hex());
130 dirty = true;
131 }
132 if dirty {
133 let snapshot = snap.clone();
134 drop(snap);
135 if let Err(e) = config::save(&snapshot, config_path) {
136 tracing::warn!(
137 target: "studio_worker::auto_register",
138 "failed to persist install state: {e}"
139 );
140 }
141 }
142}
143
144async fn create_request(
145 cfg: &SharedConfig,
146 config_path: &Path,
147 observers: &SharedRegistration,
148 api_base_url: String,
149 install_id: String,
150 label: Option<String>,
151) -> RegistrationState {
152 let secret = match cfg.lock().registration_secret.clone() {
153 Some(s) => s,
154 None => {
155 let s = new_secret_hex();
157 cfg.lock().registration_secret = Some(s.clone());
158 s
159 }
160 };
161 let secret_hash = sha256_hex(&secret);
162
163 let payload = match build_payload(cfg, install_id.clone(), secret_hash, label) {
165 Ok(p) => p,
166 Err(e) => {
167 tracing::warn!(
168 target: "studio_worker::auto_register",
169 "engine build failed during register-request: {e}"
170 );
171 return RegistrationState::Pristine;
172 }
173 };
174
175 let api_base_url_for_task = api_base_url.clone();
176 let payload_for_task = payload.clone();
177 let result = tokio::task::spawn_blocking(move || -> Result<_> {
178 let api = ApiClient::new(api_base_url_for_task)?;
179 api.register_request(&payload_for_task)
180 })
181 .await;
182
183 let response = match result {
184 Ok(Ok(r)) => r,
185 Ok(Err(e)) => {
186 tracing::warn!(
187 target: "studio_worker::auto_register",
188 "register-request HTTP failed; will retry next tick: {e}"
189 );
190 return RegistrationState::Pristine;
191 }
192 Err(e) => {
193 tracing::warn!(
194 target: "studio_worker::auto_register",
195 "register-request task panic; will retry next tick: {e}"
196 );
197 return RegistrationState::Pristine;
198 }
199 };
200
201 let now = Utc::now();
203 {
204 let mut snap = cfg.lock();
205 snap.registration_request_id = Some(response.request_id.clone());
206 let snapshot = snap.clone();
207 drop(snap);
208 if let Err(e) = config::save(&snapshot, config_path) {
209 tracing::warn!(
210 target: "studio_worker::auto_register",
211 "failed to persist request_id: {e}"
212 );
213 }
214 }
215 let state = RegistrationState::Pending {
216 request_id: response.request_id,
217 since: now,
218 };
219 *observers.lock() = state.clone();
220 state
221}
222
223async fn poll_existing(
224 cfg: &SharedConfig,
225 config_path: &Path,
226 observers: &SharedRegistration,
227 api_base_url: String,
228 request_id: String,
229 secret: String,
230) -> RegistrationState {
231 let api_base_url_for_task = api_base_url.clone();
232 let request_id_for_task = request_id.clone();
233 let secret_for_task = secret.clone();
234 let result = tokio::task::spawn_blocking(move || -> Result<_> {
235 let api = ApiClient::new(api_base_url_for_task)?;
236 api.poll_register_status(&request_id_for_task, &secret_for_task)
237 })
238 .await;
239
240 let outcome = match result {
241 Ok(Ok(o)) => o,
242 Ok(Err(e)) => {
243 tracing::warn!(
244 target: "studio_worker::auto_register",
245 "poll failed; will retry next tick: {e}"
246 );
247 let state = RegistrationState::Pending {
248 request_id,
249 since: Utc::now(),
250 };
251 *observers.lock() = state.clone();
252 return state;
253 }
254 Err(e) => {
255 tracing::warn!(
256 target: "studio_worker::auto_register",
257 "poll task panic; will retry next tick: {e}"
258 );
259 let state = RegistrationState::Pending {
260 request_id,
261 since: Utc::now(),
262 };
263 *observers.lock() = state.clone();
264 return state;
265 }
266 };
267
268 match outcome {
269 None => {
270 {
273 let mut snap = cfg.lock();
274 snap.registration_request_id = None;
275 snap.registration_secret = None;
276 let snapshot = snap.clone();
277 drop(snap);
278 let _ = config::save(&snapshot, config_path);
279 }
280 *observers.lock() = RegistrationState::Pristine;
281 RegistrationState::Pristine
282 }
283 Some(RegisterStatus::Pending) => {
284 let state = RegistrationState::Pending {
285 request_id,
286 since: Utc::now(),
287 };
288 *observers.lock() = state.clone();
289 state
290 }
291 Some(RegisterStatus::Approved {
292 worker_id,
293 auth_token,
294 }) => {
295 {
296 let mut snap = cfg.lock();
297 snap.worker_id = Some(worker_id);
298 snap.auth_token = Some(auth_token);
299 snap.registration_request_id = None;
300 snap.registration_secret = None;
301 let snapshot = snap.clone();
302 drop(snap);
303 let _ = config::save(&snapshot, config_path);
304 }
305 *observers.lock() = RegistrationState::Approved;
306 RegistrationState::Approved
307 }
308 Some(RegisterStatus::Rejected { reason }) => {
309 {
310 let mut snap = cfg.lock();
311 snap.registration_request_id = None;
312 snap.registration_secret = None;
313 let snapshot = snap.clone();
314 drop(snap);
315 let _ = config::save(&snapshot, config_path);
316 }
317 let state = RegistrationState::Rejected { reason };
318 *observers.lock() = state.clone();
319 state
320 }
321 }
322}
323
324fn build_payload(
325 cfg: &SharedConfig,
326 install_id: String,
327 registration_secret_hash: String,
328 label: Option<String>,
329) -> Result<AutoRegisterRequest> {
330 let snap = cfg.lock().clone();
331 let engine_handle = engine::build(&snap)?;
332 let capabilities = build_capabilities(&snap, &*engine_handle);
333 Ok(AutoRegisterRequest {
334 install_id,
335 registration_secret_hash,
336 label,
337 capabilities,
338 user_agent: format!("studio-worker/{AGENT_VERSION}"),
339 })
340}
341
342fn new_uuid() -> String {
343 let bytes: [u8; 16] = rand_bytes_16();
346 let hex: String = bytes.iter().map(|b| format!("{b:02x}")).collect();
347 format!(
348 "{}-{}-{}-{}-{}",
349 &hex[0..8],
350 &hex[8..12],
351 &hex[12..16],
352 &hex[16..20],
353 &hex[20..32]
354 )
355}
356
357fn new_secret_hex() -> String {
358 let bytes: [u8; 32] = rand_bytes_32();
360 bytes.iter().map(|b| format!("{b:02x}")).collect()
361}
362
363fn sha256_hex(input: &str) -> String {
364 let mut hasher = Sha256::new();
365 hasher.update(input.as_bytes());
366 let digest = hasher.finalize();
367 digest.iter().map(|b| format!("{b:02x}")).collect()
368}
369
370#[cfg(unix)]
376fn rand_bytes_16() -> [u8; 16] {
377 rand_bytes::<16>()
378}
379
380#[cfg(unix)]
381fn rand_bytes_32() -> [u8; 32] {
382 rand_bytes::<32>()
383}
384
385#[cfg(unix)]
386fn rand_bytes<const N: usize>() -> [u8; N] {
387 use std::io::Read;
388 let mut buf = [0u8; N];
389 if let Ok(mut f) = std::fs::File::open("/dev/urandom") {
390 if f.read_exact(&mut buf).is_ok() {
391 return buf;
392 }
393 }
394 fallback_bytes(&mut buf);
395 buf
396}
397
398#[cfg(windows)]
399fn rand_bytes_16() -> [u8; 16] {
400 let mut buf = [0u8; 16];
401 fallback_bytes(&mut buf);
402 buf
403}
404
405#[cfg(windows)]
406fn rand_bytes_32() -> [u8; 32] {
407 let mut buf = [0u8; 32];
408 fallback_bytes(&mut buf);
409 buf
410}
411
412fn fallback_bytes(buf: &mut [u8]) {
413 use std::time::{SystemTime, UNIX_EPOCH};
418 let nanos = SystemTime::now()
419 .duration_since(UNIX_EPOCH)
420 .map(|d| d.as_nanos())
421 .unwrap_or(0);
422 let mut counter: u64 = 0;
423 let mut offset = 0;
424 while offset < buf.len() {
425 let mut hasher = Sha256::new();
426 hasher.update(nanos.to_le_bytes());
427 hasher.update(counter.to_le_bytes());
428 let digest = hasher.finalize();
429 let take = (buf.len() - offset).min(digest.len());
430 buf[offset..offset + take].copy_from_slice(&digest[..take]);
431 offset += take;
432 counter += 1;
433 }
434}
435
436#[cfg(test)]
437mod tests {
438 use super::*;
439
440 #[test]
441 fn new_uuid_has_expected_shape() {
442 let id = new_uuid();
443 let parts: Vec<&str> = id.split('-').collect();
444 assert_eq!(parts.len(), 5);
445 assert_eq!(parts[0].len(), 8);
446 assert_eq!(parts[1].len(), 4);
447 assert_eq!(parts[2].len(), 4);
448 assert_eq!(parts[3].len(), 4);
449 assert_eq!(parts[4].len(), 12);
450 assert!(id.chars().all(|c| c.is_ascii_hexdigit() || c == '-'));
451 }
452
453 #[test]
454 fn new_uuid_is_unique() {
455 let a = new_uuid();
456 let b = new_uuid();
457 assert_ne!(a, b);
458 }
459
460 #[test]
461 fn new_secret_hex_is_64_chars() {
462 let s = new_secret_hex();
463 assert_eq!(s.len(), 64);
464 assert!(s.chars().all(|c| c.is_ascii_hexdigit()));
465 }
466
467 #[test]
468 fn sha256_hex_is_deterministic() {
469 assert_eq!(sha256_hex("abc"), sha256_hex("abc"));
470 assert_ne!(sha256_hex("abc"), sha256_hex("abd"));
471 assert_eq!(sha256_hex("").len(), 64);
472 }
473}