1use std::path::Path;
18use std::sync::Arc;
19
20use anyhow::Result;
21use chrono::{DateTime, Utc};
22use parking_lot::Mutex;
23use sha2::{Digest, Sha256};
24
25use crate::{
26 config::{self, SharedConfig},
27 engine,
28 http::ApiClient,
29 runtime::build_capabilities,
30 types::{AutoRegisterRequest, RegisterStatus},
31 AGENT_VERSION,
32};
33
34const TRACE_TARGET: &str = "studio_worker::auto_register";
39
40#[derive(Debug, Clone, PartialEq, Eq)]
44pub enum RegistrationState {
45 Pristine,
47 Pending {
50 request_id: String,
51 since: DateTime<Utc>,
53 },
54 Approved,
57 Rejected { reason: String },
60}
61
62pub type SharedRegistration = Arc<Mutex<RegistrationState>>;
65
66pub fn shared_initial() -> SharedRegistration {
67 Arc::new(Mutex::new(RegistrationState::Pristine))
68}
69
70pub async fn tick(
79 cfg: &SharedConfig,
80 config_path: &Path,
81 observers: &SharedRegistration,
82) -> RegistrationState {
83 {
85 let snap = cfg.lock();
86 if snap.worker_id.is_some() && snap.auth_token.is_some() {
87 *observers.lock() = RegistrationState::Approved;
88 return RegistrationState::Approved;
89 }
90 }
91
92 ensure_install_state(cfg, config_path);
94
95 let (api_base_url, request_id, secret, install_id) = {
97 let snap = cfg.lock();
98 (
99 snap.api_base_url.clone(),
100 snap.registration_request_id.clone(),
101 snap.registration_secret.clone(),
102 snap.install_id.clone(),
103 )
104 };
105
106 match (request_id, secret) {
107 (Some(rid), Some(sec)) => {
108 poll_existing(cfg, config_path, observers, api_base_url, rid, sec).await
109 }
110 _ => {
111 create_request(
112 cfg,
113 config_path,
114 observers,
115 api_base_url,
116 install_id.expect("ensure_install_state seeds install_id"),
117 )
118 .await
119 }
120 }
121}
122
123fn ensure_install_state(cfg: &SharedConfig, config_path: &Path) {
124 let mut snap = cfg.lock();
125 let mut dirty = false;
126 if snap.install_id.is_none() {
127 snap.install_id = Some(new_uuid());
128 dirty = true;
129 }
130 if snap.registration_request_id.is_none() && snap.registration_secret.is_none() {
133 snap.registration_secret = Some(new_secret_hex());
134 dirty = true;
135 }
136 if dirty {
137 let snapshot = snap.clone();
138 drop(snap);
139 if let Err(e) = config::save(&snapshot, config_path) {
140 tracing::warn!(
141 target: TRACE_TARGET,
142 op = "ensure-install",
143 config_path = %config_path.display(),
144 error = %e,
145 "failed to persist install state"
146 );
147 }
148 }
149}
150
151async fn create_request(
152 cfg: &SharedConfig,
153 config_path: &Path,
154 observers: &SharedRegistration,
155 api_base_url: String,
156 install_id: String,
157) -> RegistrationState {
158 let existing_secret = cfg.lock().registration_secret.clone();
163 let secret = match existing_secret {
164 Some(s) => s,
165 None => {
166 let s = new_secret_hex();
168 cfg.lock().registration_secret = Some(s.clone());
169 s
170 }
171 };
172 let secret_hash = sha256_hex(&secret);
173
174 let payload = match build_payload(cfg, install_id.clone(), secret_hash) {
176 Ok(p) => p,
177 Err(e) => {
178 tracing::warn!(
179 target: TRACE_TARGET,
180 op = "register-request",
181 error = %e,
182 "engine build failed during register-request"
183 );
184 return RegistrationState::Pristine;
185 }
186 };
187
188 let api_base_url_for_task = api_base_url.clone();
189 let payload_for_task = payload.clone();
190 let result = tokio::task::spawn_blocking(move || -> Result<_> {
191 let api = ApiClient::new(api_base_url_for_task)?;
192 api.register_request(&payload_for_task)
193 })
194 .await;
195
196 let response = match result {
197 Ok(Ok(r)) => r,
198 Ok(Err(e)) => {
199 tracing::warn!(
200 target: TRACE_TARGET,
201 op = "register-request",
202 error = %e,
203 "register-request HTTP failed; will retry next tick"
204 );
205 return RegistrationState::Pristine;
206 }
207 Err(e) => {
208 tracing::warn!(
209 target: TRACE_TARGET,
210 op = "register-request",
211 error = %e,
212 "register-request task panic; will retry next tick"
213 );
214 return RegistrationState::Pristine;
215 }
216 };
217
218 let now = Utc::now();
220 {
221 let mut snap = cfg.lock();
222 snap.registration_request_id = Some(response.request_id.clone());
223 let snapshot = snap.clone();
224 drop(snap);
225 if let Err(e) = config::save(&snapshot, config_path) {
226 tracing::warn!(
227 target: TRACE_TARGET,
228 op = "register-request",
229 config_path = %config_path.display(),
230 error = %e,
231 "failed to persist request_id"
232 );
233 }
234 }
235 let state = RegistrationState::Pending {
236 request_id: response.request_id,
237 since: now,
238 };
239 *observers.lock() = state.clone();
240 state
241}
242
243async fn poll_existing(
244 cfg: &SharedConfig,
245 config_path: &Path,
246 observers: &SharedRegistration,
247 api_base_url: String,
248 request_id: String,
249 secret: String,
250) -> RegistrationState {
251 let api_base_url_for_task = api_base_url.clone();
252 let request_id_for_task = request_id.clone();
253 let secret_for_task = secret.clone();
254 let result = tokio::task::spawn_blocking(move || -> Result<_> {
255 let api = ApiClient::new(api_base_url_for_task)?;
256 api.poll_register_status(&request_id_for_task, &secret_for_task)
257 })
258 .await;
259
260 let outcome = match result {
261 Ok(Ok(o)) => o,
262 Ok(Err(e)) => {
263 tracing::warn!(
264 target: TRACE_TARGET,
265 op = "poll",
266 error = %e,
267 "poll failed; will retry next tick"
268 );
269 let state = RegistrationState::Pending {
270 request_id,
271 since: Utc::now(),
272 };
273 *observers.lock() = state.clone();
274 return state;
275 }
276 Err(e) => {
277 tracing::warn!(
278 target: TRACE_TARGET,
279 op = "poll",
280 error = %e,
281 "poll task panic; will retry next tick"
282 );
283 let state = RegistrationState::Pending {
284 request_id,
285 since: Utc::now(),
286 };
287 *observers.lock() = state.clone();
288 return state;
289 }
290 };
291
292 match outcome {
293 None => {
294 {
297 let mut snap = cfg.lock();
298 snap.registration_request_id = None;
299 snap.registration_secret = None;
300 let snapshot = snap.clone();
301 drop(snap);
302 if let Err(e) = config::save(&snapshot, config_path) {
303 tracing::warn!(
304 target: TRACE_TARGET,
305 op = "poll",
306 config_path = %config_path.display(),
307 error = %e,
308 "failed to persist cleared request state after stale 404; the stale request id stays on disk until the next successful save"
309 );
310 }
311 }
312 *observers.lock() = RegistrationState::Pristine;
313 RegistrationState::Pristine
314 }
315 Some(RegisterStatus::Pending) => {
316 let state = RegistrationState::Pending {
317 request_id,
318 since: Utc::now(),
319 };
320 *observers.lock() = state.clone();
321 state
322 }
323 Some(RegisterStatus::Approved {
324 worker_id,
325 auth_token,
326 }) => {
327 {
328 let mut snap = cfg.lock();
329 snap.worker_id = Some(worker_id);
330 snap.auth_token = Some(auth_token);
331 snap.registration_request_id = None;
332 snap.registration_secret = None;
333 let snapshot = snap.clone();
334 drop(snap);
335 if let Err(e) = config::save(&snapshot, config_path) {
336 tracing::error!(
337 target: TRACE_TARGET,
338 op = "poll",
339 config_path = %config_path.display(),
340 error = %e,
341 "failed to persist approved credentials; this session is registered in memory but the worker will re-register from scratch on the next restart"
342 );
343 }
344 }
345 *observers.lock() = RegistrationState::Approved;
346 RegistrationState::Approved
347 }
348 Some(RegisterStatus::Rejected { reason }) => {
349 {
350 let mut snap = cfg.lock();
351 snap.registration_request_id = None;
352 snap.registration_secret = None;
353 let snapshot = snap.clone();
354 drop(snap);
355 if let Err(e) = config::save(&snapshot, config_path) {
356 tracing::warn!(
357 target: TRACE_TARGET,
358 op = "poll",
359 config_path = %config_path.display(),
360 error = %e,
361 "failed to persist cleared request state after rejection; the stale request id stays on disk until the next successful save"
362 );
363 }
364 }
365 let state = RegistrationState::Rejected { reason };
366 *observers.lock() = state.clone();
367 state
368 }
369 }
370}
371
372fn build_payload(
373 cfg: &SharedConfig,
374 install_id: String,
375 registration_secret_hash: String,
376) -> Result<AutoRegisterRequest> {
377 let snap = cfg.lock().clone();
378 let engine_handle = engine::build(&snap)?;
379 let capabilities = build_capabilities(&snap, &*engine_handle);
380 Ok(AutoRegisterRequest {
381 install_id,
382 registration_secret_hash,
383 capabilities,
384 user_agent: format!("studio-worker/{AGENT_VERSION}"),
385 })
386}
387
388fn new_uuid() -> String {
389 let bytes: [u8; 16] = rand_bytes::<16>();
392 let hex: String = bytes.iter().map(|b| format!("{b:02x}")).collect();
393 format!(
394 "{}-{}-{}-{}-{}",
395 &hex[0..8],
396 &hex[8..12],
397 &hex[12..16],
398 &hex[16..20],
399 &hex[20..32]
400 )
401}
402
403fn new_secret_hex() -> String {
404 let bytes: [u8; 32] = rand_bytes::<32>();
406 bytes.iter().map(|b| format!("{b:02x}")).collect()
407}
408
409fn sha256_hex(input: &str) -> String {
410 let mut hasher = Sha256::new();
411 hasher.update(input.as_bytes());
412 let digest = hasher.finalize();
413 digest.iter().map(|b| format!("{b:02x}")).collect()
414}
415
416fn rand_bytes<const N: usize>() -> [u8; N] {
434 let mut buf = [0u8; N];
435 getrandom::fill(&mut buf).expect("OS entropy source (getrandom) unavailable");
436 buf
437}
438
439#[cfg(test)]
440mod tests {
441 use super::*;
442
443 #[test]
444 fn new_uuid_has_expected_shape() {
445 let id = new_uuid();
446 let parts: Vec<&str> = id.split('-').collect();
447 assert_eq!(parts.len(), 5);
448 assert_eq!(parts[0].len(), 8);
449 assert_eq!(parts[1].len(), 4);
450 assert_eq!(parts[2].len(), 4);
451 assert_eq!(parts[3].len(), 4);
452 assert_eq!(parts[4].len(), 12);
453 assert!(id.chars().all(|c| c.is_ascii_hexdigit() || c == '-'));
454 }
455
456 #[test]
457 fn new_uuid_is_unique() {
458 let a = new_uuid();
459 let b = new_uuid();
460 assert_ne!(a, b);
461 }
462
463 #[test]
464 fn new_secret_hex_is_64_chars() {
465 let s = new_secret_hex();
466 assert_eq!(s.len(), 64);
467 assert!(s.chars().all(|c| c.is_ascii_hexdigit()));
468 }
469
470 #[test]
471 fn sha256_hex_is_deterministic() {
472 assert_eq!(sha256_hex("abc"), sha256_hex("abc"));
473 assert_ne!(sha256_hex("abc"), sha256_hex("abd"));
474 assert_eq!(sha256_hex("").len(), 64);
475 }
476
477 #[test]
485 fn rand_bytes_are_distinct_across_many_calls() {
486 use std::collections::HashSet;
487 let mut seen = HashSet::new();
488 for _ in 0..2_000 {
489 assert!(
490 seen.insert(rand_bytes::<32>()),
491 "rand_bytes produced a duplicate 32-byte value"
492 );
493 }
494 }
495
496 #[test]
497 fn rand_bytes_cover_every_bit_position() {
498 let mut ever_set = [0u8; 32];
503 let mut ever_clear = [0xffu8; 32];
504 for _ in 0..256 {
505 let b = rand_bytes::<32>();
506 for i in 0..32 {
507 ever_set[i] |= b[i];
508 ever_clear[i] &= b[i];
509 }
510 }
511 assert_eq!(ever_set, [0xffu8; 32], "a bit position was never set");
512 assert_eq!(ever_clear, [0u8; 32], "a bit position was never cleared");
513 }
514}