Skip to main content

studio_worker/ws/
session.rs

1//! Long-running WebSocket session that owns the worker's lifecycle.
2//!
3//! Replaces the four polling loops (`spawn_heartbeat`, `spawn_claim_loop`,
4//! `spawn_log_shipper`, plus the implicit completion path) with a single
5//! `spawn_ws_session` coordinator + a small handful of helper tasks that
6//! all push frames through a shared `WsSender`.
7//!
8//! Reconnect policy: on a transport error or non-auth close, back off
9//! `BASE_BACKOFF_MS * 2^attempt` and try again, up to
10//! `cfg.ws_reconnect_attempts`.  Out of retries \u2192 return `Err` and the
11//! systemd / launchd unit restarts the binary.
12use std::sync::{
13    atomic::{AtomicBool, Ordering},
14    Arc,
15};
16use std::time::Duration;
17
18use anyhow::{anyhow, Result};
19use parking_lot::Mutex;
20use tokio::sync::mpsc;
21use tracing::{info, warn};
22
23use crate::config::SharedConfig;
24use crate::engine::Engine;
25use crate::http::ApiClient;
26use crate::runtime::{
27    is_unsupported_kind, prompt_for, push_log_with_observers, record_recent_job, truncate_prompt,
28    wait_with_stop, CurrentJob, JobOutcome, RecentJob, WorkerObservers,
29};
30use crate::types::{LogEntry, TaskResult, WorkerCapabilities};
31use crate::ws::client::{connect, WsClientError, WsResult, WsSender};
32use crate::ws::types::{HelloFrame, JobOfferClaim, WorkerInbound, WorkerOutbound};
33
34/// Tracing target used for every event emitted by the session.
35const TRACE_TARGET: &str = "studio_worker::ws::session";
36
37const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5);
38const LOG_FLUSH_INTERVAL: Duration = Duration::from_secs(1);
39const SHUTDOWN_TICK: Duration = Duration::from_millis(250);
40const BASE_BACKOFF_MS: u64 = 1_000;
41const MAX_BACKOFF_MS: u64 = 30_000;
42const DEFAULT_RECONNECT_ATTEMPTS: u32 = 5;
43/// If no frame (not even a `heartbeatAck`) arrives from the studio within this window, treat the
44/// connection as dead and tear the session down. The studio acks every heartbeat (~5s), so a live
45/// connection always yields a frame well inside this budget; the only time it elapses is a
46/// half-open / dead-peer socket where the reader would otherwise block on `source.next()` forever.
47const READ_IDLE_TIMEOUT: Duration = Duration::from_secs(20);
48
49/// Outcome of a single session attempt.  The reconnect loop decides
50/// whether to back off + retry based on the variant.
51#[derive(Debug)]
52pub enum SessionOutcome {
53    /// Caller requested shutdown; do not reconnect.
54    Stopped,
55    /// Lost the connection unexpectedly; reconnect after backoff.
56    Disconnected,
57    /// Server rejected auth; do not reconnect.
58    AuthFailed(String),
59    /// Server sent a fatal error frame; do not reconnect.
60    Fatal(String),
61}
62
63/// Tunables for the session loop \u2014 dialed down in tests.
64#[derive(Debug, Clone, Copy)]
65pub struct SessionSchedule {
66    pub heartbeat: Duration,
67    pub log_flush: Duration,
68    pub shutdown_tick: Duration,
69    pub base_backoff_ms: u64,
70    pub max_backoff_ms: u64,
71    /// Reader gives up + reports a disconnect if no server frame arrives within this window.
72    pub read_idle_timeout: Duration,
73}
74
75impl Default for SessionSchedule {
76    fn default() -> Self {
77        Self {
78            heartbeat: HEARTBEAT_INTERVAL,
79            log_flush: LOG_FLUSH_INTERVAL,
80            shutdown_tick: SHUTDOWN_TICK,
81            base_backoff_ms: BASE_BACKOFF_MS,
82            max_backoff_ms: MAX_BACKOFF_MS,
83            read_idle_timeout: READ_IDLE_TIMEOUT,
84        }
85    }
86}
87
88impl SessionSchedule {
89    pub fn fast_for_tests() -> Self {
90        Self {
91            heartbeat: Duration::from_millis(5),
92            log_flush: Duration::from_millis(5),
93            shutdown_tick: Duration::from_millis(5),
94            base_backoff_ms: 1,
95            max_backoff_ms: 10,
96            // Generous vs the 5ms heartbeat so the existing fast tests never trip it; the
97            // silent-connection test overrides this with a tiny value to exercise the timeout.
98            read_idle_timeout: Duration::from_secs(5),
99        }
100    }
101}
102
103/// Top-level driver: connect, run a session, reconnect on disconnect,
104/// give up after `cfg.ws_reconnect_attempts` failures.
105///
106/// `paused` is a runtime-only flag (not persisted to `Config`).  When
107/// true, the heartbeat reports `autoEnabled = false` and incoming
108/// offers are rejected, so the studio stops sending new jobs.  In-
109/// flight work is allowed to finish.
110#[cfg_attr(coverage_nightly, coverage(off))]
111pub async fn spawn_ws_session(
112    cfg: SharedConfig,
113    stop: Arc<AtomicBool>,
114    logs: Arc<Mutex<Vec<LogEntry>>>,
115    busy: Arc<AtomicBool>,
116    paused: Arc<AtomicBool>,
117    observers: WorkerObservers,
118    schedule: SessionSchedule,
119) -> Result<()> {
120    let max_attempts = {
121        let guard = cfg.lock();
122        guard
123            .ws_reconnect_attempts
124            .unwrap_or(DEFAULT_RECONNECT_ATTEMPTS)
125    };
126
127    let mut attempt: u32 = 0;
128    let mut waiting_for_creds_logged = false;
129    loop {
130        if stop.load(Ordering::SeqCst) {
131            return Ok(());
132        }
133        // Credentials may not exist yet (first launch — the
134        // auto-register loop is racing to populate them).  Poll the
135        // shared config until both `worker_id` and `auth_token` show
136        // up, instead of failing the whole session loop.  This is
137        // what lets the UI's parallel auto-register + WS flow work.
138        if !has_credentials(&cfg) {
139            if !waiting_for_creds_logged {
140                push_log_with_observers(
141                    &logs,
142                    Some(&observers),
143                    "info",
144                    "ws",
145                    "waiting for operator approval before opening the session",
146                    None,
147                );
148                waiting_for_creds_logged = true;
149            }
150            wait_with_stop(Duration::from_secs(1), &stop, schedule.shutdown_tick).await;
151            continue;
152        }
153        waiting_for_creds_logged = false;
154
155        let welcomed = AtomicBool::new(false);
156        match run_one_session(
157            &cfg, &stop, &logs, &busy, &paused, &observers, schedule, &welcomed,
158        )
159        .await
160        {
161            Ok(SessionOutcome::Stopped) => return Ok(()),
162            Ok(SessionOutcome::AuthFailed(reason)) => {
163                push_log_with_observers(
164                    &logs,
165                    Some(&observers),
166                    "error",
167                    "ws",
168                    &format!("auth failed: {reason}. Re-register the worker."),
169                    None,
170                );
171                return Err(anyhow!("ws auth failed: {reason}"));
172            }
173            Ok(SessionOutcome::Fatal(reason)) => {
174                push_log_with_observers(
175                    &logs,
176                    Some(&observers),
177                    "error",
178                    "ws",
179                    &format!("fatal: {reason}"),
180                    None,
181                );
182                return Err(anyhow!("ws fatal: {reason}"));
183            }
184            Ok(SessionOutcome::Disconnected) | Err(_) => {
185                // A session that successfully connected shouldn't count its later drop toward the
186                // connect-failure cap — only consecutive failures to connect should accumulate, so
187                // a long-lived worker isn't killed by transient mid-session disconnects.
188                if welcomed.load(Ordering::SeqCst) {
189                    attempt = 0;
190                }
191                attempt += 1;
192                if max_attempts > 0 && attempt > max_attempts {
193                    push_log_with_observers(
194                        &logs,
195                        Some(&observers),
196                        "error",
197                        "ws",
198                        &format!("giving up after {attempt} reconnect attempts"),
199                        None,
200                    );
201                    return Err(anyhow!("ws reconnect cap reached"));
202                }
203                let backoff = backoff_for(attempt, schedule);
204                push_log_with_observers(
205                    &logs,
206                    Some(&observers),
207                    "warn",
208                    "ws",
209                    &format!(
210                        "disconnected; reconnect attempt {attempt} in {}ms",
211                        backoff.as_millis()
212                    ),
213                    None,
214                );
215                wait_with_stop(backoff, &stop, schedule.shutdown_tick).await;
216            }
217        }
218    }
219}
220
221/// Outcome of waiting for the server's Welcome (or an error) right
222/// after sending Hello.  Drives the precondition gate that keeps the
223/// heartbeat / log-shipper pumps from racing the studio's async auth
224/// flow.
225enum WelcomeOutcome {
226    Welcomed,
227    AuthFailed(String),
228    Fatal(String),
229    Disconnected,
230}
231
232/// Pull events from the reader until we see a Welcome (success) or an
233/// Error / Disconnect (failure).  Any acks / offers that arrive
234/// before the Welcome are pushed into the logs and discarded — the
235/// studio shouldn't be sending them at this stage, but if it does,
236/// the dispatch loop will pick the next ones up.
237#[cfg_attr(coverage_nightly, coverage(off))]
238async fn wait_for_welcome(
239    event_rx: &mut mpsc::UnboundedReceiver<SessionEvent>,
240    logs: &Arc<Mutex<Vec<LogEntry>>>,
241    observers: &WorkerObservers,
242) -> WelcomeOutcome {
243    while let Some(event) = event_rx.recv().await {
244        match event {
245            SessionEvent::Frame(WorkerOutbound::Welcome { worker_id: wid, .. }) => {
246                push_log_with_observers(
247                    logs,
248                    Some(observers),
249                    "info",
250                    "ws",
251                    &format!("server welcomed {wid}"),
252                    None,
253                );
254                return WelcomeOutcome::Welcomed;
255            }
256            SessionEvent::Frame(WorkerOutbound::Error { code, message }) => {
257                push_log_with_observers(
258                    logs,
259                    Some(observers),
260                    "error",
261                    "ws",
262                    &format!("server error before welcome {code:?}: {message}"),
263                    None,
264                );
265                return match code {
266                    crate::ws::types::WorkerErrorCode::AuthFailed => {
267                        WelcomeOutcome::AuthFailed(message)
268                    }
269                    _ => WelcomeOutcome::Fatal(message),
270                };
271            }
272            SessionEvent::Frame(other) => {
273                push_log_with_observers(
274                    logs,
275                    Some(observers),
276                    "warn",
277                    "ws",
278                    &format!("server sent unexpected frame before welcome: {other:?}"),
279                    None,
280                );
281                // Keep waiting — maybe the next frame is Welcome.
282            }
283            SessionEvent::Disconnected(WsClientError::AuthFailed { reason }) => {
284                return WelcomeOutcome::AuthFailed(reason);
285            }
286            SessionEvent::Disconnected(_) => return WelcomeOutcome::Disconnected,
287            SessionEvent::Stopped => return WelcomeOutcome::Disconnected,
288        }
289    }
290    WelcomeOutcome::Disconnected
291}
292
293/// True iff the shared config has both `worker_id` and `auth_token`
294/// populated.  The auto-register flow writes them through on
295/// approval.
296fn has_credentials(cfg: &SharedConfig) -> bool {
297    let guard = cfg.lock();
298    guard
299        .worker_id
300        .as_deref()
301        .map(|s| !s.is_empty())
302        .unwrap_or(false)
303        && guard
304            .auth_token
305            .as_deref()
306            .map(|s| !s.is_empty())
307            .unwrap_or(false)
308}
309
310/// One end-to-end session attempt: connect, hello, run until shutdown
311/// or disconnect.
312#[cfg_attr(coverage_nightly, coverage(off))]
313// Eight collaborators (config + shared flags + observers + schedule + welcomed signal);
314// grouping them adds indirection without improving readability.
315#[allow(clippy::too_many_arguments)]
316async fn run_one_session(
317    cfg: &SharedConfig,
318    stop: &Arc<AtomicBool>,
319    logs: &Arc<Mutex<Vec<LogEntry>>>,
320    busy: &Arc<AtomicBool>,
321    paused: &Arc<AtomicBool>,
322    observers: &WorkerObservers,
323    schedule: SessionSchedule,
324    welcomed: &AtomicBool,
325) -> Result<SessionOutcome> {
326    let (api_base_url, worker_id, auth_token) = {
327        let guard = cfg.lock();
328        (
329            guard.api_base_url.clone(),
330            guard.worker_id.clone().unwrap_or_default(),
331            guard.auth_token.clone().unwrap_or_default(),
332        )
333    };
334    if worker_id.is_empty() || auth_token.is_empty() {
335        return Ok(SessionOutcome::Fatal(
336            "worker_id or auth_token missing; run register".to_string(),
337        ));
338    }
339
340    push_log_with_observers(
341        logs,
342        Some(observers),
343        "info",
344        "ws",
345        &format!("connecting to {api_base_url}"),
346        None,
347    );
348    let client = match connect(&api_base_url, &worker_id, &auth_token).await {
349        Ok(c) => c,
350        Err(WsClientError::AuthFailed { reason }) => {
351            return Ok(SessionOutcome::AuthFailed(reason));
352        }
353        Err(e) => {
354            push_log_with_observers(
355                logs,
356                Some(observers),
357                "warn",
358                "ws",
359                &format!("connect failed: {e}"),
360                None,
361            );
362            return Ok(SessionOutcome::Disconnected);
363        }
364    };
365    let (sender, receiver) = client.split();
366
367    // Send hello with the current capabilities.
368    let engine = crate::engine::build(&cfg.lock())?;
369    let capabilities = crate::runtime::build_capabilities_with(
370        &cfg.lock(),
371        &*engine,
372        !paused.load(Ordering::SeqCst),
373    );
374    // Record exactly what we're about to advertise so the worker's logs
375    // (and the studio's shipped-log view) show the offered kinds /
376    // models / VRAM budget — otherwise the handshake is opaque and
377    // "why won't it claim X jobs" can't be answered from the logs.
378    push_log_with_observers(
379        logs,
380        Some(observers),
381        "info",
382        "ws",
383        &crate::runtime::summarize_capabilities(&capabilities),
384        None,
385    );
386    sender
387        .send(&WorkerInbound::Hello(HelloFrame {
388            auth_token: auth_token.clone(),
389            capabilities: capabilities.clone(),
390        }))
391        .await
392        .map_err(|e| anyhow!("hello send failed: {e}"))?;
393    info!(target: TRACE_TARGET, worker_id = %worker_id, "hello sent");
394
395    let (event_tx, event_rx) = mpsc::unbounded_channel::<SessionEvent>();
396
397    // Reader task: pump frames into the event channel.
398    let reader = spawn_reader(receiver, event_tx.clone(), schedule.read_idle_timeout);
399
400    // Wait for the server's `Welcome` (or an error) before starting
401    // the heartbeat / log-shipper pumps.  Without this gate, the
402    // first heartbeat fires immediately (tokio `interval()` returns
403    // at t=0) and races the studio's async Hello-auth flow: a
404    // heartbeat arriving while the session is still marked
405    // `authenticated: false` server-side gets rejected with
406    // `protocol_violation: session not authenticated`, killing the
407    // session.
408    let mut event_rx = event_rx;
409    match wait_for_welcome(&mut event_rx, logs, observers).await {
410        WelcomeOutcome::Welcomed => welcomed.store(true, Ordering::SeqCst),
411        WelcomeOutcome::AuthFailed(reason) => {
412            let _ = sender.close(1000, "auth failed").await;
413            let _ = reader.await;
414            return Ok(SessionOutcome::AuthFailed(reason));
415        }
416        WelcomeOutcome::Fatal(reason) => {
417            let _ = sender.close(1000, "protocol violation").await;
418            let _ = reader.await;
419            return Ok(SessionOutcome::Fatal(reason));
420        }
421        WelcomeOutcome::Disconnected => {
422            let _ = reader.await;
423            return Ok(SessionOutcome::Disconnected);
424        }
425    }
426
427    // Heartbeat task.  Reuse the engine we already built for the
428    // Hello frame instead of rebuilding it on every heartbeat —
429    // rebuilding fires every engine's registration log every 5s and
430    // floods the logs.
431    let capabilities_for_heartbeat = capabilities.clone();
432    let heartbeat = spawn_heartbeat_pump(
433        capabilities_for_heartbeat,
434        sender.clone(),
435        stop.clone(),
436        paused.clone(),
437        observers.clone(),
438        schedule,
439    );
440
441    // Log shipper task.
442    let log_shipper = spawn_log_shipper_pump(sender.clone(), logs.clone(), stop.clone(), schedule);
443
444    // Shutdown observer: ticks until stop flag is set, then drops the channel.
445    let shutdown_observer = spawn_shutdown_observer(stop.clone(), event_tx.clone(), schedule);
446    drop(event_tx);
447
448    let engine_arc: Arc<dyn Engine> = engine.into();
449    let ctx = SessionContext {
450        sender: sender.clone(),
451        engine: engine_arc,
452        logs: logs.clone(),
453        busy: busy.clone(),
454        paused: paused.clone(),
455        observers: observers.clone(),
456        api_base_url: api_base_url.clone(),
457        worker_id: worker_id.clone(),
458        auth_token: auth_token.clone(),
459    };
460    let outcome = run_dispatch_loop(ctx, event_rx).await;
461
462    // The session is ending (disconnect or shutdown). The heartbeat / log-shipper /
463    // shutdown-observer pumps only break on the *global* stop flag or a send failure, so on a
464    // silent-but-open socket — where heartbeat sends still succeed into the TCP buffer — they would
465    // loop forever and block this function from returning, which is exactly the post-job reconnect
466    // hang. Abort them so teardown is bounded regardless of socket state, then best-effort close +
467    // drain the aborted handles (await returns promptly with Cancelled).
468    reader.abort();
469    heartbeat.abort();
470    log_shipper.abort();
471    shutdown_observer.abort();
472    let _ = sender.close(1000, "session ended").await;
473    let _ = reader.await;
474    let _ = heartbeat.await;
475    let _ = log_shipper.await;
476    let _ = shutdown_observer.await;
477    Ok(outcome)
478}
479
480/// All the events the dispatch loop reacts to.
481#[derive(Debug)]
482enum SessionEvent {
483    /// Frame arrived from the server.
484    Frame(WorkerOutbound),
485    /// Engine task finished (success or fail already reported).
486    Stopped,
487    /// Reader hit EOF / error.
488    Disconnected(WsClientError),
489}
490
491/// Bundle of immutable per-session settings the dispatcher passes
492/// around — keeps clippy's `too_many_arguments` lint happy.
493struct SessionContext {
494    sender: WsSender,
495    engine: Arc<dyn Engine>,
496    logs: Arc<Mutex<Vec<LogEntry>>>,
497    busy: Arc<AtomicBool>,
498    paused: Arc<AtomicBool>,
499    observers: WorkerObservers,
500    api_base_url: String,
501    worker_id: String,
502    auth_token: String,
503}
504
505#[cfg_attr(coverage_nightly, coverage(off))]
506async fn run_dispatch_loop(
507    ctx: SessionContext,
508    mut event_rx: mpsc::UnboundedReceiver<SessionEvent>,
509) -> SessionOutcome {
510    while let Some(event) = event_rx.recv().await {
511        match event {
512            SessionEvent::Disconnected(WsClientError::AuthFailed { reason }) => {
513                return SessionOutcome::AuthFailed(reason);
514            }
515            SessionEvent::Disconnected(_) => return SessionOutcome::Disconnected,
516            SessionEvent::Stopped => return SessionOutcome::Stopped,
517            SessionEvent::Frame(frame) => match frame {
518                WorkerOutbound::Welcome { worker_id: wid, .. } => {
519                    push_log_with_observers(
520                        &ctx.logs,
521                        Some(&ctx.observers),
522                        "info",
523                        "ws",
524                        &format!("server welcomed {wid}"),
525                        None,
526                    );
527                }
528                WorkerOutbound::Offer { claim } => {
529                    handle_offer(&ctx, *claim);
530                }
531                WorkerOutbound::Error { code, message } => {
532                    push_log_with_observers(
533                        &ctx.logs,
534                        Some(&ctx.observers),
535                        "error",
536                        "ws",
537                        &format!("server error {code:?}: {message}"),
538                        None,
539                    );
540                    return match code {
541                        crate::ws::types::WorkerErrorCode::AuthFailed => {
542                            SessionOutcome::AuthFailed(message)
543                        }
544                        _ => SessionOutcome::Fatal(message),
545                    };
546                }
547                WorkerOutbound::HeartbeatAck
548                | WorkerOutbound::CompleteAck { .. }
549                | WorkerOutbound::FailAck { .. } => {
550                    // Acks are best-effort; ignore.
551                }
552            },
553        }
554    }
555    SessionOutcome::Disconnected
556}
557
558#[cfg_attr(coverage_nightly, coverage(off))]
559fn handle_offer(ctx: &SessionContext, claim: JobOfferClaim) {
560    let job_id = claim.job_id.clone();
561    push_log_with_observers(
562        &ctx.logs,
563        Some(&ctx.observers),
564        "info",
565        "ws",
566        &format!(
567            "offer received {job_id} model={} vram={}",
568            claim.model, claim.vram_gb_estimate
569        ),
570        Some(job_id.clone()),
571    );
572    // Operator pressed Pause: reject the offer so the studio retries
573    // on a different worker (or requeues until we resume).  No engine
574    // dispatch, no busy flag flip.
575    if ctx.paused.load(Ordering::SeqCst) {
576        push_log_with_observers(
577            &ctx.logs,
578            Some(&ctx.observers),
579            "info",
580            "ws",
581            &format!("rejecting offer {job_id}: worker is paused"),
582            Some(job_id.clone()),
583        );
584        spawn_reject_offer(
585            ctx.sender.clone(),
586            ctx.logs.clone(),
587            ctx.observers.clone(),
588            job_id,
589            "worker paused by operator",
590        );
591        return;
592    }
593    if !try_reserve_worker(&ctx.busy) {
594        push_log_with_observers(
595            &ctx.logs,
596            Some(&ctx.observers),
597            "info",
598            "ws",
599            &format!("rejecting offer {job_id}: worker is already busy"),
600            Some(job_id.clone()),
601        );
602        spawn_reject_offer(
603            ctx.sender.clone(),
604            ctx.logs.clone(),
605            ctx.observers.clone(),
606            job_id,
607            "worker already has an in-flight job",
608        );
609        return;
610    }
611    let job = claim.into_job_claim();
612    let task_kind = job.task.kind();
613    // The FULL prompt goes back to the studio (and to the engine).
614    // The bounded preview (`truncate_prompt`) is only for the UI's
615    // Jobs tab so the in-memory observer ring stays small even when
616    // LLM prompts are huge.  Mixing the two used to send the
617    // truncated 200-char preview as the `prompt` form field on the
618    // multipart `/complete`, which the studio then persisted onto the
619    // row — mangling every operator-facing prompt in the DB.
620    let full_prompt = prompt_for(&job.task);
621    let prompt_preview = truncate_prompt(&full_prompt);
622    let started_at = chrono::Utc::now();
623
624    let busy_flag = ctx.busy.clone();
625    let logs_for_task = ctx.logs.clone();
626    let observers_for_task = ctx.observers.clone();
627    let sender_for_task = ctx.sender.clone();
628    let engine_for_task = ctx.engine.clone();
629    let api_base_url = ctx.api_base_url.clone();
630    let worker_id = ctx.worker_id.clone();
631    let auth_token = ctx.auth_token.clone();
632    tokio::spawn(async move {
633        let accept_result = sender_for_task
634            .send(&WorkerInbound::Accept {
635                job_id: job_id.clone(),
636            })
637            .await;
638        if let Some((level, message)) = offer_response_breadcrumb("accept", &job_id, &accept_result)
639        {
640            push_log_with_observers(
641                &logs_for_task,
642                Some(&observers_for_task),
643                level,
644                "ws",
645                &message,
646                Some(job_id.clone()),
647            );
648        }
649        if accept_result.is_err() {
650            busy_flag.store(false, Ordering::SeqCst);
651            return;
652        }
653
654        // Surface the job to the UI's Jobs tab — bounded preview only.
655        *observers_for_task.current_job.lock() = Some(CurrentJob {
656            job_id: job_id.clone(),
657            kind: task_kind,
658            model: job.model.clone(),
659            prompt: prompt_preview.clone(),
660            started_at,
661        });
662
663        run_offered_job(
664            sender_for_task,
665            engine_for_task,
666            logs_for_task,
667            observers_for_task,
668            api_base_url,
669            worker_id,
670            auth_token,
671            job,
672            started_at,
673            task_kind,
674            full_prompt,
675            prompt_preview,
676        )
677        .await;
678        busy_flag.store(false, Ordering::SeqCst);
679    });
680}
681
682fn try_reserve_worker(busy: &AtomicBool) -> bool {
683    busy.compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
684        .is_ok()
685}
686
687fn spawn_reject_offer(
688    sender: WsSender,
689    logs: Arc<Mutex<Vec<LogEntry>>>,
690    observers: WorkerObservers,
691    job_id: String,
692    reason: &'static str,
693) {
694    tokio::spawn(async move {
695        let result = sender
696            .send(&WorkerInbound::Reject {
697                job_id: job_id.clone(),
698                reason: reason.to_string(),
699            })
700            .await;
701        if let Some((level, message)) = offer_response_breadcrumb("reject", &job_id, &result) {
702            push_log_with_observers(&logs, Some(&observers), level, "ws", &message, Some(job_id));
703        }
704    });
705}
706
707#[allow(clippy::too_many_arguments)]
708#[cfg_attr(coverage_nightly, coverage(off))]
709async fn run_offered_job(
710    sender: WsSender,
711    engine: Arc<dyn Engine>,
712    logs: Arc<Mutex<Vec<LogEntry>>>,
713    observers: WorkerObservers,
714    api_base_url: String,
715    worker_id: String,
716    auth_token: String,
717    job: crate::types::JobClaim,
718    started_at: chrono::DateTime<chrono::Utc>,
719    task_kind: crate::types::TaskKind,
720    full_prompt: String,
721    prompt_preview: String,
722) {
723    let start = std::time::Instant::now();
724    // Pass the studio's `ModelSource` to the engine so sd-cpp /
725    // llama-cpp know which files to load.  Required on every offer
726    // — the studio refuses to promote a job without a model source
727    // and the worker refuses any claim that lacks one.
728    let dispatch = tokio::task::spawn_blocking({
729        let model = job.model.clone();
730        let model_source = job.model_source.clone();
731        let task_for_engine = job.task.clone();
732        let engine = engine.clone();
733        move || -> Result<TaskResult> {
734            engine.dispatch_with_source(&model, task_for_engine, &model_source)
735        }
736    })
737    .await;
738
739    let job_id = job.job_id.clone();
740    // Tracks the outcome we record into the RecentJob ring once every
741    // dispatch arm below has either succeeded or surfaced an error.
742    // The default value here only survives if the match falls through
743    // without assigning, which is unreachable; we keep it as a
744    // belt-and-braces default so the recent-jobs ring is never left
745    // half-populated by a future code-path that forgets to assign.
746    #[allow(unused_assignments)]
747    let mut outcome = JobOutcome::Failed {
748        reason: "dispatch did not run to completion".to_string(),
749    };
750    match dispatch {
751        Ok(Ok(result)) => {
752            push_log_with_observers(
753                &logs,
754                Some(&observers),
755                "info",
756                "ws",
757                &format!("{} dispatched in {:?}", task_kind.as_str(), start.elapsed()),
758                Some(job_id.clone()),
759            );
760            match result {
761                TaskResult::Image { bytes, ext }
762                | TaskResult::AudioTts { bytes, ext }
763                | TaskResult::Video { bytes, ext } => {
764                    // Binary outputs go via HTTP multipart \u2014 the only
765                    // worker-side HTTP route that survives the migration.
766                    let upload_result = tokio::task::spawn_blocking({
767                        let api_base_url = api_base_url.clone();
768                        let job_id = job_id.clone();
769                        let auth_token = auth_token.clone();
770                        let worker_id = worker_id.clone();
771                        let prompt = full_prompt.clone();
772                        move || -> Result<()> {
773                            let api = ApiClient::new(api_base_url)?;
774                            api.complete(&worker_id, &auth_token, &job_id, &ext, &prompt, bytes)
775                        }
776                    })
777                    .await;
778                    let msg = match upload_result {
779                        Ok(Ok(())) => None,
780                        Ok(Err(e)) => Some(e.to_string()),
781                        Err(e) => Some(format!("upload task panic: {e}")),
782                    };
783                    if let Some(msg) = msg {
784                        push_log_with_observers(
785                            &logs,
786                            Some(&observers),
787                            "error",
788                            "ws",
789                            &msg,
790                            Some(job_id.clone()),
791                        );
792                        outcome = JobOutcome::Failed {
793                            reason: msg.clone(),
794                        };
795                        let fail_result = sender
796                            .send(&WorkerInbound::Fail {
797                                job_id: job_id.clone(),
798                                error: msg,
799                                retryable: true,
800                            })
801                            .await;
802                        record_fail_send(&fail_result, &job_id, &logs, &observers);
803                    } else {
804                        push_log_with_observers(
805                            &logs,
806                            Some(&observers),
807                            "info",
808                            "ws",
809                            "binary upload ok",
810                            Some(job_id.clone()),
811                        );
812                        outcome = JobOutcome::Completed;
813                        // The studio's HTTP `/complete` handler defers a
814                        // `notifyJobCompleted` RPC to the
815                        // WorkerConnections DO; that's the canonical
816                        // "offer next job" nudge.  Sending an extra
817                        // `ReadyForMore` here races that flow: both can
818                        // call `offerNextFor` concurrently, double-
819                        // reserve the session's `currentJob` slot, and
820                        // ship two `Offer` frames — the second `Accept`
821                        // then trips the studio's `session not
822                        // authenticated`-shaped `accept for unknown
823                        // jobId` invariant and the DO kills the
824                        // session.  See:
825                        //   apps/studio/src/worker/modules/graphics/
826                        //     WorkerConnections/orchestrator.ts (commitOffer)
827                    }
828                }
829                TaskResult::Llm { json } | TaskResult::AudioStt { json } => {
830                    // Mirror the binary path: branch on the send result
831                    // so a dropped `completeJson` frame is recorded as a
832                    // failure (never a false-positive `Completed`) and a
833                    // successful send leaves an explicit completion
834                    // breadcrumb in the logs + shipped studio logs,
835                    // symmetric with the binary path's "binary upload ok".
836                    match sender
837                        .send(&WorkerInbound::CompleteJson {
838                            job_id: job_id.clone(),
839                            result: json,
840                            prompt: Some(full_prompt.clone()),
841                        })
842                        .await
843                    {
844                        Ok(()) => {
845                            push_log_with_observers(
846                                &logs,
847                                Some(&observers),
848                                "info",
849                                "ws",
850                                "json result sent",
851                                Some(job_id.clone()),
852                            );
853                            outcome = JobOutcome::Completed;
854                        }
855                        Err(e) => {
856                            let msg = format!("failed to send result: {e}");
857                            push_log_with_observers(
858                                &logs,
859                                Some(&observers),
860                                "error",
861                                "ws",
862                                &msg,
863                                Some(job_id.clone()),
864                            );
865                            outcome = JobOutcome::Failed { reason: msg };
866                        }
867                    }
868                }
869            }
870        }
871        Ok(Err(e)) => {
872            warn!(target: TRACE_TARGET, error = %e, "engine dispatch failed");
873            push_log_with_observers(
874                &logs,
875                Some(&observers),
876                "error",
877                "ws",
878                &format!("dispatch failed: {e}"),
879                Some(job_id.clone()),
880            );
881            outcome = JobOutcome::Failed {
882                reason: e.to_string(),
883            };
884            let fail_result = sender
885                .send(&WorkerInbound::Fail {
886                    job_id: job_id.clone(),
887                    error: e.to_string(),
888                    retryable: !is_unsupported_kind(&e),
889                })
890                .await;
891            record_fail_send(&fail_result, &job_id, &logs, &observers);
892        }
893        Err(e) => {
894            push_log_with_observers(
895                &logs,
896                Some(&observers),
897                "error",
898                "ws",
899                &format!("dispatch task panic: {e}"),
900                Some(job_id.clone()),
901            );
902            outcome = JobOutcome::Failed {
903                reason: e.to_string(),
904            };
905            let fail_result = sender
906                .send(&WorkerInbound::Fail {
907                    job_id: job_id.clone(),
908                    error: e.to_string(),
909                    retryable: true,
910                })
911                .await;
912            record_fail_send(&fail_result, &job_id, &logs, &observers);
913        }
914    }
915
916    // Surface the finished job to the UI: clear the current-job slot
917    // and push a RecentJob entry into the ring.
918    *observers.current_job.lock() = None;
919    record_recent_job(
920        &observers,
921        RecentJob {
922            job_id: job_id.clone(),
923            kind: task_kind,
924            model: job.model.clone(),
925            prompt: prompt_preview,
926            outcome,
927            started_at,
928            finished_at: chrono::Utc::now(),
929        },
930    );
931}
932
933#[cfg_attr(coverage_nightly, coverage(off))]
934fn spawn_reader(
935    mut receiver: crate::ws::client::WsReceiver,
936    event_tx: mpsc::UnboundedSender<SessionEvent>,
937    read_idle_timeout: Duration,
938) -> tokio::task::JoinHandle<()> {
939    tokio::spawn(async move {
940        loop {
941            // Bound the wait so a half-open / dead-peer socket can't block the reader forever.
942            // A live studio acks every heartbeat (~5s), so a frame always lands well inside the
943            // window; elapsing it means the connection is gone and the session must reconnect.
944            match tokio::time::timeout(read_idle_timeout, receiver.recv()).await {
945                Ok(Ok(Some(frame))) => {
946                    if event_tx.send(SessionEvent::Frame(frame)).is_err() {
947                        break;
948                    }
949                }
950                Ok(Ok(None)) => {
951                    let _ =
952                        event_tx.send(SessionEvent::Disconnected(WsClientError::ConnectionClosed));
953                    break;
954                }
955                Ok(Err(e)) => {
956                    let _ = event_tx.send(SessionEvent::Disconnected(e));
957                    break;
958                }
959                Err(_elapsed) => {
960                    let _ = event_tx.send(SessionEvent::Disconnected(WsClientError::Transport(
961                        format!(
962                            "no frames from server for {:?}; treating connection as dead",
963                            read_idle_timeout
964                        ),
965                    )));
966                    break;
967                }
968            }
969        }
970    })
971}
972
973#[cfg_attr(coverage_nightly, coverage(off))]
974fn spawn_heartbeat_pump(
975    capabilities: WorkerCapabilities,
976    sender: WsSender,
977    stop: Arc<AtomicBool>,
978    paused: Arc<AtomicBool>,
979    observers: WorkerObservers,
980    schedule: SessionSchedule,
981) -> tokio::task::JoinHandle<()> {
982    tokio::spawn(async move {
983        let mut interval = tokio::time::interval(schedule.heartbeat);
984        interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
985        loop {
986            interval.tick().await;
987            if stop.load(Ordering::SeqCst) {
988                break;
989            }
990            // The capability snapshot is captured once at session
991            // start.  Only `auto_enabled` (the pause flag) and the
992            // current job id can change between heartbeats.
993            let mut caps = capabilities.clone();
994            caps.auto_enabled = !paused.load(Ordering::SeqCst);
995            let current_job_id = heartbeat_current_job_id(&observers);
996            if let Err(e) = sender
997                .send(&WorkerInbound::Heartbeat {
998                    capabilities: caps,
999                    current_job_id,
1000                })
1001                .await
1002            {
1003                warn!(target: TRACE_TARGET, error = %e, "heartbeat send failed");
1004                break;
1005            }
1006        }
1007    })
1008}
1009
1010fn heartbeat_current_job_id(observers: &WorkerObservers) -> Option<String> {
1011    observers
1012        .current_job
1013        .lock()
1014        .as_ref()
1015        .map(|job| job.job_id.clone())
1016}
1017
1018#[cfg_attr(coverage_nightly, coverage(off))]
1019fn spawn_log_shipper_pump(
1020    sender: WsSender,
1021    logs: Arc<Mutex<Vec<LogEntry>>>,
1022    stop: Arc<AtomicBool>,
1023    schedule: SessionSchedule,
1024) -> tokio::task::JoinHandle<()> {
1025    tokio::spawn(async move {
1026        let mut interval = tokio::time::interval(schedule.log_flush);
1027        interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
1028        loop {
1029            interval.tick().await;
1030            if stop.load(Ordering::SeqCst) {
1031                break;
1032            }
1033            let batch = {
1034                let mut guard = logs.lock();
1035                if guard.is_empty() {
1036                    continue;
1037                }
1038                std::mem::take(&mut *guard)
1039            };
1040            if let Err(e) = sender
1041                .send(&WorkerInbound::LogBatch { entries: batch })
1042                .await
1043            {
1044                warn!(target: TRACE_TARGET, error = %e, "log batch send failed");
1045                break;
1046            }
1047        }
1048    })
1049}
1050
1051#[cfg_attr(coverage_nightly, coverage(off))]
1052fn spawn_shutdown_observer(
1053    stop: Arc<AtomicBool>,
1054    event_tx: mpsc::UnboundedSender<SessionEvent>,
1055    schedule: SessionSchedule,
1056) -> tokio::task::JoinHandle<()> {
1057    tokio::spawn(async move {
1058        loop {
1059            tokio::time::sleep(schedule.shutdown_tick).await;
1060            if stop.load(Ordering::SeqCst) {
1061                let _ = event_tx.send(SessionEvent::Stopped);
1062                break;
1063            }
1064            if event_tx.is_closed() {
1065                break;
1066            }
1067        }
1068    })
1069}
1070
1071fn backoff_for(attempt: u32, schedule: SessionSchedule) -> Duration {
1072    let factor = 2u64.saturating_pow(attempt.saturating_sub(1));
1073    let raw_ms = schedule.base_backoff_ms.saturating_mul(factor);
1074    Duration::from_millis(raw_ms.min(schedule.max_backoff_ms))
1075}
1076
1077/// Decide whether a just-attempted offer-response send (accept /
1078/// reject) warrants a session-level breadcrumb.
1079///
1080/// Returns `None` on success: the happy path is already implied by the
1081/// surrounding "dispatched" / "rejecting offer: paused" breadcrumbs, so
1082/// re-logging it would only add per-job noise.  Returns
1083/// `Some(("error", …))` when the send failed — a dropped accept leaves
1084/// the worker running a job the studio never marked accepted, and a
1085/// dropped reject leaves the offer reserved on a paused worker until it
1086/// times out.  The transport layer already logs the failure locally on
1087/// `studio_worker::ws::client`, but only a session-level breadcrumb
1088/// reaches the UI's Logs tab and the studio-shipped log view with the
1089/// offending `job_id` attached.  Pure so the wording + level are
1090/// unit-tested without a live WS sink.
1091fn offer_response_breadcrumb(
1092    label: &str,
1093    job_id: &str,
1094    result: &WsResult<()>,
1095) -> Option<(&'static str, String)> {
1096    match result {
1097        Ok(()) => None,
1098        Err(e) => Some((
1099            "error",
1100            format!("{label} send failed for offer {job_id}: {e}"),
1101        )),
1102    }
1103}
1104
1105/// Decide whether a just-attempted `Fail`-frame send warrants a
1106/// session-level breadcrumb.
1107///
1108/// Returns `None` on success: the caller already logged the underlying
1109/// job failure (the upload error, dispatch error, or panic), so a `Fail`
1110/// frame that lands needs no second per-job line.  Returns
1111/// `Some(("error", …))` when the send itself failed — a dropped `Fail`
1112/// leaves the studio believing the job is still in flight (reserved on
1113/// the session's `currentJob` slot) until it times out, with no local
1114/// record that the notification never landed.  The transport layer logs
1115/// the drop locally on `studio_worker::ws::client`, but only a
1116/// session-level breadcrumb reaches the UI's Logs tab and the
1117/// studio-shipped log view with the offending `job_id` attached.  Pure
1118/// so the wording + level are unit-tested without a live WS sink.
1119fn fail_send_breadcrumb(job_id: &str, result: &WsResult<()>) -> Option<(&'static str, String)> {
1120    match result {
1121        Ok(()) => None,
1122        Err(e) => Some((
1123            "error",
1124            format!("failed to notify studio of job {job_id} failure: {e}"),
1125        )),
1126    }
1127}
1128
1129/// Push a session-level breadcrumb when a `Fail`-frame send dropped.
1130///
1131/// Trivial glue over [`fail_send_breadcrumb`]: the three job-failure
1132/// arms (upload error, dispatch error, dispatch panic) all notify the
1133/// studio with a `Fail` frame and then call this, so a dropped
1134/// notification is recorded with the `job_id` attached instead of being
1135/// swallowed by `let _ = sender.send(...)`.
1136fn record_fail_send(
1137    result: &WsResult<()>,
1138    job_id: &str,
1139    logs: &Arc<Mutex<Vec<LogEntry>>>,
1140    observers: &WorkerObservers,
1141) {
1142    if let Some((level, message)) = fail_send_breadcrumb(job_id, result) {
1143        push_log_with_observers(
1144            logs,
1145            Some(observers),
1146            level,
1147            "ws",
1148            &message,
1149            Some(job_id.to_string()),
1150        );
1151    }
1152}
1153
1154#[cfg(test)]
1155mod tests {
1156    use super::*;
1157
1158    #[test]
1159    fn offer_response_breadcrumb_is_silent_on_success() {
1160        // The happy path is already implied by the surrounding
1161        // "dispatched" / "rejecting offer: paused" breadcrumbs, so a
1162        // successful accept / reject send must not add per-job noise.
1163        assert!(offer_response_breadcrumb("accept", "j-1", &Ok(())).is_none());
1164        assert!(offer_response_breadcrumb("reject", "j-2", &Ok(())).is_none());
1165    }
1166
1167    #[test]
1168    fn try_reserve_worker_only_allows_one_in_flight_job() {
1169        let busy = AtomicBool::new(false);
1170        assert!(try_reserve_worker(&busy));
1171        assert!(!try_reserve_worker(&busy));
1172    }
1173
1174    #[test]
1175    fn heartbeat_current_job_id_uses_actual_job_id() {
1176        let observers = WorkerObservers::default();
1177        assert_eq!(heartbeat_current_job_id(&observers), None);
1178        *observers.current_job.lock() = Some(CurrentJob {
1179            job_id: "job-42".into(),
1180            kind: crate::types::TaskKind::Image,
1181            model: "synthetic".into(),
1182            prompt: "prompt".into(),
1183            started_at: chrono::Utc::now(),
1184        });
1185        assert_eq!(
1186            heartbeat_current_job_id(&observers).as_deref(),
1187            Some("job-42")
1188        );
1189    }
1190
1191    #[test]
1192    fn offer_response_breadcrumb_reports_accept_send_failure() {
1193        let (level, msg) =
1194            offer_response_breadcrumb("accept", "j-1", &Err(WsClientError::ConnectionClosed))
1195                .expect("a failed accept send must surface a breadcrumb");
1196        assert_eq!(level, "error");
1197        assert!(msg.contains("accept send failed"), "got: {msg}");
1198        assert!(msg.contains("j-1"), "must name the job: {msg}");
1199        assert!(
1200            msg.contains("connection closed"),
1201            "must carry the cause: {msg}"
1202        );
1203    }
1204
1205    #[test]
1206    fn offer_response_breadcrumb_reports_reject_send_failure() {
1207        let (level, msg) = offer_response_breadcrumb(
1208            "reject",
1209            "j-9",
1210            &Err(WsClientError::Transport("sink gone".into())),
1211        )
1212        .expect("a failed reject send must surface a breadcrumb");
1213        assert_eq!(level, "error");
1214        assert!(msg.contains("reject send failed"), "got: {msg}");
1215        assert!(msg.contains("j-9"), "must name the job: {msg}");
1216        assert!(msg.contains("sink gone"), "must carry the cause: {msg}");
1217    }
1218
1219    #[test]
1220    fn fail_send_breadcrumb_is_silent_on_success() {
1221        // The underlying job failure (upload / dispatch / panic) is
1222        // already logged by the caller, so a Fail-frame that lands must
1223        // not add a second per-job line.
1224        assert!(fail_send_breadcrumb("j-1", &Ok(())).is_none());
1225    }
1226
1227    #[test]
1228    fn fail_send_breadcrumb_reports_send_failure() {
1229        let (level, msg) = fail_send_breadcrumb("j-7", &Err(WsClientError::ConnectionClosed))
1230            .expect("a dropped Fail send must surface a breadcrumb");
1231        assert_eq!(level, "error");
1232        assert!(msg.contains("j-7"), "must name the job: {msg}");
1233        assert!(
1234            msg.contains("connection closed"),
1235            "must carry the cause: {msg}"
1236        );
1237    }
1238
1239    #[test]
1240    fn fail_send_breadcrumb_carries_transport_cause() {
1241        let (level, msg) =
1242            fail_send_breadcrumb("j-3", &Err(WsClientError::Transport("sink gone".into())))
1243                .expect("a dropped Fail send must surface a breadcrumb");
1244        assert_eq!(level, "error");
1245        assert!(msg.contains("j-3"), "must name the job: {msg}");
1246        assert!(msg.contains("sink gone"), "must carry the cause: {msg}");
1247    }
1248
1249    #[test]
1250    fn backoff_grows_exponentially_until_cap() {
1251        let schedule = SessionSchedule {
1252            base_backoff_ms: 100,
1253            max_backoff_ms: 1_000,
1254            heartbeat: Duration::from_secs(1),
1255            log_flush: Duration::from_secs(1),
1256            shutdown_tick: Duration::from_secs(1),
1257            read_idle_timeout: Duration::from_secs(1),
1258        };
1259        assert_eq!(backoff_for(1, schedule), Duration::from_millis(100));
1260        assert_eq!(backoff_for(2, schedule), Duration::from_millis(200));
1261        assert_eq!(backoff_for(3, schedule), Duration::from_millis(400));
1262        assert_eq!(backoff_for(4, schedule), Duration::from_millis(800));
1263        // Capped.
1264        assert_eq!(backoff_for(5, schedule), Duration::from_millis(1_000));
1265        assert_eq!(backoff_for(10, schedule), Duration::from_millis(1_000));
1266    }
1267
1268    #[test]
1269    fn has_credentials_false_when_either_missing() {
1270        let mut cfg = crate::config::Config::default();
1271        let shared = crate::config::shared(cfg.clone());
1272        assert!(!has_credentials(&shared), "both missing");
1273        cfg.worker_id = Some("w-1".into());
1274        let shared = crate::config::shared(cfg.clone());
1275        assert!(!has_credentials(&shared), "only worker_id");
1276        cfg.worker_id = None;
1277        cfg.auth_token = Some("tok".into());
1278        let shared = crate::config::shared(cfg.clone());
1279        assert!(!has_credentials(&shared), "only auth_token");
1280    }
1281
1282    #[test]
1283    fn has_credentials_true_when_both_present() {
1284        let cfg = crate::config::Config {
1285            worker_id: Some("w-1".into()),
1286            auth_token: Some("tok".into()),
1287            ..crate::config::Config::default()
1288        };
1289        let shared = crate::config::shared(cfg);
1290        assert!(has_credentials(&shared));
1291    }
1292
1293    #[test]
1294    fn has_credentials_false_when_empty_strings() {
1295        let cfg = crate::config::Config {
1296            worker_id: Some("".into()),
1297            auth_token: Some("".into()),
1298            ..crate::config::Config::default()
1299        };
1300        let shared = crate::config::shared(cfg);
1301        assert!(!has_credentials(&shared));
1302    }
1303}