Skip to main content

studio_worker/ws/
session.rs

1//! Long-running WebSocket session that owns the worker's lifecycle.
2//!
3//! Replaces the four polling loops (`spawn_heartbeat`, `spawn_claim_loop`,
4//! `spawn_log_shipper`, plus the implicit completion path) with a single
5//! `spawn_ws_session` coordinator + a small handful of helper tasks that
6//! all push frames through a shared `WsSender`.
7//!
8//! Reconnect policy: on a transport error or non-auth close, back off
9//! `BASE_BACKOFF_MS * 2^attempt` and try again, up to
10//! `cfg.ws_reconnect_attempts`.  Out of retries \u2192 return `Err` and the
11//! systemd / launchd unit restarts the binary.
12use std::sync::{
13    atomic::{AtomicBool, Ordering},
14    Arc,
15};
16use std::time::Duration;
17
18use anyhow::{anyhow, Result};
19use parking_lot::Mutex;
20use tokio::sync::mpsc;
21use tracing::{info, warn};
22
23use crate::config::SharedConfig;
24use crate::engine::Engine;
25use crate::http::ApiClient;
26use crate::runtime::{
27    build_capabilities, is_unsupported_kind, prompt_for, push_log, WorkerObservers,
28};
29use crate::types::{LogEntry, TaskResult, WorkerCapabilities};
30use crate::ws::client::{connect, WsClientError, WsSender};
31use crate::ws::types::{HelloFrame, JobOfferClaim, WorkerInbound, WorkerOutbound};
32
33/// Tracing target used for every event emitted by the session.
34const TRACE_TARGET: &str = "studio_worker::ws::session";
35
36const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5);
37const LOG_FLUSH_INTERVAL: Duration = Duration::from_secs(1);
38const SHUTDOWN_TICK: Duration = Duration::from_millis(250);
39const BASE_BACKOFF_MS: u64 = 1_000;
40const MAX_BACKOFF_MS: u64 = 30_000;
41const DEFAULT_RECONNECT_ATTEMPTS: u32 = 5;
42
43/// Outcome of a single session attempt.  The reconnect loop decides
44/// whether to back off + retry based on the variant.
45#[derive(Debug)]
46pub enum SessionOutcome {
47    /// Caller requested shutdown; do not reconnect.
48    Stopped,
49    /// Lost the connection unexpectedly; reconnect after backoff.
50    Disconnected,
51    /// Server rejected auth; do not reconnect.
52    AuthFailed(String),
53    /// Server sent a fatal error frame; do not reconnect.
54    Fatal(String),
55}
56
57/// Tunables for the session loop \u2014 dialed down in tests.
58#[derive(Debug, Clone, Copy)]
59pub struct SessionSchedule {
60    pub heartbeat: Duration,
61    pub log_flush: Duration,
62    pub shutdown_tick: Duration,
63    pub base_backoff_ms: u64,
64    pub max_backoff_ms: u64,
65}
66
67impl Default for SessionSchedule {
68    fn default() -> Self {
69        Self {
70            heartbeat: HEARTBEAT_INTERVAL,
71            log_flush: LOG_FLUSH_INTERVAL,
72            shutdown_tick: SHUTDOWN_TICK,
73            base_backoff_ms: BASE_BACKOFF_MS,
74            max_backoff_ms: MAX_BACKOFF_MS,
75        }
76    }
77}
78
79impl SessionSchedule {
80    pub fn fast_for_tests() -> Self {
81        Self {
82            heartbeat: Duration::from_millis(5),
83            log_flush: Duration::from_millis(5),
84            shutdown_tick: Duration::from_millis(5),
85            base_backoff_ms: 1,
86            max_backoff_ms: 10,
87        }
88    }
89}
90
91/// Top-level driver: connect, run a session, reconnect on disconnect,
92/// give up after `cfg.ws_reconnect_attempts` failures.
93pub async fn spawn_ws_session(
94    cfg: SharedConfig,
95    stop: Arc<AtomicBool>,
96    logs: Arc<Mutex<Vec<LogEntry>>>,
97    busy: Arc<AtomicBool>,
98    _observers: WorkerObservers,
99    schedule: SessionSchedule,
100) -> Result<()> {
101    // `observers` is threaded through so the optional native UI gets a
102    // live view of the worker.  Wired into the dispatch path in a
103    // follow-up commit; for now the slots stay empty (Default::default
104    // on the WorkerObservers).  The headless build doesn't care.
105    let _ = &_observers;
106    let max_attempts = {
107        let guard = cfg.lock();
108        guard
109            .ws_reconnect_attempts
110            .unwrap_or(DEFAULT_RECONNECT_ATTEMPTS)
111    };
112
113    let mut attempt: u32 = 0;
114    loop {
115        if stop.load(Ordering::SeqCst) {
116            return Ok(());
117        }
118        match run_one_session(&cfg, &stop, &logs, &busy, schedule).await {
119            Ok(SessionOutcome::Stopped) => return Ok(()),
120            Ok(SessionOutcome::AuthFailed(reason)) => {
121                push_log(
122                    &logs,
123                    "error",
124                    "ws",
125                    &format!("auth failed: {reason}. Re-register the worker."),
126                    None,
127                );
128                return Err(anyhow!("ws auth failed: {reason}"));
129            }
130            Ok(SessionOutcome::Fatal(reason)) => {
131                push_log(&logs, "error", "ws", &format!("fatal: {reason}"), None);
132                return Err(anyhow!("ws fatal: {reason}"));
133            }
134            Ok(SessionOutcome::Disconnected) | Err(_) => {
135                attempt += 1;
136                if max_attempts > 0 && attempt > max_attempts {
137                    push_log(
138                        &logs,
139                        "error",
140                        "ws",
141                        &format!("giving up after {attempt} reconnect attempts"),
142                        None,
143                    );
144                    return Err(anyhow!("ws reconnect cap reached"));
145                }
146                let backoff = backoff_for(attempt, schedule);
147                push_log(
148                    &logs,
149                    "warn",
150                    "ws",
151                    &format!(
152                        "disconnected; reconnect attempt {attempt} in {}ms",
153                        backoff.as_millis()
154                    ),
155                    None,
156                );
157                wait_with_stop(backoff, &stop, schedule.shutdown_tick).await;
158            }
159        }
160    }
161}
162
163/// One end-to-end session attempt: connect, hello, run until shutdown
164/// or disconnect.
165async fn run_one_session(
166    cfg: &SharedConfig,
167    stop: &Arc<AtomicBool>,
168    logs: &Arc<Mutex<Vec<LogEntry>>>,
169    busy: &Arc<AtomicBool>,
170    schedule: SessionSchedule,
171) -> Result<SessionOutcome> {
172    let (api_base_url, worker_id, auth_token) = {
173        let guard = cfg.lock();
174        (
175            guard.api_base_url.clone(),
176            guard.worker_id.clone().unwrap_or_default(),
177            guard.auth_token.clone().unwrap_or_default(),
178        )
179    };
180    if worker_id.is_empty() || auth_token.is_empty() {
181        return Ok(SessionOutcome::Fatal(
182            "worker_id or auth_token missing; run register".to_string(),
183        ));
184    }
185
186    push_log(
187        logs,
188        "info",
189        "ws",
190        &format!("connecting to {api_base_url}"),
191        None,
192    );
193    let client = match connect(&api_base_url, &worker_id, &auth_token).await {
194        Ok(c) => c,
195        Err(WsClientError::AuthFailed { reason }) => {
196            return Ok(SessionOutcome::AuthFailed(reason));
197        }
198        Err(e) => {
199            push_log(logs, "warn", "ws", &format!("connect failed: {e}"), None);
200            return Ok(SessionOutcome::Disconnected);
201        }
202    };
203    let (sender, receiver) = client.split();
204
205    // Send hello with the current capabilities.
206    let engine = crate::engine::build(&cfg.lock())?;
207    let capabilities = build_capabilities(&cfg.lock(), &*engine);
208    sender
209        .send(&WorkerInbound::Hello(HelloFrame {
210            auth_token: auth_token.clone(),
211            capabilities: capabilities.clone(),
212        }))
213        .await
214        .map_err(|e| anyhow!("hello send failed: {e}"))?;
215    info!(target: TRACE_TARGET, worker_id = %worker_id, "hello sent");
216
217    let (event_tx, event_rx) = mpsc::unbounded_channel::<SessionEvent>();
218
219    // Reader task: pump frames into the event channel.
220    let reader = spawn_reader(receiver, event_tx.clone());
221
222    // Heartbeat task.
223    let heartbeat = spawn_heartbeat_pump(
224        cfg.clone(),
225        sender.clone(),
226        stop.clone(),
227        busy.clone(),
228        schedule,
229    );
230
231    // Log shipper task.
232    let log_shipper = spawn_log_shipper_pump(sender.clone(), logs.clone(), stop.clone(), schedule);
233
234    // Shutdown observer: ticks until stop flag is set, then drops the channel.
235    let shutdown_observer = spawn_shutdown_observer(stop.clone(), event_tx.clone(), schedule);
236    drop(event_tx);
237
238    let engine_arc: Arc<dyn Engine> = engine.into();
239    let ctx = SessionContext {
240        sender: sender.clone(),
241        engine: engine_arc,
242        logs: logs.clone(),
243        busy: busy.clone(),
244        api_base_url: api_base_url.clone(),
245        worker_id: worker_id.clone(),
246        auth_token: auth_token.clone(),
247    };
248    let outcome = run_dispatch_loop(ctx, event_rx).await;
249
250    // Best-effort graceful close + task drain.
251    let _ = sender.close(1000, "session ended").await;
252    let _ = reader.await;
253    let _ = heartbeat.await;
254    let _ = log_shipper.await;
255    let _ = shutdown_observer.await;
256    Ok(outcome)
257}
258
259/// All the events the dispatch loop reacts to.
260#[derive(Debug)]
261enum SessionEvent {
262    /// Frame arrived from the server.
263    Frame(WorkerOutbound),
264    /// Engine task finished (success or fail already reported).
265    Stopped,
266    /// Reader hit EOF / error.
267    Disconnected(WsClientError),
268}
269
270/// Bundle of immutable per-session settings the dispatcher passes
271/// around — keeps clippy's `too_many_arguments` lint happy.
272struct SessionContext {
273    sender: WsSender,
274    engine: Arc<dyn Engine>,
275    logs: Arc<Mutex<Vec<LogEntry>>>,
276    busy: Arc<AtomicBool>,
277    api_base_url: String,
278    worker_id: String,
279    auth_token: String,
280}
281
282async fn run_dispatch_loop(
283    ctx: SessionContext,
284    mut event_rx: mpsc::UnboundedReceiver<SessionEvent>,
285) -> SessionOutcome {
286    while let Some(event) = event_rx.recv().await {
287        match event {
288            SessionEvent::Disconnected(WsClientError::AuthFailed { reason }) => {
289                return SessionOutcome::AuthFailed(reason);
290            }
291            SessionEvent::Disconnected(_) => return SessionOutcome::Disconnected,
292            SessionEvent::Stopped => return SessionOutcome::Stopped,
293            SessionEvent::Frame(frame) => match frame {
294                WorkerOutbound::Welcome { worker_id: wid, .. } => {
295                    push_log(
296                        &ctx.logs,
297                        "info",
298                        "ws",
299                        &format!("server welcomed {wid}"),
300                        None,
301                    );
302                }
303                WorkerOutbound::Offer { claim } => {
304                    handle_offer(&ctx, claim);
305                }
306                WorkerOutbound::Error { code, message } => {
307                    push_log(
308                        &ctx.logs,
309                        "error",
310                        "ws",
311                        &format!("server error {code:?}: {message}"),
312                        None,
313                    );
314                    return match code {
315                        crate::ws::types::WorkerErrorCode::AuthFailed => {
316                            SessionOutcome::AuthFailed(message)
317                        }
318                        _ => SessionOutcome::Fatal(message),
319                    };
320                }
321                WorkerOutbound::HeartbeatAck
322                | WorkerOutbound::CompleteAck { .. }
323                | WorkerOutbound::FailAck { .. } => {
324                    // Acks are best-effort; ignore.
325                }
326            },
327        }
328    }
329    SessionOutcome::Disconnected
330}
331
332fn handle_offer(ctx: &SessionContext, claim: JobOfferClaim) {
333    let job_id = claim.job_id.clone();
334    push_log(
335        &ctx.logs,
336        "info",
337        "ws",
338        &format!(
339            "offer received {job_id} model={} vram={}",
340            claim.model, claim.vram_gb_estimate
341        ),
342        Some(job_id.clone()),
343    );
344    // Send accept first so the server marks the offer consumed.
345    let sender_for_accept = ctx.sender.clone();
346    let job_id_for_accept = job_id.clone();
347    tokio::spawn(async move {
348        let _ = sender_for_accept
349            .send(&WorkerInbound::Accept {
350                job_id: job_id_for_accept,
351            })
352            .await;
353    });
354
355    let job = claim.into_job_claim();
356    let busy_flag = ctx.busy.clone();
357    busy_flag.store(true, Ordering::SeqCst);
358    let logs_for_task = ctx.logs.clone();
359    let sender_for_task = ctx.sender.clone();
360    let engine_for_task = ctx.engine.clone();
361    let api_base_url = ctx.api_base_url.clone();
362    let worker_id = ctx.worker_id.clone();
363    let auth_token = ctx.auth_token.clone();
364    tokio::spawn(async move {
365        run_offered_job(
366            sender_for_task,
367            engine_for_task,
368            logs_for_task,
369            api_base_url,
370            worker_id,
371            auth_token,
372            job,
373        )
374        .await;
375        busy_flag.store(false, Ordering::SeqCst);
376    });
377}
378
379async fn run_offered_job(
380    sender: WsSender,
381    engine: Arc<dyn Engine>,
382    logs: Arc<Mutex<Vec<LogEntry>>>,
383    api_base_url: String,
384    worker_id: String,
385    auth_token: String,
386    job: crate::types::JobClaim,
387) {
388    let task = job.resolved_task();
389    let task_kind = task.kind();
390    let prompt_for_log = prompt_for(&task);
391    let start = std::time::Instant::now();
392    let dispatch = tokio::task::spawn_blocking({
393        let model = job.model.clone();
394        let task_for_engine = task;
395        let engine = engine.clone();
396        move || -> Result<TaskResult> { engine.dispatch(&model, task_for_engine) }
397    })
398    .await;
399
400    let job_id = job.job_id.clone();
401    match dispatch {
402        Ok(Ok(result)) => {
403            push_log(
404                &logs,
405                "info",
406                "ws",
407                &format!("{} dispatched in {:?}", task_kind.as_str(), start.elapsed()),
408                Some(job_id.clone()),
409            );
410            match result {
411                TaskResult::Image { bytes, ext }
412                | TaskResult::AudioTts { bytes, ext }
413                | TaskResult::Video { bytes, ext } => {
414                    // Binary outputs go via HTTP multipart \u2014 the only
415                    // worker-side HTTP route that survives the migration.
416                    let upload_result = tokio::task::spawn_blocking({
417                        let api_base_url = api_base_url.clone();
418                        let job_id = job_id.clone();
419                        let auth_token = auth_token.clone();
420                        let worker_id = worker_id.clone();
421                        let prompt = prompt_for_log.clone();
422                        move || -> Result<()> {
423                            let api = ApiClient::new(api_base_url)?;
424                            api.complete(&worker_id, &auth_token, &job_id, &ext, &prompt, bytes)
425                        }
426                    })
427                    .await;
428                    let msg = match upload_result {
429                        Ok(Ok(())) => None,
430                        Ok(Err(e)) => Some(e.to_string()),
431                        Err(e) => Some(format!("upload task panic: {e}")),
432                    };
433                    if let Some(msg) = msg {
434                        push_log(&logs, "error", "ws", &msg, Some(job_id.clone()));
435                        let _ = sender
436                            .send(&WorkerInbound::Fail {
437                                job_id: job_id.clone(),
438                                error: msg,
439                                retryable: true,
440                            })
441                            .await;
442                    } else {
443                        push_log(
444                            &logs,
445                            "info",
446                            "ws",
447                            "binary upload ok",
448                            Some(job_id.clone()),
449                        );
450                        // Nudge the server so it offers the next job.
451                        let _ = sender.send(&WorkerInbound::ReadyForMore).await;
452                    }
453                }
454                TaskResult::Llm { json } | TaskResult::AudioStt { json } => {
455                    let _ = sender
456                        .send(&WorkerInbound::CompleteJson {
457                            job_id: job_id.clone(),
458                            result: json,
459                            prompt: Some(prompt_for_log.clone()),
460                        })
461                        .await;
462                }
463            }
464        }
465        Ok(Err(e)) => {
466            warn!(target: TRACE_TARGET, error = %e, "engine dispatch failed");
467            push_log(
468                &logs,
469                "error",
470                "ws",
471                &format!("dispatch failed: {e}"),
472                Some(job_id.clone()),
473            );
474            let _ = sender
475                .send(&WorkerInbound::Fail {
476                    job_id: job_id.clone(),
477                    error: e.to_string(),
478                    retryable: !is_unsupported_kind(&e),
479                })
480                .await;
481        }
482        Err(e) => {
483            push_log(
484                &logs,
485                "error",
486                "ws",
487                &format!("dispatch task panic: {e}"),
488                Some(job_id.clone()),
489            );
490            let _ = sender
491                .send(&WorkerInbound::Fail {
492                    job_id: job_id.clone(),
493                    error: e.to_string(),
494                    retryable: true,
495                })
496                .await;
497        }
498    }
499}
500
501fn spawn_reader(
502    mut receiver: crate::ws::client::WsReceiver,
503    event_tx: mpsc::UnboundedSender<SessionEvent>,
504) -> tokio::task::JoinHandle<()> {
505    tokio::spawn(async move {
506        loop {
507            match receiver.recv().await {
508                Ok(Some(frame)) => {
509                    if event_tx.send(SessionEvent::Frame(frame)).is_err() {
510                        break;
511                    }
512                }
513                Ok(None) => {
514                    let _ =
515                        event_tx.send(SessionEvent::Disconnected(WsClientError::ConnectionClosed));
516                    break;
517                }
518                Err(e) => {
519                    let _ = event_tx.send(SessionEvent::Disconnected(e));
520                    break;
521                }
522            }
523        }
524    })
525}
526
527fn spawn_heartbeat_pump(
528    cfg: SharedConfig,
529    sender: WsSender,
530    stop: Arc<AtomicBool>,
531    busy: Arc<AtomicBool>,
532    schedule: SessionSchedule,
533) -> tokio::task::JoinHandle<()> {
534    tokio::spawn(async move {
535        let mut interval = tokio::time::interval(schedule.heartbeat);
536        interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
537        loop {
538            interval.tick().await;
539            if stop.load(Ordering::SeqCst) {
540                break;
541            }
542            let snapshot = build_heartbeat_snapshot(&cfg, &busy);
543            if let Err(e) = sender
544                .send(&WorkerInbound::Heartbeat {
545                    capabilities: snapshot.capabilities,
546                    current_job_id: snapshot.current_job_id,
547                })
548                .await
549            {
550                warn!(target: TRACE_TARGET, error = %e, "heartbeat send failed");
551                break;
552            }
553        }
554    })
555}
556
557struct HeartbeatSnapshot {
558    capabilities: WorkerCapabilities,
559    current_job_id: Option<String>,
560}
561
562fn build_heartbeat_snapshot(cfg: &SharedConfig, busy: &Arc<AtomicBool>) -> HeartbeatSnapshot {
563    let engine = match crate::engine::build(&cfg.lock()) {
564        Ok(e) => e,
565        Err(_) => return placeholder_snapshot(),
566    };
567    let capabilities = build_capabilities(&cfg.lock(), &*engine);
568    let current_job_id = if busy.load(Ordering::SeqCst) {
569        Some("in-flight".to_string())
570    } else {
571        None
572    };
573    HeartbeatSnapshot {
574        capabilities,
575        current_job_id,
576    }
577}
578
579fn placeholder_snapshot() -> HeartbeatSnapshot {
580    HeartbeatSnapshot {
581        capabilities: WorkerCapabilities {
582            machine_name: String::new(),
583            username: String::new(),
584            agent_version: crate::AGENT_VERSION.to_string(),
585            engine: "synthetic".to_string(),
586            vram_total_gb: 0.0,
587            vram_threshold_gb: 0.0,
588            auto_enabled: false,
589            auto_start: false,
590            supported_models: vec![],
591            task_kinds: vec![],
592            supported_models_per_kind: Default::default(),
593        },
594        current_job_id: None,
595    }
596}
597
598fn spawn_log_shipper_pump(
599    sender: WsSender,
600    logs: Arc<Mutex<Vec<LogEntry>>>,
601    stop: Arc<AtomicBool>,
602    schedule: SessionSchedule,
603) -> tokio::task::JoinHandle<()> {
604    tokio::spawn(async move {
605        let mut interval = tokio::time::interval(schedule.log_flush);
606        interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
607        loop {
608            interval.tick().await;
609            if stop.load(Ordering::SeqCst) {
610                break;
611            }
612            let batch = {
613                let mut guard = logs.lock();
614                if guard.is_empty() {
615                    continue;
616                }
617                std::mem::take(&mut *guard)
618            };
619            if let Err(e) = sender
620                .send(&WorkerInbound::LogBatch { entries: batch })
621                .await
622            {
623                warn!(target: TRACE_TARGET, error = %e, "log batch send failed");
624                break;
625            }
626        }
627    })
628}
629
630fn spawn_shutdown_observer(
631    stop: Arc<AtomicBool>,
632    event_tx: mpsc::UnboundedSender<SessionEvent>,
633    schedule: SessionSchedule,
634) -> tokio::task::JoinHandle<()> {
635    tokio::spawn(async move {
636        loop {
637            tokio::time::sleep(schedule.shutdown_tick).await;
638            if stop.load(Ordering::SeqCst) {
639                let _ = event_tx.send(SessionEvent::Stopped);
640                break;
641            }
642            if event_tx.is_closed() {
643                break;
644            }
645        }
646    })
647}
648
649async fn wait_with_stop(total: Duration, stop: &Arc<AtomicBool>, tick: Duration) {
650    let mut elapsed = Duration::ZERO;
651    while elapsed < total {
652        if stop.load(Ordering::SeqCst) {
653            return;
654        }
655        let next = tick.min(total - elapsed);
656        tokio::time::sleep(next).await;
657        elapsed += next;
658    }
659}
660
661fn backoff_for(attempt: u32, schedule: SessionSchedule) -> Duration {
662    let factor = 2u64.saturating_pow(attempt.saturating_sub(1));
663    let raw_ms = schedule.base_backoff_ms.saturating_mul(factor);
664    Duration::from_millis(raw_ms.min(schedule.max_backoff_ms))
665}
666
667#[cfg(test)]
668mod tests {
669    use super::*;
670
671    #[test]
672    fn backoff_grows_exponentially_until_cap() {
673        let schedule = SessionSchedule {
674            base_backoff_ms: 100,
675            max_backoff_ms: 1_000,
676            heartbeat: Duration::from_secs(1),
677            log_flush: Duration::from_secs(1),
678            shutdown_tick: Duration::from_secs(1),
679        };
680        assert_eq!(backoff_for(1, schedule), Duration::from_millis(100));
681        assert_eq!(backoff_for(2, schedule), Duration::from_millis(200));
682        assert_eq!(backoff_for(3, schedule), Duration::from_millis(400));
683        assert_eq!(backoff_for(4, schedule), Duration::from_millis(800));
684        // Capped.
685        assert_eq!(backoff_for(5, schedule), Duration::from_millis(1_000));
686        assert_eq!(backoff_for(10, schedule), Duration::from_millis(1_000));
687    }
688
689    #[test]
690    fn placeholder_snapshot_has_no_current_job() {
691        let snap = placeholder_snapshot();
692        assert!(snap.current_job_id.is_none());
693        assert_eq!(snap.capabilities.engine, "synthetic");
694    }
695}