1use std::sync::{
13 atomic::{AtomicBool, Ordering},
14 Arc,
15};
16use std::time::Duration;
17
18use anyhow::{anyhow, Result};
19use parking_lot::Mutex;
20use tokio::sync::mpsc;
21use tracing::{info, warn};
22
23use crate::config::SharedConfig;
24use crate::engine::Engine;
25use crate::http::ApiClient;
26use crate::runtime::{
27 build_capabilities, is_unsupported_kind, prompt_for, push_log, WorkerObservers,
28};
29use crate::types::{LogEntry, TaskResult, WorkerCapabilities};
30use crate::ws::client::{connect, WsClientError, WsSender};
31use crate::ws::types::{HelloFrame, JobOfferClaim, WorkerInbound, WorkerOutbound};
32
33const TRACE_TARGET: &str = "studio_worker::ws::session";
35
36const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5);
37const LOG_FLUSH_INTERVAL: Duration = Duration::from_secs(1);
38const SHUTDOWN_TICK: Duration = Duration::from_millis(250);
39const BASE_BACKOFF_MS: u64 = 1_000;
40const MAX_BACKOFF_MS: u64 = 30_000;
41const DEFAULT_RECONNECT_ATTEMPTS: u32 = 5;
42
43#[derive(Debug)]
46pub enum SessionOutcome {
47 Stopped,
49 Disconnected,
51 AuthFailed(String),
53 Fatal(String),
55}
56
57#[derive(Debug, Clone, Copy)]
59pub struct SessionSchedule {
60 pub heartbeat: Duration,
61 pub log_flush: Duration,
62 pub shutdown_tick: Duration,
63 pub base_backoff_ms: u64,
64 pub max_backoff_ms: u64,
65}
66
67impl Default for SessionSchedule {
68 fn default() -> Self {
69 Self {
70 heartbeat: HEARTBEAT_INTERVAL,
71 log_flush: LOG_FLUSH_INTERVAL,
72 shutdown_tick: SHUTDOWN_TICK,
73 base_backoff_ms: BASE_BACKOFF_MS,
74 max_backoff_ms: MAX_BACKOFF_MS,
75 }
76 }
77}
78
79impl SessionSchedule {
80 pub fn fast_for_tests() -> Self {
81 Self {
82 heartbeat: Duration::from_millis(5),
83 log_flush: Duration::from_millis(5),
84 shutdown_tick: Duration::from_millis(5),
85 base_backoff_ms: 1,
86 max_backoff_ms: 10,
87 }
88 }
89}
90
91pub async fn spawn_ws_session(
94 cfg: SharedConfig,
95 stop: Arc<AtomicBool>,
96 logs: Arc<Mutex<Vec<LogEntry>>>,
97 busy: Arc<AtomicBool>,
98 _observers: WorkerObservers,
99 schedule: SessionSchedule,
100) -> Result<()> {
101 let _ = &_observers;
106 let max_attempts = {
107 let guard = cfg.lock();
108 guard
109 .ws_reconnect_attempts
110 .unwrap_or(DEFAULT_RECONNECT_ATTEMPTS)
111 };
112
113 let mut attempt: u32 = 0;
114 loop {
115 if stop.load(Ordering::SeqCst) {
116 return Ok(());
117 }
118 match run_one_session(&cfg, &stop, &logs, &busy, schedule).await {
119 Ok(SessionOutcome::Stopped) => return Ok(()),
120 Ok(SessionOutcome::AuthFailed(reason)) => {
121 push_log(
122 &logs,
123 "error",
124 "ws",
125 &format!("auth failed: {reason}. Re-register the worker."),
126 None,
127 );
128 return Err(anyhow!("ws auth failed: {reason}"));
129 }
130 Ok(SessionOutcome::Fatal(reason)) => {
131 push_log(&logs, "error", "ws", &format!("fatal: {reason}"), None);
132 return Err(anyhow!("ws fatal: {reason}"));
133 }
134 Ok(SessionOutcome::Disconnected) | Err(_) => {
135 attempt += 1;
136 if max_attempts > 0 && attempt > max_attempts {
137 push_log(
138 &logs,
139 "error",
140 "ws",
141 &format!("giving up after {attempt} reconnect attempts"),
142 None,
143 );
144 return Err(anyhow!("ws reconnect cap reached"));
145 }
146 let backoff = backoff_for(attempt, schedule);
147 push_log(
148 &logs,
149 "warn",
150 "ws",
151 &format!(
152 "disconnected; reconnect attempt {attempt} in {}ms",
153 backoff.as_millis()
154 ),
155 None,
156 );
157 wait_with_stop(backoff, &stop, schedule.shutdown_tick).await;
158 }
159 }
160 }
161}
162
163async fn run_one_session(
166 cfg: &SharedConfig,
167 stop: &Arc<AtomicBool>,
168 logs: &Arc<Mutex<Vec<LogEntry>>>,
169 busy: &Arc<AtomicBool>,
170 schedule: SessionSchedule,
171) -> Result<SessionOutcome> {
172 let (api_base_url, worker_id, auth_token) = {
173 let guard = cfg.lock();
174 (
175 guard.api_base_url.clone(),
176 guard.worker_id.clone().unwrap_or_default(),
177 guard.auth_token.clone().unwrap_or_default(),
178 )
179 };
180 if worker_id.is_empty() || auth_token.is_empty() {
181 return Ok(SessionOutcome::Fatal(
182 "worker_id or auth_token missing; run register".to_string(),
183 ));
184 }
185
186 push_log(
187 logs,
188 "info",
189 "ws",
190 &format!("connecting to {api_base_url}"),
191 None,
192 );
193 let client = match connect(&api_base_url, &worker_id, &auth_token).await {
194 Ok(c) => c,
195 Err(WsClientError::AuthFailed { reason }) => {
196 return Ok(SessionOutcome::AuthFailed(reason));
197 }
198 Err(e) => {
199 push_log(logs, "warn", "ws", &format!("connect failed: {e}"), None);
200 return Ok(SessionOutcome::Disconnected);
201 }
202 };
203 let (sender, receiver) = client.split();
204
205 let engine = crate::engine::build(&cfg.lock())?;
207 let capabilities = build_capabilities(&cfg.lock(), &*engine);
208 sender
209 .send(&WorkerInbound::Hello(HelloFrame {
210 auth_token: auth_token.clone(),
211 capabilities: capabilities.clone(),
212 }))
213 .await
214 .map_err(|e| anyhow!("hello send failed: {e}"))?;
215 info!(target: TRACE_TARGET, worker_id = %worker_id, "hello sent");
216
217 let (event_tx, event_rx) = mpsc::unbounded_channel::<SessionEvent>();
218
219 let reader = spawn_reader(receiver, event_tx.clone());
221
222 let heartbeat = spawn_heartbeat_pump(
224 cfg.clone(),
225 sender.clone(),
226 stop.clone(),
227 busy.clone(),
228 schedule,
229 );
230
231 let log_shipper = spawn_log_shipper_pump(sender.clone(), logs.clone(), stop.clone(), schedule);
233
234 let shutdown_observer = spawn_shutdown_observer(stop.clone(), event_tx.clone(), schedule);
236 drop(event_tx);
237
238 let engine_arc: Arc<dyn Engine> = engine.into();
239 let ctx = SessionContext {
240 sender: sender.clone(),
241 engine: engine_arc,
242 logs: logs.clone(),
243 busy: busy.clone(),
244 api_base_url: api_base_url.clone(),
245 worker_id: worker_id.clone(),
246 auth_token: auth_token.clone(),
247 };
248 let outcome = run_dispatch_loop(ctx, event_rx).await;
249
250 let _ = sender.close(1000, "session ended").await;
252 let _ = reader.await;
253 let _ = heartbeat.await;
254 let _ = log_shipper.await;
255 let _ = shutdown_observer.await;
256 Ok(outcome)
257}
258
259#[derive(Debug)]
261enum SessionEvent {
262 Frame(WorkerOutbound),
264 Stopped,
266 Disconnected(WsClientError),
268}
269
270struct SessionContext {
273 sender: WsSender,
274 engine: Arc<dyn Engine>,
275 logs: Arc<Mutex<Vec<LogEntry>>>,
276 busy: Arc<AtomicBool>,
277 api_base_url: String,
278 worker_id: String,
279 auth_token: String,
280}
281
282async fn run_dispatch_loop(
283 ctx: SessionContext,
284 mut event_rx: mpsc::UnboundedReceiver<SessionEvent>,
285) -> SessionOutcome {
286 while let Some(event) = event_rx.recv().await {
287 match event {
288 SessionEvent::Disconnected(WsClientError::AuthFailed { reason }) => {
289 return SessionOutcome::AuthFailed(reason);
290 }
291 SessionEvent::Disconnected(_) => return SessionOutcome::Disconnected,
292 SessionEvent::Stopped => return SessionOutcome::Stopped,
293 SessionEvent::Frame(frame) => match frame {
294 WorkerOutbound::Welcome { worker_id: wid, .. } => {
295 push_log(
296 &ctx.logs,
297 "info",
298 "ws",
299 &format!("server welcomed {wid}"),
300 None,
301 );
302 }
303 WorkerOutbound::Offer { claim } => {
304 handle_offer(&ctx, claim);
305 }
306 WorkerOutbound::Error { code, message } => {
307 push_log(
308 &ctx.logs,
309 "error",
310 "ws",
311 &format!("server error {code:?}: {message}"),
312 None,
313 );
314 return match code {
315 crate::ws::types::WorkerErrorCode::AuthFailed => {
316 SessionOutcome::AuthFailed(message)
317 }
318 _ => SessionOutcome::Fatal(message),
319 };
320 }
321 WorkerOutbound::HeartbeatAck
322 | WorkerOutbound::CompleteAck { .. }
323 | WorkerOutbound::FailAck { .. } => {
324 }
326 },
327 }
328 }
329 SessionOutcome::Disconnected
330}
331
332fn handle_offer(ctx: &SessionContext, claim: JobOfferClaim) {
333 let job_id = claim.job_id.clone();
334 push_log(
335 &ctx.logs,
336 "info",
337 "ws",
338 &format!(
339 "offer received {job_id} model={} vram={}",
340 claim.model, claim.vram_gb_estimate
341 ),
342 Some(job_id.clone()),
343 );
344 let sender_for_accept = ctx.sender.clone();
346 let job_id_for_accept = job_id.clone();
347 tokio::spawn(async move {
348 let _ = sender_for_accept
349 .send(&WorkerInbound::Accept {
350 job_id: job_id_for_accept,
351 })
352 .await;
353 });
354
355 let job = claim.into_job_claim();
356 let busy_flag = ctx.busy.clone();
357 busy_flag.store(true, Ordering::SeqCst);
358 let logs_for_task = ctx.logs.clone();
359 let sender_for_task = ctx.sender.clone();
360 let engine_for_task = ctx.engine.clone();
361 let api_base_url = ctx.api_base_url.clone();
362 let worker_id = ctx.worker_id.clone();
363 let auth_token = ctx.auth_token.clone();
364 tokio::spawn(async move {
365 run_offered_job(
366 sender_for_task,
367 engine_for_task,
368 logs_for_task,
369 api_base_url,
370 worker_id,
371 auth_token,
372 job,
373 )
374 .await;
375 busy_flag.store(false, Ordering::SeqCst);
376 });
377}
378
379async fn run_offered_job(
380 sender: WsSender,
381 engine: Arc<dyn Engine>,
382 logs: Arc<Mutex<Vec<LogEntry>>>,
383 api_base_url: String,
384 worker_id: String,
385 auth_token: String,
386 job: crate::types::JobClaim,
387) {
388 let task = job.resolved_task();
389 let task_kind = task.kind();
390 let prompt_for_log = prompt_for(&task);
391 let start = std::time::Instant::now();
392 let dispatch = tokio::task::spawn_blocking({
393 let model = job.model.clone();
394 let task_for_engine = task;
395 let engine = engine.clone();
396 move || -> Result<TaskResult> { engine.dispatch(&model, task_for_engine) }
397 })
398 .await;
399
400 let job_id = job.job_id.clone();
401 match dispatch {
402 Ok(Ok(result)) => {
403 push_log(
404 &logs,
405 "info",
406 "ws",
407 &format!("{} dispatched in {:?}", task_kind.as_str(), start.elapsed()),
408 Some(job_id.clone()),
409 );
410 match result {
411 TaskResult::Image { bytes, ext }
412 | TaskResult::AudioTts { bytes, ext }
413 | TaskResult::Video { bytes, ext } => {
414 let upload_result = tokio::task::spawn_blocking({
417 let api_base_url = api_base_url.clone();
418 let job_id = job_id.clone();
419 let auth_token = auth_token.clone();
420 let worker_id = worker_id.clone();
421 let prompt = prompt_for_log.clone();
422 move || -> Result<()> {
423 let api = ApiClient::new(api_base_url)?;
424 api.complete(&worker_id, &auth_token, &job_id, &ext, &prompt, bytes)
425 }
426 })
427 .await;
428 let msg = match upload_result {
429 Ok(Ok(())) => None,
430 Ok(Err(e)) => Some(e.to_string()),
431 Err(e) => Some(format!("upload task panic: {e}")),
432 };
433 if let Some(msg) = msg {
434 push_log(&logs, "error", "ws", &msg, Some(job_id.clone()));
435 let _ = sender
436 .send(&WorkerInbound::Fail {
437 job_id: job_id.clone(),
438 error: msg,
439 retryable: true,
440 })
441 .await;
442 } else {
443 push_log(
444 &logs,
445 "info",
446 "ws",
447 "binary upload ok",
448 Some(job_id.clone()),
449 );
450 let _ = sender.send(&WorkerInbound::ReadyForMore).await;
452 }
453 }
454 TaskResult::Llm { json } | TaskResult::AudioStt { json } => {
455 let _ = sender
456 .send(&WorkerInbound::CompleteJson {
457 job_id: job_id.clone(),
458 result: json,
459 prompt: Some(prompt_for_log.clone()),
460 })
461 .await;
462 }
463 }
464 }
465 Ok(Err(e)) => {
466 warn!(target: TRACE_TARGET, error = %e, "engine dispatch failed");
467 push_log(
468 &logs,
469 "error",
470 "ws",
471 &format!("dispatch failed: {e}"),
472 Some(job_id.clone()),
473 );
474 let _ = sender
475 .send(&WorkerInbound::Fail {
476 job_id: job_id.clone(),
477 error: e.to_string(),
478 retryable: !is_unsupported_kind(&e),
479 })
480 .await;
481 }
482 Err(e) => {
483 push_log(
484 &logs,
485 "error",
486 "ws",
487 &format!("dispatch task panic: {e}"),
488 Some(job_id.clone()),
489 );
490 let _ = sender
491 .send(&WorkerInbound::Fail {
492 job_id: job_id.clone(),
493 error: e.to_string(),
494 retryable: true,
495 })
496 .await;
497 }
498 }
499}
500
501fn spawn_reader(
502 mut receiver: crate::ws::client::WsReceiver,
503 event_tx: mpsc::UnboundedSender<SessionEvent>,
504) -> tokio::task::JoinHandle<()> {
505 tokio::spawn(async move {
506 loop {
507 match receiver.recv().await {
508 Ok(Some(frame)) => {
509 if event_tx.send(SessionEvent::Frame(frame)).is_err() {
510 break;
511 }
512 }
513 Ok(None) => {
514 let _ =
515 event_tx.send(SessionEvent::Disconnected(WsClientError::ConnectionClosed));
516 break;
517 }
518 Err(e) => {
519 let _ = event_tx.send(SessionEvent::Disconnected(e));
520 break;
521 }
522 }
523 }
524 })
525}
526
527fn spawn_heartbeat_pump(
528 cfg: SharedConfig,
529 sender: WsSender,
530 stop: Arc<AtomicBool>,
531 busy: Arc<AtomicBool>,
532 schedule: SessionSchedule,
533) -> tokio::task::JoinHandle<()> {
534 tokio::spawn(async move {
535 let mut interval = tokio::time::interval(schedule.heartbeat);
536 interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
537 loop {
538 interval.tick().await;
539 if stop.load(Ordering::SeqCst) {
540 break;
541 }
542 let snapshot = build_heartbeat_snapshot(&cfg, &busy);
543 if let Err(e) = sender
544 .send(&WorkerInbound::Heartbeat {
545 capabilities: snapshot.capabilities,
546 current_job_id: snapshot.current_job_id,
547 })
548 .await
549 {
550 warn!(target: TRACE_TARGET, error = %e, "heartbeat send failed");
551 break;
552 }
553 }
554 })
555}
556
557struct HeartbeatSnapshot {
558 capabilities: WorkerCapabilities,
559 current_job_id: Option<String>,
560}
561
562fn build_heartbeat_snapshot(cfg: &SharedConfig, busy: &Arc<AtomicBool>) -> HeartbeatSnapshot {
563 let engine = match crate::engine::build(&cfg.lock()) {
564 Ok(e) => e,
565 Err(_) => return placeholder_snapshot(),
566 };
567 let capabilities = build_capabilities(&cfg.lock(), &*engine);
568 let current_job_id = if busy.load(Ordering::SeqCst) {
569 Some("in-flight".to_string())
570 } else {
571 None
572 };
573 HeartbeatSnapshot {
574 capabilities,
575 current_job_id,
576 }
577}
578
579fn placeholder_snapshot() -> HeartbeatSnapshot {
580 HeartbeatSnapshot {
581 capabilities: WorkerCapabilities {
582 machine_name: String::new(),
583 username: String::new(),
584 agent_version: crate::AGENT_VERSION.to_string(),
585 engine: "synthetic".to_string(),
586 vram_total_gb: 0.0,
587 vram_threshold_gb: 0.0,
588 auto_enabled: false,
589 auto_start: false,
590 supported_models: vec![],
591 task_kinds: vec![],
592 supported_models_per_kind: Default::default(),
593 },
594 current_job_id: None,
595 }
596}
597
598fn spawn_log_shipper_pump(
599 sender: WsSender,
600 logs: Arc<Mutex<Vec<LogEntry>>>,
601 stop: Arc<AtomicBool>,
602 schedule: SessionSchedule,
603) -> tokio::task::JoinHandle<()> {
604 tokio::spawn(async move {
605 let mut interval = tokio::time::interval(schedule.log_flush);
606 interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
607 loop {
608 interval.tick().await;
609 if stop.load(Ordering::SeqCst) {
610 break;
611 }
612 let batch = {
613 let mut guard = logs.lock();
614 if guard.is_empty() {
615 continue;
616 }
617 std::mem::take(&mut *guard)
618 };
619 if let Err(e) = sender
620 .send(&WorkerInbound::LogBatch { entries: batch })
621 .await
622 {
623 warn!(target: TRACE_TARGET, error = %e, "log batch send failed");
624 break;
625 }
626 }
627 })
628}
629
630fn spawn_shutdown_observer(
631 stop: Arc<AtomicBool>,
632 event_tx: mpsc::UnboundedSender<SessionEvent>,
633 schedule: SessionSchedule,
634) -> tokio::task::JoinHandle<()> {
635 tokio::spawn(async move {
636 loop {
637 tokio::time::sleep(schedule.shutdown_tick).await;
638 if stop.load(Ordering::SeqCst) {
639 let _ = event_tx.send(SessionEvent::Stopped);
640 break;
641 }
642 if event_tx.is_closed() {
643 break;
644 }
645 }
646 })
647}
648
649async fn wait_with_stop(total: Duration, stop: &Arc<AtomicBool>, tick: Duration) {
650 let mut elapsed = Duration::ZERO;
651 while elapsed < total {
652 if stop.load(Ordering::SeqCst) {
653 return;
654 }
655 let next = tick.min(total - elapsed);
656 tokio::time::sleep(next).await;
657 elapsed += next;
658 }
659}
660
661fn backoff_for(attempt: u32, schedule: SessionSchedule) -> Duration {
662 let factor = 2u64.saturating_pow(attempt.saturating_sub(1));
663 let raw_ms = schedule.base_backoff_ms.saturating_mul(factor);
664 Duration::from_millis(raw_ms.min(schedule.max_backoff_ms))
665}
666
667#[cfg(test)]
668mod tests {
669 use super::*;
670
671 #[test]
672 fn backoff_grows_exponentially_until_cap() {
673 let schedule = SessionSchedule {
674 base_backoff_ms: 100,
675 max_backoff_ms: 1_000,
676 heartbeat: Duration::from_secs(1),
677 log_flush: Duration::from_secs(1),
678 shutdown_tick: Duration::from_secs(1),
679 };
680 assert_eq!(backoff_for(1, schedule), Duration::from_millis(100));
681 assert_eq!(backoff_for(2, schedule), Duration::from_millis(200));
682 assert_eq!(backoff_for(3, schedule), Duration::from_millis(400));
683 assert_eq!(backoff_for(4, schedule), Duration::from_millis(800));
684 assert_eq!(backoff_for(5, schedule), Duration::from_millis(1_000));
686 assert_eq!(backoff_for(10, schedule), Duration::from_millis(1_000));
687 }
688
689 #[test]
690 fn placeholder_snapshot_has_no_current_job() {
691 let snap = placeholder_snapshot();
692 assert!(snap.current_job_id.is_none());
693 assert_eq!(snap.capabilities.engine, "synthetic");
694 }
695}