1use std::sync::{
13 atomic::{AtomicBool, Ordering},
14 Arc,
15};
16use std::time::Duration;
17
18use anyhow::{anyhow, Result};
19use parking_lot::Mutex;
20use tokio::sync::mpsc;
21use tracing::{info, warn};
22
23use crate::config::SharedConfig;
24use crate::engine::Engine;
25use crate::http::ApiClient;
26use crate::runtime::{
27 is_unsupported_kind, prompt_for, push_log_with_observers, record_recent_job, truncate_prompt,
28 wait_with_stop, CurrentJob, JobOutcome, RecentJob, WorkerObservers,
29};
30use crate::types::{LogEntry, TaskResult, WorkerCapabilities};
31use crate::ws::client::{connect, WsClientError, WsResult, WsSender};
32use crate::ws::types::{HelloFrame, JobOfferClaim, WorkerInbound, WorkerOutbound};
33
34const TRACE_TARGET: &str = "studio_worker::ws::session";
36
37const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5);
38const LOG_FLUSH_INTERVAL: Duration = Duration::from_secs(1);
39const SHUTDOWN_TICK: Duration = Duration::from_millis(250);
40const BASE_BACKOFF_MS: u64 = 1_000;
41const MAX_BACKOFF_MS: u64 = 30_000;
42const DEFAULT_RECONNECT_ATTEMPTS: u32 = 5;
43const READ_IDLE_TIMEOUT: Duration = Duration::from_secs(20);
48
49#[derive(Debug)]
52pub enum SessionOutcome {
53 Stopped,
55 Disconnected,
57 AuthFailed(String),
59 Fatal(String),
61}
62
63#[derive(Debug, Clone, Copy)]
65pub struct SessionSchedule {
66 pub heartbeat: Duration,
67 pub log_flush: Duration,
68 pub shutdown_tick: Duration,
69 pub base_backoff_ms: u64,
70 pub max_backoff_ms: u64,
71 pub read_idle_timeout: Duration,
73}
74
75impl Default for SessionSchedule {
76 fn default() -> Self {
77 Self {
78 heartbeat: HEARTBEAT_INTERVAL,
79 log_flush: LOG_FLUSH_INTERVAL,
80 shutdown_tick: SHUTDOWN_TICK,
81 base_backoff_ms: BASE_BACKOFF_MS,
82 max_backoff_ms: MAX_BACKOFF_MS,
83 read_idle_timeout: READ_IDLE_TIMEOUT,
84 }
85 }
86}
87
88impl SessionSchedule {
89 pub fn fast_for_tests() -> Self {
90 Self {
91 heartbeat: Duration::from_millis(5),
92 log_flush: Duration::from_millis(5),
93 shutdown_tick: Duration::from_millis(5),
94 base_backoff_ms: 1,
95 max_backoff_ms: 10,
96 read_idle_timeout: Duration::from_secs(5),
99 }
100 }
101}
102
103#[cfg_attr(coverage_nightly, coverage(off))]
111pub async fn spawn_ws_session(
112 cfg: SharedConfig,
113 stop: Arc<AtomicBool>,
114 logs: Arc<Mutex<Vec<LogEntry>>>,
115 busy: Arc<AtomicBool>,
116 paused: Arc<AtomicBool>,
117 observers: WorkerObservers,
118 schedule: SessionSchedule,
119) -> Result<()> {
120 let max_attempts = {
121 let guard = cfg.lock();
122 guard
123 .ws_reconnect_attempts
124 .unwrap_or(DEFAULT_RECONNECT_ATTEMPTS)
125 };
126
127 let mut attempt: u32 = 0;
128 let mut waiting_for_creds_logged = false;
129 loop {
130 if stop.load(Ordering::SeqCst) {
131 return Ok(());
132 }
133 if !has_credentials(&cfg) {
139 if !waiting_for_creds_logged {
140 push_log_with_observers(
141 &logs,
142 Some(&observers),
143 "info",
144 "ws",
145 "waiting for operator approval before opening the session",
146 None,
147 );
148 waiting_for_creds_logged = true;
149 }
150 wait_with_stop(Duration::from_secs(1), &stop, schedule.shutdown_tick).await;
151 continue;
152 }
153 waiting_for_creds_logged = false;
154
155 let welcomed = AtomicBool::new(false);
156 match run_one_session(
157 &cfg, &stop, &logs, &busy, &paused, &observers, schedule, &welcomed,
158 )
159 .await
160 {
161 Ok(SessionOutcome::Stopped) => return Ok(()),
162 Ok(SessionOutcome::AuthFailed(reason)) => {
163 push_log_with_observers(
164 &logs,
165 Some(&observers),
166 "error",
167 "ws",
168 &format!("auth failed: {reason}. Re-register the worker."),
169 None,
170 );
171 return Err(anyhow!("ws auth failed: {reason}"));
172 }
173 Ok(SessionOutcome::Fatal(reason)) => {
174 push_log_with_observers(
175 &logs,
176 Some(&observers),
177 "error",
178 "ws",
179 &format!("fatal: {reason}"),
180 None,
181 );
182 return Err(anyhow!("ws fatal: {reason}"));
183 }
184 Ok(SessionOutcome::Disconnected) | Err(_) => {
185 if welcomed.load(Ordering::SeqCst) {
189 attempt = 0;
190 }
191 attempt += 1;
192 if max_attempts > 0 && attempt > max_attempts {
193 push_log_with_observers(
194 &logs,
195 Some(&observers),
196 "error",
197 "ws",
198 &format!("giving up after {attempt} reconnect attempts"),
199 None,
200 );
201 return Err(anyhow!("ws reconnect cap reached"));
202 }
203 let backoff = backoff_for(attempt, schedule);
204 push_log_with_observers(
205 &logs,
206 Some(&observers),
207 "warn",
208 "ws",
209 &format!(
210 "disconnected; reconnect attempt {attempt} in {}ms",
211 backoff.as_millis()
212 ),
213 None,
214 );
215 wait_with_stop(backoff, &stop, schedule.shutdown_tick).await;
216 }
217 }
218 }
219}
220
221enum WelcomeOutcome {
226 Welcomed,
227 AuthFailed(String),
228 Fatal(String),
229 Disconnected,
230}
231
232#[cfg_attr(coverage_nightly, coverage(off))]
238async fn wait_for_welcome(
239 event_rx: &mut mpsc::UnboundedReceiver<SessionEvent>,
240 logs: &Arc<Mutex<Vec<LogEntry>>>,
241 observers: &WorkerObservers,
242) -> WelcomeOutcome {
243 while let Some(event) = event_rx.recv().await {
244 match event {
245 SessionEvent::Frame(WorkerOutbound::Welcome { worker_id: wid, .. }) => {
246 push_log_with_observers(
247 logs,
248 Some(observers),
249 "info",
250 "ws",
251 &format!("server welcomed {wid}"),
252 None,
253 );
254 return WelcomeOutcome::Welcomed;
255 }
256 SessionEvent::Frame(WorkerOutbound::Error { code, message }) => {
257 push_log_with_observers(
258 logs,
259 Some(observers),
260 "error",
261 "ws",
262 &format!("server error before welcome {code:?}: {message}"),
263 None,
264 );
265 return match code {
266 crate::ws::types::WorkerErrorCode::AuthFailed => {
267 WelcomeOutcome::AuthFailed(message)
268 }
269 _ => WelcomeOutcome::Fatal(message),
270 };
271 }
272 SessionEvent::Frame(other) => {
273 push_log_with_observers(
274 logs,
275 Some(observers),
276 "warn",
277 "ws",
278 &format!("server sent unexpected frame before welcome: {other:?}"),
279 None,
280 );
281 }
283 SessionEvent::Disconnected(WsClientError::AuthFailed { reason }) => {
284 return WelcomeOutcome::AuthFailed(reason);
285 }
286 SessionEvent::Disconnected(_) => return WelcomeOutcome::Disconnected,
287 SessionEvent::Stopped => return WelcomeOutcome::Disconnected,
288 }
289 }
290 WelcomeOutcome::Disconnected
291}
292
293fn has_credentials(cfg: &SharedConfig) -> bool {
297 let guard = cfg.lock();
298 guard
299 .worker_id
300 .as_deref()
301 .map(|s| !s.is_empty())
302 .unwrap_or(false)
303 && guard
304 .auth_token
305 .as_deref()
306 .map(|s| !s.is_empty())
307 .unwrap_or(false)
308}
309
310#[cfg_attr(coverage_nightly, coverage(off))]
313#[allow(clippy::too_many_arguments)]
316async fn run_one_session(
317 cfg: &SharedConfig,
318 stop: &Arc<AtomicBool>,
319 logs: &Arc<Mutex<Vec<LogEntry>>>,
320 busy: &Arc<AtomicBool>,
321 paused: &Arc<AtomicBool>,
322 observers: &WorkerObservers,
323 schedule: SessionSchedule,
324 welcomed: &AtomicBool,
325) -> Result<SessionOutcome> {
326 let (api_base_url, worker_id, auth_token) = {
327 let guard = cfg.lock();
328 (
329 guard.api_base_url.clone(),
330 guard.worker_id.clone().unwrap_or_default(),
331 guard.auth_token.clone().unwrap_or_default(),
332 )
333 };
334 if worker_id.is_empty() || auth_token.is_empty() {
335 return Ok(SessionOutcome::Fatal(
336 "worker_id or auth_token missing; run register".to_string(),
337 ));
338 }
339
340 push_log_with_observers(
341 logs,
342 Some(observers),
343 "info",
344 "ws",
345 &format!("connecting to {api_base_url}"),
346 None,
347 );
348 let client = match connect(&api_base_url, &worker_id, &auth_token).await {
349 Ok(c) => c,
350 Err(WsClientError::AuthFailed { reason }) => {
351 return Ok(SessionOutcome::AuthFailed(reason));
352 }
353 Err(e) => {
354 push_log_with_observers(
355 logs,
356 Some(observers),
357 "warn",
358 "ws",
359 &format!("connect failed: {e}"),
360 None,
361 );
362 return Ok(SessionOutcome::Disconnected);
363 }
364 };
365 let (sender, receiver) = client.split();
366
367 let engine = crate::engine::build(&cfg.lock())?;
369 let capabilities = crate::runtime::build_capabilities_with(
370 &cfg.lock(),
371 &*engine,
372 !paused.load(Ordering::SeqCst),
373 );
374 push_log_with_observers(
379 logs,
380 Some(observers),
381 "info",
382 "ws",
383 &crate::runtime::summarize_capabilities(&capabilities),
384 None,
385 );
386 sender
387 .send(&WorkerInbound::Hello(HelloFrame {
388 auth_token: auth_token.clone(),
389 capabilities: capabilities.clone(),
390 }))
391 .await
392 .map_err(|e| anyhow!("hello send failed: {e}"))?;
393 info!(target: TRACE_TARGET, worker_id = %worker_id, "hello sent");
394
395 let (event_tx, event_rx) = mpsc::unbounded_channel::<SessionEvent>();
396
397 let reader = spawn_reader(receiver, event_tx.clone(), schedule.read_idle_timeout);
399
400 let mut event_rx = event_rx;
409 match wait_for_welcome(&mut event_rx, logs, observers).await {
410 WelcomeOutcome::Welcomed => welcomed.store(true, Ordering::SeqCst),
411 WelcomeOutcome::AuthFailed(reason) => {
412 let _ = sender.close(1000, "auth failed").await;
413 let _ = reader.await;
414 return Ok(SessionOutcome::AuthFailed(reason));
415 }
416 WelcomeOutcome::Fatal(reason) => {
417 let _ = sender.close(1000, "protocol violation").await;
418 let _ = reader.await;
419 return Ok(SessionOutcome::Fatal(reason));
420 }
421 WelcomeOutcome::Disconnected => {
422 let _ = reader.await;
423 return Ok(SessionOutcome::Disconnected);
424 }
425 }
426
427 let capabilities_for_heartbeat = capabilities.clone();
433 let heartbeat = spawn_heartbeat_pump(
434 capabilities_for_heartbeat,
435 sender.clone(),
436 stop.clone(),
437 paused.clone(),
438 observers.clone(),
439 schedule,
440 );
441
442 let log_shipper = spawn_log_shipper_pump(sender.clone(), logs.clone(), stop.clone(), schedule);
444
445 let shutdown_observer = spawn_shutdown_observer(stop.clone(), event_tx.clone(), schedule);
447 drop(event_tx);
448
449 let engine_arc: Arc<dyn Engine> = engine.into();
450 let ctx = SessionContext {
451 sender: sender.clone(),
452 engine: engine_arc,
453 logs: logs.clone(),
454 busy: busy.clone(),
455 paused: paused.clone(),
456 observers: observers.clone(),
457 api_base_url: api_base_url.clone(),
458 worker_id: worker_id.clone(),
459 auth_token: auth_token.clone(),
460 };
461 let outcome = run_dispatch_loop(ctx, event_rx).await;
462
463 reader.abort();
470 heartbeat.abort();
471 log_shipper.abort();
472 shutdown_observer.abort();
473 let _ = sender.close(1000, "session ended").await;
474 let _ = reader.await;
475 let _ = heartbeat.await;
476 let _ = log_shipper.await;
477 let _ = shutdown_observer.await;
478 Ok(outcome)
479}
480
481#[derive(Debug)]
483enum SessionEvent {
484 Frame(WorkerOutbound),
486 Stopped,
488 Disconnected(WsClientError),
490}
491
492struct SessionContext {
495 sender: WsSender,
496 engine: Arc<dyn Engine>,
497 logs: Arc<Mutex<Vec<LogEntry>>>,
498 busy: Arc<AtomicBool>,
499 paused: Arc<AtomicBool>,
500 observers: WorkerObservers,
501 api_base_url: String,
502 worker_id: String,
503 auth_token: String,
504}
505
506#[cfg_attr(coverage_nightly, coverage(off))]
507async fn run_dispatch_loop(
508 ctx: SessionContext,
509 mut event_rx: mpsc::UnboundedReceiver<SessionEvent>,
510) -> SessionOutcome {
511 while let Some(event) = event_rx.recv().await {
512 match event {
513 SessionEvent::Disconnected(WsClientError::AuthFailed { reason }) => {
514 return SessionOutcome::AuthFailed(reason);
515 }
516 SessionEvent::Disconnected(_) => return SessionOutcome::Disconnected,
517 SessionEvent::Stopped => return SessionOutcome::Stopped,
518 SessionEvent::Frame(frame) => match frame {
519 WorkerOutbound::Welcome { worker_id: wid, .. } => {
520 push_log_with_observers(
521 &ctx.logs,
522 Some(&ctx.observers),
523 "info",
524 "ws",
525 &format!("server welcomed {wid}"),
526 None,
527 );
528 }
529 WorkerOutbound::Offer { claim } => {
530 handle_offer(&ctx, *claim);
531 }
532 WorkerOutbound::Error { code, message } => {
533 push_log_with_observers(
534 &ctx.logs,
535 Some(&ctx.observers),
536 "error",
537 "ws",
538 &format!("server error {code:?}: {message}"),
539 None,
540 );
541 return match code {
542 crate::ws::types::WorkerErrorCode::AuthFailed => {
543 SessionOutcome::AuthFailed(message)
544 }
545 _ => SessionOutcome::Fatal(message),
546 };
547 }
548 WorkerOutbound::HeartbeatAck
549 | WorkerOutbound::CompleteAck { .. }
550 | WorkerOutbound::FailAck { .. } => {
551 }
553 },
554 }
555 }
556 SessionOutcome::Disconnected
557}
558
559#[cfg_attr(coverage_nightly, coverage(off))]
560fn handle_offer(ctx: &SessionContext, claim: JobOfferClaim) {
561 let job_id = claim.job_id.clone();
562 push_log_with_observers(
563 &ctx.logs,
564 Some(&ctx.observers),
565 "info",
566 "ws",
567 &format!(
568 "offer received {job_id} model={} vram={}",
569 claim.model, claim.vram_gb_estimate
570 ),
571 Some(job_id.clone()),
572 );
573 if ctx.paused.load(Ordering::SeqCst) {
577 push_log_with_observers(
578 &ctx.logs,
579 Some(&ctx.observers),
580 "info",
581 "ws",
582 &format!("rejecting offer {job_id}: worker is paused"),
583 Some(job_id.clone()),
584 );
585 spawn_reject_offer(
586 ctx.sender.clone(),
587 ctx.logs.clone(),
588 ctx.observers.clone(),
589 job_id,
590 "worker paused by operator",
591 );
592 return;
593 }
594 if !try_reserve_worker(&ctx.busy) {
595 push_log_with_observers(
596 &ctx.logs,
597 Some(&ctx.observers),
598 "info",
599 "ws",
600 &format!("rejecting offer {job_id}: worker is already busy"),
601 Some(job_id.clone()),
602 );
603 spawn_reject_offer(
604 ctx.sender.clone(),
605 ctx.logs.clone(),
606 ctx.observers.clone(),
607 job_id,
608 "worker already has an in-flight job",
609 );
610 return;
611 }
612 let job = claim.into_job_claim();
613 let task_kind = job.task.kind();
614 let full_prompt = prompt_for(&job.task);
622 let prompt_preview = truncate_prompt(&full_prompt);
623 let started_at = chrono::Utc::now();
624
625 let busy_flag = ctx.busy.clone();
626 let logs_for_task = ctx.logs.clone();
627 let observers_for_task = ctx.observers.clone();
628 let sender_for_task = ctx.sender.clone();
629 let engine_for_task = ctx.engine.clone();
630 let api_base_url = ctx.api_base_url.clone();
631 let worker_id = ctx.worker_id.clone();
632 let auth_token = ctx.auth_token.clone();
633 tokio::spawn(async move {
634 let accept_result = sender_for_task
635 .send(&WorkerInbound::Accept {
636 job_id: job_id.clone(),
637 })
638 .await;
639 if let Some((level, message)) = offer_response_breadcrumb("accept", &job_id, &accept_result)
640 {
641 push_log_with_observers(
642 &logs_for_task,
643 Some(&observers_for_task),
644 level,
645 "ws",
646 &message,
647 Some(job_id.clone()),
648 );
649 }
650 if accept_result.is_err() {
651 busy_flag.store(false, Ordering::SeqCst);
652 return;
653 }
654
655 *observers_for_task.current_job.lock() = Some(CurrentJob {
657 job_id: job_id.clone(),
658 kind: task_kind,
659 model: job.model.clone(),
660 prompt: prompt_preview.clone(),
661 started_at,
662 });
663
664 run_offered_job(
665 sender_for_task,
666 engine_for_task,
667 logs_for_task,
668 observers_for_task,
669 api_base_url,
670 worker_id,
671 auth_token,
672 job,
673 started_at,
674 task_kind,
675 full_prompt,
676 prompt_preview,
677 )
678 .await;
679 busy_flag.store(false, Ordering::SeqCst);
680 });
681}
682
683fn try_reserve_worker(busy: &AtomicBool) -> bool {
684 busy.compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
685 .is_ok()
686}
687
688fn spawn_reject_offer(
689 sender: WsSender,
690 logs: Arc<Mutex<Vec<LogEntry>>>,
691 observers: WorkerObservers,
692 job_id: String,
693 reason: &'static str,
694) {
695 tokio::spawn(async move {
696 let result = sender
697 .send(&WorkerInbound::Reject {
698 job_id: job_id.clone(),
699 reason: reason.to_string(),
700 })
701 .await;
702 if let Some((level, message)) = offer_response_breadcrumb("reject", &job_id, &result) {
703 push_log_with_observers(&logs, Some(&observers), level, "ws", &message, Some(job_id));
704 }
705 });
706}
707
708#[allow(clippy::too_many_arguments)]
709#[cfg_attr(coverage_nightly, coverage(off))]
710async fn run_offered_job(
711 sender: WsSender,
712 engine: Arc<dyn Engine>,
713 logs: Arc<Mutex<Vec<LogEntry>>>,
714 observers: WorkerObservers,
715 api_base_url: String,
716 worker_id: String,
717 auth_token: String,
718 job: crate::types::JobClaim,
719 started_at: chrono::DateTime<chrono::Utc>,
720 task_kind: crate::types::TaskKind,
721 full_prompt: String,
722 prompt_preview: String,
723) {
724 let start = std::time::Instant::now();
725 let dispatch = tokio::task::spawn_blocking({
730 let model = job.model.clone();
731 let model_source = job.model_source.clone();
732 let task_for_engine = job.task.clone();
733 let engine = engine.clone();
734 move || -> Result<TaskResult> {
735 engine.dispatch_with_source(&model, task_for_engine, &model_source)
736 }
737 })
738 .await;
739
740 let job_id = job.job_id.clone();
741 #[allow(unused_assignments)]
748 let mut outcome = JobOutcome::Failed {
749 reason: "dispatch did not run to completion".to_string(),
750 };
751 match dispatch {
752 Ok(Ok(result)) => {
753 push_log_with_observers(
754 &logs,
755 Some(&observers),
756 "info",
757 "ws",
758 &format!("{} dispatched in {:?}", task_kind.as_str(), start.elapsed()),
759 Some(job_id.clone()),
760 );
761 match result {
762 TaskResult::Image { bytes, ext }
763 | TaskResult::AudioTts { bytes, ext }
764 | TaskResult::Video { bytes, ext } => {
765 let upload_result = tokio::task::spawn_blocking({
768 let api_base_url = api_base_url.clone();
769 let job_id = job_id.clone();
770 let auth_token = auth_token.clone();
771 let worker_id = worker_id.clone();
772 let prompt = full_prompt.clone();
773 move || -> Result<()> {
774 let api = ApiClient::new(api_base_url)?;
775 api.complete(&worker_id, &auth_token, &job_id, &ext, &prompt, bytes)
776 }
777 })
778 .await;
779 let msg = match upload_result {
780 Ok(Ok(())) => None,
781 Ok(Err(e)) => Some(e.to_string()),
782 Err(e) => Some(format!("upload task panic: {e}")),
783 };
784 if let Some(msg) = msg {
785 push_log_with_observers(
786 &logs,
787 Some(&observers),
788 "error",
789 "ws",
790 &msg,
791 Some(job_id.clone()),
792 );
793 outcome = JobOutcome::Failed {
794 reason: msg.clone(),
795 };
796 let fail_result = sender
797 .send(&WorkerInbound::Fail {
798 job_id: job_id.clone(),
799 error: msg,
800 retryable: true,
801 })
802 .await;
803 record_fail_send(&fail_result, &job_id, &logs, &observers);
804 } else {
805 push_log_with_observers(
806 &logs,
807 Some(&observers),
808 "info",
809 "ws",
810 "binary upload ok",
811 Some(job_id.clone()),
812 );
813 outcome = JobOutcome::Completed;
814 }
829 }
830 TaskResult::Llm { json } | TaskResult::AudioStt { json } => {
831 match sender
838 .send(&WorkerInbound::CompleteJson {
839 job_id: job_id.clone(),
840 result: json,
841 prompt: Some(full_prompt.clone()),
842 })
843 .await
844 {
845 Ok(()) => {
846 push_log_with_observers(
847 &logs,
848 Some(&observers),
849 "info",
850 "ws",
851 "json result sent",
852 Some(job_id.clone()),
853 );
854 outcome = JobOutcome::Completed;
855 }
856 Err(e) => {
857 let msg = format!("failed to send result: {e}");
858 push_log_with_observers(
859 &logs,
860 Some(&observers),
861 "error",
862 "ws",
863 &msg,
864 Some(job_id.clone()),
865 );
866 outcome = JobOutcome::Failed { reason: msg };
867 }
868 }
869 }
870 }
871 }
872 Ok(Err(e)) => {
873 warn!(target: TRACE_TARGET, error = %e, "engine dispatch failed");
874 push_log_with_observers(
875 &logs,
876 Some(&observers),
877 "error",
878 "ws",
879 &format!("dispatch failed: {e}"),
880 Some(job_id.clone()),
881 );
882 outcome = JobOutcome::Failed {
883 reason: e.to_string(),
884 };
885 let fail_result = sender
886 .send(&WorkerInbound::Fail {
887 job_id: job_id.clone(),
888 error: e.to_string(),
889 retryable: !is_unsupported_kind(&e),
890 })
891 .await;
892 record_fail_send(&fail_result, &job_id, &logs, &observers);
893 }
894 Err(e) => {
895 push_log_with_observers(
896 &logs,
897 Some(&observers),
898 "error",
899 "ws",
900 &format!("dispatch task panic: {e}"),
901 Some(job_id.clone()),
902 );
903 outcome = JobOutcome::Failed {
904 reason: e.to_string(),
905 };
906 let fail_result = sender
907 .send(&WorkerInbound::Fail {
908 job_id: job_id.clone(),
909 error: e.to_string(),
910 retryable: true,
911 })
912 .await;
913 record_fail_send(&fail_result, &job_id, &logs, &observers);
914 }
915 }
916
917 *observers.current_job.lock() = None;
920 record_recent_job(
921 &observers,
922 RecentJob {
923 job_id: job_id.clone(),
924 kind: task_kind,
925 model: job.model.clone(),
926 prompt: prompt_preview,
927 outcome,
928 started_at,
929 finished_at: chrono::Utc::now(),
930 },
931 );
932}
933
934#[cfg_attr(coverage_nightly, coverage(off))]
935fn spawn_reader(
936 mut receiver: crate::ws::client::WsReceiver,
937 event_tx: mpsc::UnboundedSender<SessionEvent>,
938 read_idle_timeout: Duration,
939) -> tokio::task::JoinHandle<()> {
940 tokio::spawn(async move {
941 loop {
942 match tokio::time::timeout(read_idle_timeout, receiver.recv()).await {
946 Ok(Ok(Some(frame))) => {
947 if event_tx.send(SessionEvent::Frame(frame)).is_err() {
948 break;
949 }
950 }
951 Ok(Ok(None)) => {
952 let _ =
953 event_tx.send(SessionEvent::Disconnected(WsClientError::ConnectionClosed));
954 break;
955 }
956 Ok(Err(e)) => {
957 let _ = event_tx.send(SessionEvent::Disconnected(e));
958 break;
959 }
960 Err(_elapsed) => {
961 let _ = event_tx.send(SessionEvent::Disconnected(WsClientError::Transport(
962 format!(
963 "no frames from server for {:?}; treating connection as dead",
964 read_idle_timeout
965 ),
966 )));
967 break;
968 }
969 }
970 }
971 })
972}
973
974#[cfg_attr(coverage_nightly, coverage(off))]
975fn spawn_heartbeat_pump(
976 capabilities: WorkerCapabilities,
977 sender: WsSender,
978 stop: Arc<AtomicBool>,
979 paused: Arc<AtomicBool>,
980 observers: WorkerObservers,
981 schedule: SessionSchedule,
982) -> tokio::task::JoinHandle<()> {
983 tokio::spawn(async move {
984 let mut interval = tokio::time::interval(schedule.heartbeat);
985 interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
986 loop {
987 interval.tick().await;
988 if stop.load(Ordering::SeqCst) {
989 break;
990 }
991 let mut caps = capabilities.clone();
995 caps.auto_enabled = !paused.load(Ordering::SeqCst);
996 let current_job_id = heartbeat_current_job_id(&observers);
997 if let Err(e) = sender
998 .send(&WorkerInbound::Heartbeat {
999 capabilities: caps,
1000 current_job_id,
1001 })
1002 .await
1003 {
1004 warn!(target: TRACE_TARGET, error = %e, "heartbeat send failed");
1005 break;
1006 }
1007 }
1008 })
1009}
1010
1011fn heartbeat_current_job_id(observers: &WorkerObservers) -> Option<String> {
1012 observers
1013 .current_job
1014 .lock()
1015 .as_ref()
1016 .map(|job| job.job_id.clone())
1017}
1018
1019#[cfg_attr(coverage_nightly, coverage(off))]
1020fn spawn_log_shipper_pump(
1021 sender: WsSender,
1022 logs: Arc<Mutex<Vec<LogEntry>>>,
1023 stop: Arc<AtomicBool>,
1024 schedule: SessionSchedule,
1025) -> tokio::task::JoinHandle<()> {
1026 tokio::spawn(async move {
1027 let mut interval = tokio::time::interval(schedule.log_flush);
1028 interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
1029 loop {
1030 interval.tick().await;
1031 if stop.load(Ordering::SeqCst) {
1032 break;
1033 }
1034 let batch = {
1035 let mut guard = logs.lock();
1036 if guard.is_empty() {
1037 continue;
1038 }
1039 std::mem::take(&mut *guard)
1040 };
1041 if let Err(e) = sender
1042 .send(&WorkerInbound::LogBatch { entries: batch })
1043 .await
1044 {
1045 warn!(target: TRACE_TARGET, error = %e, "log batch send failed");
1046 break;
1047 }
1048 }
1049 })
1050}
1051
1052#[cfg_attr(coverage_nightly, coverage(off))]
1053fn spawn_shutdown_observer(
1054 stop: Arc<AtomicBool>,
1055 event_tx: mpsc::UnboundedSender<SessionEvent>,
1056 schedule: SessionSchedule,
1057) -> tokio::task::JoinHandle<()> {
1058 tokio::spawn(async move {
1059 loop {
1060 tokio::time::sleep(schedule.shutdown_tick).await;
1061 if stop.load(Ordering::SeqCst) {
1062 let _ = event_tx.send(SessionEvent::Stopped);
1063 break;
1064 }
1065 if event_tx.is_closed() {
1066 break;
1067 }
1068 }
1069 })
1070}
1071
1072fn backoff_for(attempt: u32, schedule: SessionSchedule) -> Duration {
1073 let factor = 2u64.saturating_pow(attempt.saturating_sub(1));
1074 let raw_ms = schedule.base_backoff_ms.saturating_mul(factor);
1075 Duration::from_millis(raw_ms.min(schedule.max_backoff_ms))
1076}
1077
1078fn offer_response_breadcrumb(
1093 label: &str,
1094 job_id: &str,
1095 result: &WsResult<()>,
1096) -> Option<(&'static str, String)> {
1097 match result {
1098 Ok(()) => None,
1099 Err(e) => Some((
1100 "error",
1101 format!("{label} send failed for offer {job_id}: {e}"),
1102 )),
1103 }
1104}
1105
1106fn fail_send_breadcrumb(job_id: &str, result: &WsResult<()>) -> Option<(&'static str, String)> {
1121 match result {
1122 Ok(()) => None,
1123 Err(e) => Some((
1124 "error",
1125 format!("failed to notify studio of job {job_id} failure: {e}"),
1126 )),
1127 }
1128}
1129
1130fn record_fail_send(
1138 result: &WsResult<()>,
1139 job_id: &str,
1140 logs: &Arc<Mutex<Vec<LogEntry>>>,
1141 observers: &WorkerObservers,
1142) {
1143 if let Some((level, message)) = fail_send_breadcrumb(job_id, result) {
1144 push_log_with_observers(
1145 logs,
1146 Some(observers),
1147 level,
1148 "ws",
1149 &message,
1150 Some(job_id.to_string()),
1151 );
1152 }
1153}
1154
1155#[cfg(test)]
1156mod tests {
1157 use super::*;
1158
1159 #[test]
1160 fn offer_response_breadcrumb_is_silent_on_success() {
1161 assert!(offer_response_breadcrumb("accept", "j-1", &Ok(())).is_none());
1165 assert!(offer_response_breadcrumb("reject", "j-2", &Ok(())).is_none());
1166 }
1167
1168 #[test]
1169 fn try_reserve_worker_only_allows_one_in_flight_job() {
1170 let busy = AtomicBool::new(false);
1171 assert!(try_reserve_worker(&busy));
1172 assert!(!try_reserve_worker(&busy));
1173 }
1174
1175 #[test]
1176 fn heartbeat_current_job_id_uses_actual_job_id() {
1177 let observers = WorkerObservers::default();
1178 assert_eq!(heartbeat_current_job_id(&observers), None);
1179 *observers.current_job.lock() = Some(CurrentJob {
1180 job_id: "job-42".into(),
1181 kind: crate::types::TaskKind::Image,
1182 model: "synthetic".into(),
1183 prompt: "prompt".into(),
1184 started_at: chrono::Utc::now(),
1185 });
1186 assert_eq!(
1187 heartbeat_current_job_id(&observers).as_deref(),
1188 Some("job-42")
1189 );
1190 }
1191
1192 #[test]
1193 fn offer_response_breadcrumb_reports_accept_send_failure() {
1194 let (level, msg) =
1195 offer_response_breadcrumb("accept", "j-1", &Err(WsClientError::ConnectionClosed))
1196 .expect("a failed accept send must surface a breadcrumb");
1197 assert_eq!(level, "error");
1198 assert!(msg.contains("accept send failed"), "got: {msg}");
1199 assert!(msg.contains("j-1"), "must name the job: {msg}");
1200 assert!(
1201 msg.contains("connection closed"),
1202 "must carry the cause: {msg}"
1203 );
1204 }
1205
1206 #[test]
1207 fn offer_response_breadcrumb_reports_reject_send_failure() {
1208 let (level, msg) = offer_response_breadcrumb(
1209 "reject",
1210 "j-9",
1211 &Err(WsClientError::Transport("sink gone".into())),
1212 )
1213 .expect("a failed reject send must surface a breadcrumb");
1214 assert_eq!(level, "error");
1215 assert!(msg.contains("reject send failed"), "got: {msg}");
1216 assert!(msg.contains("j-9"), "must name the job: {msg}");
1217 assert!(msg.contains("sink gone"), "must carry the cause: {msg}");
1218 }
1219
1220 #[test]
1221 fn fail_send_breadcrumb_is_silent_on_success() {
1222 assert!(fail_send_breadcrumb("j-1", &Ok(())).is_none());
1226 }
1227
1228 #[test]
1229 fn fail_send_breadcrumb_reports_send_failure() {
1230 let (level, msg) = fail_send_breadcrumb("j-7", &Err(WsClientError::ConnectionClosed))
1231 .expect("a dropped Fail send must surface a breadcrumb");
1232 assert_eq!(level, "error");
1233 assert!(msg.contains("j-7"), "must name the job: {msg}");
1234 assert!(
1235 msg.contains("connection closed"),
1236 "must carry the cause: {msg}"
1237 );
1238 }
1239
1240 #[test]
1241 fn fail_send_breadcrumb_carries_transport_cause() {
1242 let (level, msg) =
1243 fail_send_breadcrumb("j-3", &Err(WsClientError::Transport("sink gone".into())))
1244 .expect("a dropped Fail send must surface a breadcrumb");
1245 assert_eq!(level, "error");
1246 assert!(msg.contains("j-3"), "must name the job: {msg}");
1247 assert!(msg.contains("sink gone"), "must carry the cause: {msg}");
1248 }
1249
1250 #[test]
1251 fn backoff_grows_exponentially_until_cap() {
1252 let schedule = SessionSchedule {
1253 base_backoff_ms: 100,
1254 max_backoff_ms: 1_000,
1255 heartbeat: Duration::from_secs(1),
1256 log_flush: Duration::from_secs(1),
1257 shutdown_tick: Duration::from_secs(1),
1258 read_idle_timeout: Duration::from_secs(1),
1259 };
1260 assert_eq!(backoff_for(1, schedule), Duration::from_millis(100));
1261 assert_eq!(backoff_for(2, schedule), Duration::from_millis(200));
1262 assert_eq!(backoff_for(3, schedule), Duration::from_millis(400));
1263 assert_eq!(backoff_for(4, schedule), Duration::from_millis(800));
1264 assert_eq!(backoff_for(5, schedule), Duration::from_millis(1_000));
1266 assert_eq!(backoff_for(10, schedule), Duration::from_millis(1_000));
1267 }
1268
1269 #[test]
1270 fn has_credentials_false_when_either_missing() {
1271 let mut cfg = crate::config::Config::default();
1272 let shared = crate::config::shared(cfg.clone());
1273 assert!(!has_credentials(&shared), "both missing");
1274 cfg.worker_id = Some("w-1".into());
1275 let shared = crate::config::shared(cfg.clone());
1276 assert!(!has_credentials(&shared), "only worker_id");
1277 cfg.worker_id = None;
1278 cfg.auth_token = Some("tok".into());
1279 let shared = crate::config::shared(cfg.clone());
1280 assert!(!has_credentials(&shared), "only auth_token");
1281 }
1282
1283 #[test]
1284 fn has_credentials_true_when_both_present() {
1285 let cfg = crate::config::Config {
1286 worker_id: Some("w-1".into()),
1287 auth_token: Some("tok".into()),
1288 ..crate::config::Config::default()
1289 };
1290 let shared = crate::config::shared(cfg);
1291 assert!(has_credentials(&shared));
1292 }
1293
1294 #[test]
1295 fn has_credentials_false_when_empty_strings() {
1296 let cfg = crate::config::Config {
1297 worker_id: Some("".into()),
1298 auth_token: Some("".into()),
1299 ..crate::config::Config::default()
1300 };
1301 let shared = crate::config::shared(cfg);
1302 assert!(!has_credentials(&shared));
1303 }
1304}