1use std::sync::{
13 atomic::{AtomicBool, Ordering},
14 Arc,
15};
16use std::time::Duration;
17
18use anyhow::{anyhow, Result};
19use parking_lot::Mutex;
20use tokio::sync::mpsc;
21use tracing::{info, warn};
22
23use crate::config::SharedConfig;
24use crate::engine::Engine;
25use crate::http::ApiClient;
26use crate::runtime::{
27 is_unsupported_kind, prompt_for, push_log_with_observers, record_recent_job, truncate_prompt,
28 wait_with_stop, CurrentJob, JobOutcome, RecentJob, WorkerObservers,
29};
30use crate::types::{LogEntry, TaskResult, WorkerCapabilities};
31use crate::ws::client::{connect, WsClientError, WsResult, WsSender};
32use crate::ws::types::{HelloFrame, JobOfferClaim, WorkerInbound, WorkerOutbound};
33
34const TRACE_TARGET: &str = "studio_worker::ws::session";
36
37const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5);
38const LOG_FLUSH_INTERVAL: Duration = Duration::from_secs(1);
39const SHUTDOWN_TICK: Duration = Duration::from_millis(250);
40const BASE_BACKOFF_MS: u64 = 1_000;
41const MAX_BACKOFF_MS: u64 = 30_000;
42const DEFAULT_RECONNECT_ATTEMPTS: u32 = 5;
43const READ_IDLE_TIMEOUT: Duration = Duration::from_secs(20);
48
49#[derive(Debug)]
52pub enum SessionOutcome {
53 Stopped,
55 Disconnected,
57 AuthFailed(String),
59 Fatal(String),
61}
62
63#[derive(Debug, Clone, Copy)]
65pub struct SessionSchedule {
66 pub heartbeat: Duration,
67 pub log_flush: Duration,
68 pub shutdown_tick: Duration,
69 pub base_backoff_ms: u64,
70 pub max_backoff_ms: u64,
71 pub read_idle_timeout: Duration,
73}
74
75impl Default for SessionSchedule {
76 fn default() -> Self {
77 Self {
78 heartbeat: HEARTBEAT_INTERVAL,
79 log_flush: LOG_FLUSH_INTERVAL,
80 shutdown_tick: SHUTDOWN_TICK,
81 base_backoff_ms: BASE_BACKOFF_MS,
82 max_backoff_ms: MAX_BACKOFF_MS,
83 read_idle_timeout: READ_IDLE_TIMEOUT,
84 }
85 }
86}
87
88impl SessionSchedule {
89 pub fn fast_for_tests() -> Self {
90 Self {
91 heartbeat: Duration::from_millis(5),
92 log_flush: Duration::from_millis(5),
93 shutdown_tick: Duration::from_millis(5),
94 base_backoff_ms: 1,
95 max_backoff_ms: 10,
96 read_idle_timeout: Duration::from_secs(5),
99 }
100 }
101}
102
103#[cfg_attr(coverage_nightly, coverage(off))]
111pub async fn spawn_ws_session(
112 cfg: SharedConfig,
113 stop: Arc<AtomicBool>,
114 logs: Arc<Mutex<Vec<LogEntry>>>,
115 busy: Arc<AtomicBool>,
116 paused: Arc<AtomicBool>,
117 observers: WorkerObservers,
118 schedule: SessionSchedule,
119) -> Result<()> {
120 let max_attempts = {
121 let guard = cfg.lock();
122 guard
123 .ws_reconnect_attempts
124 .unwrap_or(DEFAULT_RECONNECT_ATTEMPTS)
125 };
126
127 let mut attempt: u32 = 0;
128 let mut waiting_for_creds_logged = false;
129 loop {
130 if stop.load(Ordering::SeqCst) {
131 return Ok(());
132 }
133 if !has_credentials(&cfg) {
139 if !waiting_for_creds_logged {
140 push_log_with_observers(
141 &logs,
142 Some(&observers),
143 "info",
144 "ws",
145 "waiting for operator approval before opening the session",
146 None,
147 );
148 waiting_for_creds_logged = true;
149 }
150 wait_with_stop(Duration::from_secs(1), &stop, schedule.shutdown_tick).await;
151 continue;
152 }
153 waiting_for_creds_logged = false;
154
155 let welcomed = AtomicBool::new(false);
156 match run_one_session(
157 &cfg, &stop, &logs, &busy, &paused, &observers, schedule, &welcomed,
158 )
159 .await
160 {
161 Ok(SessionOutcome::Stopped) => return Ok(()),
162 Ok(SessionOutcome::AuthFailed(reason)) => {
163 push_log_with_observers(
164 &logs,
165 Some(&observers),
166 "error",
167 "ws",
168 &format!("auth failed: {reason}. Re-register the worker."),
169 None,
170 );
171 return Err(anyhow!("ws auth failed: {reason}"));
172 }
173 Ok(SessionOutcome::Fatal(reason)) => {
174 push_log_with_observers(
175 &logs,
176 Some(&observers),
177 "error",
178 "ws",
179 &format!("fatal: {reason}"),
180 None,
181 );
182 return Err(anyhow!("ws fatal: {reason}"));
183 }
184 Ok(SessionOutcome::Disconnected) | Err(_) => {
185 if welcomed.load(Ordering::SeqCst) {
189 attempt = 0;
190 }
191 attempt += 1;
192 if max_attempts > 0 && attempt > max_attempts {
193 push_log_with_observers(
194 &logs,
195 Some(&observers),
196 "error",
197 "ws",
198 &format!("giving up after {attempt} reconnect attempts"),
199 None,
200 );
201 return Err(anyhow!("ws reconnect cap reached"));
202 }
203 let backoff = backoff_for(attempt, schedule);
204 push_log_with_observers(
205 &logs,
206 Some(&observers),
207 "warn",
208 "ws",
209 &format!(
210 "disconnected; reconnect attempt {attempt} in {}ms",
211 backoff.as_millis()
212 ),
213 None,
214 );
215 wait_with_stop(backoff, &stop, schedule.shutdown_tick).await;
216 }
217 }
218 }
219}
220
221enum WelcomeOutcome {
226 Welcomed,
227 AuthFailed(String),
228 Fatal(String),
229 Disconnected,
230}
231
232#[cfg_attr(coverage_nightly, coverage(off))]
238async fn wait_for_welcome(
239 event_rx: &mut mpsc::UnboundedReceiver<SessionEvent>,
240 logs: &Arc<Mutex<Vec<LogEntry>>>,
241 observers: &WorkerObservers,
242) -> WelcomeOutcome {
243 while let Some(event) = event_rx.recv().await {
244 match event {
245 SessionEvent::Frame(WorkerOutbound::Welcome { worker_id: wid, .. }) => {
246 push_log_with_observers(
247 logs,
248 Some(observers),
249 "info",
250 "ws",
251 &format!("server welcomed {wid}"),
252 None,
253 );
254 return WelcomeOutcome::Welcomed;
255 }
256 SessionEvent::Frame(WorkerOutbound::Error { code, message }) => {
257 push_log_with_observers(
258 logs,
259 Some(observers),
260 "error",
261 "ws",
262 &format!("server error before welcome {code:?}: {message}"),
263 None,
264 );
265 return match code {
266 crate::ws::types::WorkerErrorCode::AuthFailed => {
267 WelcomeOutcome::AuthFailed(message)
268 }
269 _ => WelcomeOutcome::Fatal(message),
270 };
271 }
272 SessionEvent::Frame(other) => {
273 push_log_with_observers(
274 logs,
275 Some(observers),
276 "warn",
277 "ws",
278 &format!("server sent unexpected frame before welcome: {other:?}"),
279 None,
280 );
281 }
283 SessionEvent::Disconnected(WsClientError::AuthFailed { reason }) => {
284 return WelcomeOutcome::AuthFailed(reason);
285 }
286 SessionEvent::Disconnected(_) => return WelcomeOutcome::Disconnected,
287 SessionEvent::Stopped => return WelcomeOutcome::Disconnected,
288 }
289 }
290 WelcomeOutcome::Disconnected
291}
292
293fn has_credentials(cfg: &SharedConfig) -> bool {
297 let guard = cfg.lock();
298 guard
299 .worker_id
300 .as_deref()
301 .map(|s| !s.is_empty())
302 .unwrap_or(false)
303 && guard
304 .auth_token
305 .as_deref()
306 .map(|s| !s.is_empty())
307 .unwrap_or(false)
308}
309
310#[cfg_attr(coverage_nightly, coverage(off))]
313#[allow(clippy::too_many_arguments)]
316async fn run_one_session(
317 cfg: &SharedConfig,
318 stop: &Arc<AtomicBool>,
319 logs: &Arc<Mutex<Vec<LogEntry>>>,
320 busy: &Arc<AtomicBool>,
321 paused: &Arc<AtomicBool>,
322 observers: &WorkerObservers,
323 schedule: SessionSchedule,
324 welcomed: &AtomicBool,
325) -> Result<SessionOutcome> {
326 let (api_base_url, worker_id, auth_token) = {
327 let guard = cfg.lock();
328 (
329 guard.api_base_url.clone(),
330 guard.worker_id.clone().unwrap_or_default(),
331 guard.auth_token.clone().unwrap_or_default(),
332 )
333 };
334 if worker_id.is_empty() || auth_token.is_empty() {
335 return Ok(SessionOutcome::Fatal(
336 "worker_id or auth_token missing; run register".to_string(),
337 ));
338 }
339
340 push_log_with_observers(
341 logs,
342 Some(observers),
343 "info",
344 "ws",
345 &format!("connecting to {api_base_url}"),
346 None,
347 );
348 let client = match connect(&api_base_url, &worker_id, &auth_token).await {
349 Ok(c) => c,
350 Err(WsClientError::AuthFailed { reason }) => {
351 return Ok(SessionOutcome::AuthFailed(reason));
352 }
353 Err(e) => {
354 push_log_with_observers(
355 logs,
356 Some(observers),
357 "warn",
358 "ws",
359 &format!("connect failed: {e}"),
360 None,
361 );
362 return Ok(SessionOutcome::Disconnected);
363 }
364 };
365 let (sender, receiver) = client.split();
366
367 let engine = crate::engine::build(&cfg.lock())?;
369 let capabilities = crate::runtime::build_capabilities_with(
370 &cfg.lock(),
371 &*engine,
372 !paused.load(Ordering::SeqCst),
373 );
374 push_log_with_observers(
379 logs,
380 Some(observers),
381 "info",
382 "ws",
383 &crate::runtime::summarize_capabilities(&capabilities),
384 None,
385 );
386 sender
387 .send(&WorkerInbound::Hello(HelloFrame {
388 auth_token: auth_token.clone(),
389 capabilities: capabilities.clone(),
390 }))
391 .await
392 .map_err(|e| anyhow!("hello send failed: {e}"))?;
393 info!(target: TRACE_TARGET, worker_id = %worker_id, "hello sent");
394
395 let (event_tx, event_rx) = mpsc::unbounded_channel::<SessionEvent>();
396
397 let reader = spawn_reader(receiver, event_tx.clone(), schedule.read_idle_timeout);
399
400 let mut event_rx = event_rx;
409 match wait_for_welcome(&mut event_rx, logs, observers).await {
410 WelcomeOutcome::Welcomed => welcomed.store(true, Ordering::SeqCst),
411 WelcomeOutcome::AuthFailed(reason) => {
412 let _ = sender.close(1000, "auth failed").await;
413 let _ = reader.await;
414 return Ok(SessionOutcome::AuthFailed(reason));
415 }
416 WelcomeOutcome::Fatal(reason) => {
417 let _ = sender.close(1000, "protocol violation").await;
418 let _ = reader.await;
419 return Ok(SessionOutcome::Fatal(reason));
420 }
421 WelcomeOutcome::Disconnected => {
422 let _ = reader.await;
423 return Ok(SessionOutcome::Disconnected);
424 }
425 }
426
427 let capabilities_for_heartbeat = capabilities.clone();
432 let heartbeat = spawn_heartbeat_pump(
433 capabilities_for_heartbeat,
434 sender.clone(),
435 stop.clone(),
436 paused.clone(),
437 observers.clone(),
438 schedule,
439 );
440
441 let log_shipper = spawn_log_shipper_pump(sender.clone(), logs.clone(), stop.clone(), schedule);
443
444 let shutdown_observer = spawn_shutdown_observer(stop.clone(), event_tx.clone(), schedule);
446 drop(event_tx);
447
448 let engine_arc: Arc<dyn Engine> = engine.into();
449 let ctx = SessionContext {
450 sender: sender.clone(),
451 engine: engine_arc,
452 logs: logs.clone(),
453 busy: busy.clone(),
454 paused: paused.clone(),
455 observers: observers.clone(),
456 api_base_url: api_base_url.clone(),
457 worker_id: worker_id.clone(),
458 auth_token: auth_token.clone(),
459 };
460 let outcome = run_dispatch_loop(ctx, event_rx).await;
461
462 reader.abort();
469 heartbeat.abort();
470 log_shipper.abort();
471 shutdown_observer.abort();
472 let _ = sender.close(1000, "session ended").await;
473 let _ = reader.await;
474 let _ = heartbeat.await;
475 let _ = log_shipper.await;
476 let _ = shutdown_observer.await;
477 Ok(outcome)
478}
479
480#[derive(Debug)]
482enum SessionEvent {
483 Frame(WorkerOutbound),
485 Stopped,
487 Disconnected(WsClientError),
489}
490
491struct SessionContext {
494 sender: WsSender,
495 engine: Arc<dyn Engine>,
496 logs: Arc<Mutex<Vec<LogEntry>>>,
497 busy: Arc<AtomicBool>,
498 paused: Arc<AtomicBool>,
499 observers: WorkerObservers,
500 api_base_url: String,
501 worker_id: String,
502 auth_token: String,
503}
504
505#[cfg_attr(coverage_nightly, coverage(off))]
506async fn run_dispatch_loop(
507 ctx: SessionContext,
508 mut event_rx: mpsc::UnboundedReceiver<SessionEvent>,
509) -> SessionOutcome {
510 while let Some(event) = event_rx.recv().await {
511 match event {
512 SessionEvent::Disconnected(WsClientError::AuthFailed { reason }) => {
513 return SessionOutcome::AuthFailed(reason);
514 }
515 SessionEvent::Disconnected(_) => return SessionOutcome::Disconnected,
516 SessionEvent::Stopped => return SessionOutcome::Stopped,
517 SessionEvent::Frame(frame) => match frame {
518 WorkerOutbound::Welcome { worker_id: wid, .. } => {
519 push_log_with_observers(
520 &ctx.logs,
521 Some(&ctx.observers),
522 "info",
523 "ws",
524 &format!("server welcomed {wid}"),
525 None,
526 );
527 }
528 WorkerOutbound::Offer { claim } => {
529 handle_offer(&ctx, *claim);
530 }
531 WorkerOutbound::Error { code, message } => {
532 push_log_with_observers(
533 &ctx.logs,
534 Some(&ctx.observers),
535 "error",
536 "ws",
537 &format!("server error {code:?}: {message}"),
538 None,
539 );
540 return match code {
541 crate::ws::types::WorkerErrorCode::AuthFailed => {
542 SessionOutcome::AuthFailed(message)
543 }
544 _ => SessionOutcome::Fatal(message),
545 };
546 }
547 WorkerOutbound::HeartbeatAck
548 | WorkerOutbound::CompleteAck { .. }
549 | WorkerOutbound::FailAck { .. } => {
550 }
552 },
553 }
554 }
555 SessionOutcome::Disconnected
556}
557
558#[cfg_attr(coverage_nightly, coverage(off))]
559fn handle_offer(ctx: &SessionContext, claim: JobOfferClaim) {
560 let job_id = claim.job_id.clone();
561 push_log_with_observers(
562 &ctx.logs,
563 Some(&ctx.observers),
564 "info",
565 "ws",
566 &format!(
567 "offer received {job_id} model={} vram={}",
568 claim.model, claim.vram_gb_estimate
569 ),
570 Some(job_id.clone()),
571 );
572 if ctx.paused.load(Ordering::SeqCst) {
576 push_log_with_observers(
577 &ctx.logs,
578 Some(&ctx.observers),
579 "info",
580 "ws",
581 &format!("rejecting offer {job_id}: worker is paused"),
582 Some(job_id.clone()),
583 );
584 spawn_reject_offer(
585 ctx.sender.clone(),
586 ctx.logs.clone(),
587 ctx.observers.clone(),
588 job_id,
589 "worker paused by operator",
590 );
591 return;
592 }
593 if !try_reserve_worker(&ctx.busy) {
594 push_log_with_observers(
595 &ctx.logs,
596 Some(&ctx.observers),
597 "info",
598 "ws",
599 &format!("rejecting offer {job_id}: worker is already busy"),
600 Some(job_id.clone()),
601 );
602 spawn_reject_offer(
603 ctx.sender.clone(),
604 ctx.logs.clone(),
605 ctx.observers.clone(),
606 job_id,
607 "worker already has an in-flight job",
608 );
609 return;
610 }
611 let job = claim.into_job_claim();
612 let task_kind = job.task.kind();
613 let full_prompt = prompt_for(&job.task);
621 let prompt_preview = truncate_prompt(&full_prompt);
622 let started_at = chrono::Utc::now();
623
624 let busy_flag = ctx.busy.clone();
625 let logs_for_task = ctx.logs.clone();
626 let observers_for_task = ctx.observers.clone();
627 let sender_for_task = ctx.sender.clone();
628 let engine_for_task = ctx.engine.clone();
629 let api_base_url = ctx.api_base_url.clone();
630 let worker_id = ctx.worker_id.clone();
631 let auth_token = ctx.auth_token.clone();
632 tokio::spawn(async move {
633 let accept_result = sender_for_task
634 .send(&WorkerInbound::Accept {
635 job_id: job_id.clone(),
636 })
637 .await;
638 if let Some((level, message)) = offer_response_breadcrumb("accept", &job_id, &accept_result)
639 {
640 push_log_with_observers(
641 &logs_for_task,
642 Some(&observers_for_task),
643 level,
644 "ws",
645 &message,
646 Some(job_id.clone()),
647 );
648 }
649 if accept_result.is_err() {
650 busy_flag.store(false, Ordering::SeqCst);
651 return;
652 }
653
654 *observers_for_task.current_job.lock() = Some(CurrentJob {
656 job_id: job_id.clone(),
657 kind: task_kind,
658 model: job.model.clone(),
659 prompt: prompt_preview.clone(),
660 started_at,
661 });
662
663 run_offered_job(
664 sender_for_task,
665 engine_for_task,
666 logs_for_task,
667 observers_for_task,
668 api_base_url,
669 worker_id,
670 auth_token,
671 job,
672 started_at,
673 task_kind,
674 full_prompt,
675 prompt_preview,
676 )
677 .await;
678 busy_flag.store(false, Ordering::SeqCst);
679 });
680}
681
682fn try_reserve_worker(busy: &AtomicBool) -> bool {
683 busy.compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
684 .is_ok()
685}
686
687fn spawn_reject_offer(
688 sender: WsSender,
689 logs: Arc<Mutex<Vec<LogEntry>>>,
690 observers: WorkerObservers,
691 job_id: String,
692 reason: &'static str,
693) {
694 tokio::spawn(async move {
695 let result = sender
696 .send(&WorkerInbound::Reject {
697 job_id: job_id.clone(),
698 reason: reason.to_string(),
699 })
700 .await;
701 if let Some((level, message)) = offer_response_breadcrumb("reject", &job_id, &result) {
702 push_log_with_observers(&logs, Some(&observers), level, "ws", &message, Some(job_id));
703 }
704 });
705}
706
707#[allow(clippy::too_many_arguments)]
708#[cfg_attr(coverage_nightly, coverage(off))]
709async fn run_offered_job(
710 sender: WsSender,
711 engine: Arc<dyn Engine>,
712 logs: Arc<Mutex<Vec<LogEntry>>>,
713 observers: WorkerObservers,
714 api_base_url: String,
715 worker_id: String,
716 auth_token: String,
717 job: crate::types::JobClaim,
718 started_at: chrono::DateTime<chrono::Utc>,
719 task_kind: crate::types::TaskKind,
720 full_prompt: String,
721 prompt_preview: String,
722) {
723 let start = std::time::Instant::now();
724 let dispatch = tokio::task::spawn_blocking({
729 let model = job.model.clone();
730 let model_source = job.model_source.clone();
731 let task_for_engine = job.task.clone();
732 let engine = engine.clone();
733 move || -> Result<TaskResult> {
734 engine.dispatch_with_source(&model, task_for_engine, &model_source)
735 }
736 })
737 .await;
738
739 let job_id = job.job_id.clone();
740 #[allow(unused_assignments)]
747 let mut outcome = JobOutcome::Failed {
748 reason: "dispatch did not run to completion".to_string(),
749 };
750 match dispatch {
751 Ok(Ok(result)) => {
752 push_log_with_observers(
753 &logs,
754 Some(&observers),
755 "info",
756 "ws",
757 &format!("{} dispatched in {:?}", task_kind.as_str(), start.elapsed()),
758 Some(job_id.clone()),
759 );
760 match result {
761 TaskResult::Image { bytes, ext }
762 | TaskResult::AudioTts { bytes, ext }
763 | TaskResult::Video { bytes, ext } => {
764 let upload_result = tokio::task::spawn_blocking({
767 let api_base_url = api_base_url.clone();
768 let job_id = job_id.clone();
769 let auth_token = auth_token.clone();
770 let worker_id = worker_id.clone();
771 let prompt = full_prompt.clone();
772 move || -> Result<()> {
773 let api = ApiClient::new(api_base_url)?;
774 api.complete(&worker_id, &auth_token, &job_id, &ext, &prompt, bytes)
775 }
776 })
777 .await;
778 let msg = match upload_result {
779 Ok(Ok(())) => None,
780 Ok(Err(e)) => Some(e.to_string()),
781 Err(e) => Some(format!("upload task panic: {e}")),
782 };
783 if let Some(msg) = msg {
784 push_log_with_observers(
785 &logs,
786 Some(&observers),
787 "error",
788 "ws",
789 &msg,
790 Some(job_id.clone()),
791 );
792 outcome = JobOutcome::Failed {
793 reason: msg.clone(),
794 };
795 let fail_result = sender
796 .send(&WorkerInbound::Fail {
797 job_id: job_id.clone(),
798 error: msg,
799 retryable: true,
800 })
801 .await;
802 record_fail_send(&fail_result, &job_id, &logs, &observers);
803 } else {
804 push_log_with_observers(
805 &logs,
806 Some(&observers),
807 "info",
808 "ws",
809 "binary upload ok",
810 Some(job_id.clone()),
811 );
812 outcome = JobOutcome::Completed;
813 }
828 }
829 TaskResult::Llm { json } | TaskResult::AudioStt { json } => {
830 match sender
837 .send(&WorkerInbound::CompleteJson {
838 job_id: job_id.clone(),
839 result: json,
840 prompt: Some(full_prompt.clone()),
841 })
842 .await
843 {
844 Ok(()) => {
845 push_log_with_observers(
846 &logs,
847 Some(&observers),
848 "info",
849 "ws",
850 "json result sent",
851 Some(job_id.clone()),
852 );
853 outcome = JobOutcome::Completed;
854 }
855 Err(e) => {
856 let msg = format!("failed to send result: {e}");
857 push_log_with_observers(
858 &logs,
859 Some(&observers),
860 "error",
861 "ws",
862 &msg,
863 Some(job_id.clone()),
864 );
865 outcome = JobOutcome::Failed { reason: msg };
866 }
867 }
868 }
869 }
870 }
871 Ok(Err(e)) => {
872 warn!(target: TRACE_TARGET, error = %e, "engine dispatch failed");
873 push_log_with_observers(
874 &logs,
875 Some(&observers),
876 "error",
877 "ws",
878 &format!("dispatch failed: {e}"),
879 Some(job_id.clone()),
880 );
881 outcome = JobOutcome::Failed {
882 reason: e.to_string(),
883 };
884 let fail_result = sender
885 .send(&WorkerInbound::Fail {
886 job_id: job_id.clone(),
887 error: e.to_string(),
888 retryable: !is_unsupported_kind(&e),
889 })
890 .await;
891 record_fail_send(&fail_result, &job_id, &logs, &observers);
892 }
893 Err(e) => {
894 push_log_with_observers(
895 &logs,
896 Some(&observers),
897 "error",
898 "ws",
899 &format!("dispatch task panic: {e}"),
900 Some(job_id.clone()),
901 );
902 outcome = JobOutcome::Failed {
903 reason: e.to_string(),
904 };
905 let fail_result = sender
906 .send(&WorkerInbound::Fail {
907 job_id: job_id.clone(),
908 error: e.to_string(),
909 retryable: true,
910 })
911 .await;
912 record_fail_send(&fail_result, &job_id, &logs, &observers);
913 }
914 }
915
916 *observers.current_job.lock() = None;
919 record_recent_job(
920 &observers,
921 RecentJob {
922 job_id: job_id.clone(),
923 kind: task_kind,
924 model: job.model.clone(),
925 prompt: prompt_preview,
926 outcome,
927 started_at,
928 finished_at: chrono::Utc::now(),
929 },
930 );
931}
932
933#[cfg_attr(coverage_nightly, coverage(off))]
934fn spawn_reader(
935 mut receiver: crate::ws::client::WsReceiver,
936 event_tx: mpsc::UnboundedSender<SessionEvent>,
937 read_idle_timeout: Duration,
938) -> tokio::task::JoinHandle<()> {
939 tokio::spawn(async move {
940 loop {
941 match tokio::time::timeout(read_idle_timeout, receiver.recv()).await {
945 Ok(Ok(Some(frame))) => {
946 if event_tx.send(SessionEvent::Frame(frame)).is_err() {
947 break;
948 }
949 }
950 Ok(Ok(None)) => {
951 let _ =
952 event_tx.send(SessionEvent::Disconnected(WsClientError::ConnectionClosed));
953 break;
954 }
955 Ok(Err(e)) => {
956 let _ = event_tx.send(SessionEvent::Disconnected(e));
957 break;
958 }
959 Err(_elapsed) => {
960 let _ = event_tx.send(SessionEvent::Disconnected(WsClientError::Transport(
961 format!(
962 "no frames from server for {:?}; treating connection as dead",
963 read_idle_timeout
964 ),
965 )));
966 break;
967 }
968 }
969 }
970 })
971}
972
973#[cfg_attr(coverage_nightly, coverage(off))]
974fn spawn_heartbeat_pump(
975 capabilities: WorkerCapabilities,
976 sender: WsSender,
977 stop: Arc<AtomicBool>,
978 paused: Arc<AtomicBool>,
979 observers: WorkerObservers,
980 schedule: SessionSchedule,
981) -> tokio::task::JoinHandle<()> {
982 tokio::spawn(async move {
983 let mut interval = tokio::time::interval(schedule.heartbeat);
984 interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
985 loop {
986 interval.tick().await;
987 if stop.load(Ordering::SeqCst) {
988 break;
989 }
990 let mut caps = capabilities.clone();
994 caps.auto_enabled = !paused.load(Ordering::SeqCst);
995 let current_job_id = heartbeat_current_job_id(&observers);
996 if let Err(e) = sender
997 .send(&WorkerInbound::Heartbeat {
998 capabilities: caps,
999 current_job_id,
1000 })
1001 .await
1002 {
1003 warn!(target: TRACE_TARGET, error = %e, "heartbeat send failed");
1004 break;
1005 }
1006 }
1007 })
1008}
1009
1010fn heartbeat_current_job_id(observers: &WorkerObservers) -> Option<String> {
1011 observers
1012 .current_job
1013 .lock()
1014 .as_ref()
1015 .map(|job| job.job_id.clone())
1016}
1017
1018#[cfg_attr(coverage_nightly, coverage(off))]
1019fn spawn_log_shipper_pump(
1020 sender: WsSender,
1021 logs: Arc<Mutex<Vec<LogEntry>>>,
1022 stop: Arc<AtomicBool>,
1023 schedule: SessionSchedule,
1024) -> tokio::task::JoinHandle<()> {
1025 tokio::spawn(async move {
1026 let mut interval = tokio::time::interval(schedule.log_flush);
1027 interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
1028 loop {
1029 interval.tick().await;
1030 if stop.load(Ordering::SeqCst) {
1031 break;
1032 }
1033 let batch = {
1034 let mut guard = logs.lock();
1035 if guard.is_empty() {
1036 continue;
1037 }
1038 std::mem::take(&mut *guard)
1039 };
1040 if let Err(e) = sender
1041 .send(&WorkerInbound::LogBatch { entries: batch })
1042 .await
1043 {
1044 warn!(target: TRACE_TARGET, error = %e, "log batch send failed");
1045 break;
1046 }
1047 }
1048 })
1049}
1050
1051#[cfg_attr(coverage_nightly, coverage(off))]
1052fn spawn_shutdown_observer(
1053 stop: Arc<AtomicBool>,
1054 event_tx: mpsc::UnboundedSender<SessionEvent>,
1055 schedule: SessionSchedule,
1056) -> tokio::task::JoinHandle<()> {
1057 tokio::spawn(async move {
1058 loop {
1059 tokio::time::sleep(schedule.shutdown_tick).await;
1060 if stop.load(Ordering::SeqCst) {
1061 let _ = event_tx.send(SessionEvent::Stopped);
1062 break;
1063 }
1064 if event_tx.is_closed() {
1065 break;
1066 }
1067 }
1068 })
1069}
1070
1071fn backoff_for(attempt: u32, schedule: SessionSchedule) -> Duration {
1072 let factor = 2u64.saturating_pow(attempt.saturating_sub(1));
1073 let raw_ms = schedule.base_backoff_ms.saturating_mul(factor);
1074 Duration::from_millis(raw_ms.min(schedule.max_backoff_ms))
1075}
1076
1077fn offer_response_breadcrumb(
1092 label: &str,
1093 job_id: &str,
1094 result: &WsResult<()>,
1095) -> Option<(&'static str, String)> {
1096 match result {
1097 Ok(()) => None,
1098 Err(e) => Some((
1099 "error",
1100 format!("{label} send failed for offer {job_id}: {e}"),
1101 )),
1102 }
1103}
1104
1105fn fail_send_breadcrumb(job_id: &str, result: &WsResult<()>) -> Option<(&'static str, String)> {
1120 match result {
1121 Ok(()) => None,
1122 Err(e) => Some((
1123 "error",
1124 format!("failed to notify studio of job {job_id} failure: {e}"),
1125 )),
1126 }
1127}
1128
1129fn record_fail_send(
1137 result: &WsResult<()>,
1138 job_id: &str,
1139 logs: &Arc<Mutex<Vec<LogEntry>>>,
1140 observers: &WorkerObservers,
1141) {
1142 if let Some((level, message)) = fail_send_breadcrumb(job_id, result) {
1143 push_log_with_observers(
1144 logs,
1145 Some(observers),
1146 level,
1147 "ws",
1148 &message,
1149 Some(job_id.to_string()),
1150 );
1151 }
1152}
1153
1154#[cfg(test)]
1155mod tests {
1156 use super::*;
1157
1158 #[test]
1159 fn offer_response_breadcrumb_is_silent_on_success() {
1160 assert!(offer_response_breadcrumb("accept", "j-1", &Ok(())).is_none());
1164 assert!(offer_response_breadcrumb("reject", "j-2", &Ok(())).is_none());
1165 }
1166
1167 #[test]
1168 fn try_reserve_worker_only_allows_one_in_flight_job() {
1169 let busy = AtomicBool::new(false);
1170 assert!(try_reserve_worker(&busy));
1171 assert!(!try_reserve_worker(&busy));
1172 }
1173
1174 #[test]
1175 fn heartbeat_current_job_id_uses_actual_job_id() {
1176 let observers = WorkerObservers::default();
1177 assert_eq!(heartbeat_current_job_id(&observers), None);
1178 *observers.current_job.lock() = Some(CurrentJob {
1179 job_id: "job-42".into(),
1180 kind: crate::types::TaskKind::Image,
1181 model: "synthetic".into(),
1182 prompt: "prompt".into(),
1183 started_at: chrono::Utc::now(),
1184 });
1185 assert_eq!(
1186 heartbeat_current_job_id(&observers).as_deref(),
1187 Some("job-42")
1188 );
1189 }
1190
1191 #[test]
1192 fn offer_response_breadcrumb_reports_accept_send_failure() {
1193 let (level, msg) =
1194 offer_response_breadcrumb("accept", "j-1", &Err(WsClientError::ConnectionClosed))
1195 .expect("a failed accept send must surface a breadcrumb");
1196 assert_eq!(level, "error");
1197 assert!(msg.contains("accept send failed"), "got: {msg}");
1198 assert!(msg.contains("j-1"), "must name the job: {msg}");
1199 assert!(
1200 msg.contains("connection closed"),
1201 "must carry the cause: {msg}"
1202 );
1203 }
1204
1205 #[test]
1206 fn offer_response_breadcrumb_reports_reject_send_failure() {
1207 let (level, msg) = offer_response_breadcrumb(
1208 "reject",
1209 "j-9",
1210 &Err(WsClientError::Transport("sink gone".into())),
1211 )
1212 .expect("a failed reject send must surface a breadcrumb");
1213 assert_eq!(level, "error");
1214 assert!(msg.contains("reject send failed"), "got: {msg}");
1215 assert!(msg.contains("j-9"), "must name the job: {msg}");
1216 assert!(msg.contains("sink gone"), "must carry the cause: {msg}");
1217 }
1218
1219 #[test]
1220 fn fail_send_breadcrumb_is_silent_on_success() {
1221 assert!(fail_send_breadcrumb("j-1", &Ok(())).is_none());
1225 }
1226
1227 #[test]
1228 fn fail_send_breadcrumb_reports_send_failure() {
1229 let (level, msg) = fail_send_breadcrumb("j-7", &Err(WsClientError::ConnectionClosed))
1230 .expect("a dropped Fail send must surface a breadcrumb");
1231 assert_eq!(level, "error");
1232 assert!(msg.contains("j-7"), "must name the job: {msg}");
1233 assert!(
1234 msg.contains("connection closed"),
1235 "must carry the cause: {msg}"
1236 );
1237 }
1238
1239 #[test]
1240 fn fail_send_breadcrumb_carries_transport_cause() {
1241 let (level, msg) =
1242 fail_send_breadcrumb("j-3", &Err(WsClientError::Transport("sink gone".into())))
1243 .expect("a dropped Fail send must surface a breadcrumb");
1244 assert_eq!(level, "error");
1245 assert!(msg.contains("j-3"), "must name the job: {msg}");
1246 assert!(msg.contains("sink gone"), "must carry the cause: {msg}");
1247 }
1248
1249 #[test]
1250 fn backoff_grows_exponentially_until_cap() {
1251 let schedule = SessionSchedule {
1252 base_backoff_ms: 100,
1253 max_backoff_ms: 1_000,
1254 heartbeat: Duration::from_secs(1),
1255 log_flush: Duration::from_secs(1),
1256 shutdown_tick: Duration::from_secs(1),
1257 read_idle_timeout: Duration::from_secs(1),
1258 };
1259 assert_eq!(backoff_for(1, schedule), Duration::from_millis(100));
1260 assert_eq!(backoff_for(2, schedule), Duration::from_millis(200));
1261 assert_eq!(backoff_for(3, schedule), Duration::from_millis(400));
1262 assert_eq!(backoff_for(4, schedule), Duration::from_millis(800));
1263 assert_eq!(backoff_for(5, schedule), Duration::from_millis(1_000));
1265 assert_eq!(backoff_for(10, schedule), Duration::from_millis(1_000));
1266 }
1267
1268 #[test]
1269 fn has_credentials_false_when_either_missing() {
1270 let mut cfg = crate::config::Config::default();
1271 let shared = crate::config::shared(cfg.clone());
1272 assert!(!has_credentials(&shared), "both missing");
1273 cfg.worker_id = Some("w-1".into());
1274 let shared = crate::config::shared(cfg.clone());
1275 assert!(!has_credentials(&shared), "only worker_id");
1276 cfg.worker_id = None;
1277 cfg.auth_token = Some("tok".into());
1278 let shared = crate::config::shared(cfg.clone());
1279 assert!(!has_credentials(&shared), "only auth_token");
1280 }
1281
1282 #[test]
1283 fn has_credentials_true_when_both_present() {
1284 let cfg = crate::config::Config {
1285 worker_id: Some("w-1".into()),
1286 auth_token: Some("tok".into()),
1287 ..crate::config::Config::default()
1288 };
1289 let shared = crate::config::shared(cfg);
1290 assert!(has_credentials(&shared));
1291 }
1292
1293 #[test]
1294 fn has_credentials_false_when_empty_strings() {
1295 let cfg = crate::config::Config {
1296 worker_id: Some("".into()),
1297 auth_token: Some("".into()),
1298 ..crate::config::Config::default()
1299 };
1300 let shared = crate::config::shared(cfg);
1301 assert!(!has_credentials(&shared));
1302 }
1303}