1use crate::{
9 config::{self, Config, SharedConfig},
10 engine::{self, Engine},
11 http::ApiClient,
12 sys,
13 types::*,
14 update, AGENT_VERSION,
15};
16use anyhow::{anyhow, Result};
17use chrono::{DateTime, SecondsFormat, Utc};
18use parking_lot::Mutex;
19use std::{
20 collections::VecDeque,
21 sync::{
22 atomic::{AtomicBool, Ordering},
23 Arc,
24 },
25 time::Duration,
26};
27use tracing::{info, warn};
28
29pub const RECENT_JOBS_CAP: usize = 50;
39
40pub const PROMPT_PREVIEW_CHARS: usize = 200;
44
45#[derive(Debug, Clone)]
48pub struct CurrentJob {
49 pub job_id: String,
50 pub kind: TaskKind,
51 pub model: String,
52 pub prompt: String,
53 pub started_at: DateTime<Utc>,
54}
55
56#[derive(Debug, Clone, PartialEq, Eq)]
59pub enum JobOutcome {
60 Completed,
61 Failed { reason: String },
62}
63
64#[derive(Debug, Clone)]
66pub struct RecentJob {
67 pub job_id: String,
68 pub kind: TaskKind,
69 pub model: String,
70 pub prompt: String,
71 pub outcome: JobOutcome,
72 pub started_at: DateTime<Utc>,
73 pub finished_at: DateTime<Utc>,
74}
75
76#[derive(Debug, Clone, PartialEq, Eq)]
78pub enum HeartbeatOutcome {
79 Ok,
80 Err { reason: String },
81}
82
83#[derive(Debug, Clone)]
84pub struct HeartbeatStatus {
85 pub last_attempt_at: DateTime<Utc>,
86 pub outcome: HeartbeatOutcome,
87}
88
89#[derive(Clone, Default)]
94pub struct WorkerObservers {
95 pub current_job: Arc<Mutex<Option<CurrentJob>>>,
96 pub recent_jobs: Arc<Mutex<VecDeque<RecentJob>>>,
97 pub last_heartbeat: Arc<Mutex<Option<HeartbeatStatus>>>,
98}
99
100fn truncate_prompt(s: &str) -> String {
101 if s.chars().count() <= PROMPT_PREVIEW_CHARS {
102 return s.to_string();
103 }
104 let mut out: String = s.chars().take(PROMPT_PREVIEW_CHARS).collect();
105 out.push('…');
106 out
107}
108
109fn record_recent_job(observers: &WorkerObservers, entry: RecentJob) {
110 let mut ring = observers.recent_jobs.lock();
111 ring.push_front(entry);
112 while ring.len() > RECENT_JOBS_CAP {
113 ring.pop_back();
114 }
115}
116
117#[doc(hidden)]
121pub fn push_recent_job_for_tests(observers: &WorkerObservers, job_id: &str) {
122 let now = Utc::now();
123 record_recent_job(
124 observers,
125 RecentJob {
126 job_id: job_id.to_string(),
127 kind: TaskKind::Image,
128 model: "synthetic".into(),
129 prompt: String::new(),
130 outcome: JobOutcome::Completed,
131 started_at: now,
132 finished_at: now,
133 },
134 );
135}
136
137const TRACE_TARGET: &str = "studio_worker::runtime";
140
141pub const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5);
142pub const CLAIM_INTERVAL_IDLE: Duration = Duration::from_secs(2);
143pub const CLAIM_INTERVAL_AFTER_NULL: Duration = Duration::from_secs(5);
144pub const LOG_FLUSH_INTERVAL: Duration = Duration::from_secs(1);
145pub const AUTO_UPDATE_TICK: Duration = Duration::from_secs(60);
146
147#[derive(Debug, Clone, Copy)]
150pub struct LoopSchedule {
151 pub heartbeat: Duration,
152 pub claim_idle: Duration,
153 pub claim_after_null: Duration,
154 pub log_flush: Duration,
155 pub auto_update_tick: Duration,
156}
157
158impl Default for LoopSchedule {
159 fn default() -> Self {
160 Self {
161 heartbeat: HEARTBEAT_INTERVAL,
162 claim_idle: CLAIM_INTERVAL_IDLE,
163 claim_after_null: CLAIM_INTERVAL_AFTER_NULL,
164 log_flush: LOG_FLUSH_INTERVAL,
165 auto_update_tick: AUTO_UPDATE_TICK,
166 }
167 }
168}
169
170impl LoopSchedule {
171 pub fn fast_for_tests() -> Self {
174 Self {
175 heartbeat: Duration::from_millis(1),
176 claim_idle: Duration::from_millis(1),
177 claim_after_null: Duration::from_millis(1),
178 log_flush: Duration::from_millis(1),
179 auto_update_tick: Duration::from_millis(1),
180 }
181 }
182}
183
184pub async fn register(
189 config_path: Option<&str>,
190 bootstrap_override: Option<String>,
191 api_base_url: Option<String>,
192) -> Result<()> {
193 let (mut cfg, path) = config::load(config_path)?;
194 if let Some(token) = bootstrap_override {
195 cfg.bootstrap_token = token;
196 }
197 if let Some(url) = api_base_url {
198 cfg.api_base_url = url;
199 }
200 let engine = engine::build(&cfg)?;
201 let cap = build_capabilities(&cfg, &*engine);
202 let api_for_diag = cfg.api_base_url.clone();
203 let response = tokio::task::spawn_blocking({
204 let api_base_url = cfg.api_base_url.clone();
205 let bootstrap = cfg.bootstrap_token.clone();
206 let worker_id = cfg.worker_id.clone();
207 let cap = cap.clone();
208 move || -> Result<RegisterResponse> {
209 let api = ApiClient::new(api_base_url)?;
210 api.register(&bootstrap, cap, worker_id)
211 }
212 })
213 .await?
214 .map_err(|e| friendly_register_error(e, &api_for_diag))?;
215 cfg.worker_id = Some(response.worker_id.clone());
216 cfg.auth_token = Some(response.auth_token);
217 config::save(&cfg, &path)?;
218 info!(
219 worker_id = %response.worker_id,
220 api = %cfg.api_base_url,
221 "registered with studio API"
222 );
223 Ok(())
224}
225
226fn friendly_register_error(err: anyhow::Error, api_base_url: &str) -> anyhow::Error {
230 let message = format!("{:#}", err);
233 let is_connection_refused =
234 message.contains("Connection refused") || message.contains("ConnectionRefused");
235 if is_connection_refused {
236 anyhow!(
237 "could not reach the studio API at {api_base_url}: {message}\n\
238 \n\
239 Hint: pass --api-base-url <URL> on the register command, e.g.\n\
240 studio-worker register \\\n\
241 --bootstrap-token <TOKEN> \\\n\
242 --api-base-url https://studio.example.com\n\
243 \n\
244 The bootstrap token is the WORKER_BOOTSTRAP_TOKEN wrangler secret\n\
245 on the studio side (for local dev the default is `dev-bootstrap-token`)."
246 )
247 } else if message.contains("401") || message.contains("403") {
248 anyhow!(
249 "the studio API rejected our bootstrap token: {message}\n\
250 \n\
251 Check that --bootstrap-token matches the WORKER_BOOTSTRAP_TOKEN\n\
252 secret on the studio side."
253 )
254 } else {
255 err
256 }
257}
258
259pub async fn status(config_path: Option<&str>) -> Result<()> {
260 let (cfg, path) = config::load(config_path)?;
261 println!("{}", format_status(&cfg, &path));
262 Ok(())
263}
264
265pub fn format_status(cfg: &Config, path: &std::path::Path) -> String {
266 let mut out = String::new();
267 use std::fmt::Write as _;
268 let _ = writeln!(out, "config path: {}", path.display());
269 let _ = writeln!(out, "api_base_url: {}", cfg.api_base_url);
270 let _ = writeln!(
271 out,
272 "worker_id: {}",
273 cfg.worker_id.as_deref().unwrap_or("(not registered)")
274 );
275 let _ = writeln!(out, "engine: {}", cfg.engine);
276 let _ = writeln!(out, "vram_threshold_gb: {}", cfg.vram_threshold_gb);
277 let _ = writeln!(out, "auto_enabled: {}", cfg.auto_enabled);
278 let _ = writeln!(out, "auto_start: {}", cfg.auto_start);
279 let _ = writeln!(out, "auto_update: {}", cfg.auto_update_enabled);
280 let _ = writeln!(
281 out,
282 "update_interval: {}s",
283 cfg.auto_update_interval_secs
284 );
285 out
286}
287
288pub fn set_enabled(config_path: Option<&str>, enabled: bool) -> Result<()> {
289 let (mut cfg, path) = config::load(config_path)?;
290 cfg.auto_enabled = enabled;
291 config::save(&cfg, &path)?;
292 info!(
293 target: TRACE_TARGET,
294 op = "set_enabled",
295 auto_enabled = enabled,
296 config_path = path.display().to_string(),
297 "auto-claim flag persisted"
298 );
299 println!("auto_enabled = {enabled}");
300 Ok(())
301}
302
303pub fn set_threshold(config_path: Option<&str>, gb: f32) -> Result<()> {
304 if gb < 0.0 {
305 return Err(anyhow!("threshold must be >= 0"));
306 }
307 let (mut cfg, path) = config::load(config_path)?;
308 cfg.vram_threshold_gb = gb;
309 config::save(&cfg, &path)?;
310 info!(
311 target: TRACE_TARGET,
312 op = "set_threshold",
313 vram_threshold_gb = gb,
314 config_path = path.display().to_string(),
315 "VRAM threshold persisted"
316 );
317 println!("vram_threshold_gb = {gb}");
318 Ok(())
319}
320
321pub fn log_startup_banner(cfg: &Config, path: &std::path::Path) {
326 info!(
327 target: TRACE_TARGET,
328 op = "startup",
329 version = AGENT_VERSION,
330 config_path = path.display().to_string(),
331 api_base_url = cfg.api_base_url.as_str(),
332 engine = cfg.engine.as_str(),
333 vram_threshold_gb = cfg.vram_threshold_gb,
334 auto_enabled = cfg.auto_enabled,
335 auto_update_enabled = cfg.auto_update_enabled,
336 auto_update_interval_secs = cfg.auto_update_interval_secs,
337 worker_id = cfg.worker_id.as_deref().unwrap_or("(unregistered)"),
338 "studio-worker booting"
339 );
340}
341
342pub fn show_config(config_path: Option<&str>) -> Result<()> {
343 let (cfg, path) = config::load(config_path)?;
344 println!("# {}", path.display());
345 print!("{}", toml::to_string_pretty(&cfg)?);
346 Ok(())
347}
348
349pub async fn check_update(config_path: Option<&str>) -> Result<()> {
350 let (cfg, _) = config::load(config_path)?;
351 let current = semver::Version::parse(AGENT_VERSION)
352 .map_err(|e| anyhow!("invalid current version {AGENT_VERSION}: {e}"))?;
353 let outcome = tokio::task::spawn_blocking(move || {
354 update::check(&cfg.auto_update_feed, ¤t, cfg.auto_update_prerelease)
355 })
356 .await??;
357 println!("{}", format_check_outcome(&outcome));
358 Ok(())
359}
360
361pub fn format_check_outcome(outcome: &update::CheckOutcome) -> String {
362 match outcome {
363 update::CheckOutcome::UpToDate { current } => format!("up to date: {current}"),
364 update::CheckOutcome::NewerAvailable { current, latest } => {
365 format!("update available: {current} -> {latest}")
366 }
367 }
368}
369
370pub async fn run(config_path: Option<&str>) -> Result<()> {
375 let (mut cfg, path) = config::load(config_path)?;
376 log_startup_banner(&cfg, &path);
377 if cfg.worker_id.is_none() || cfg.auth_token.is_none() {
378 let engine = engine::build(&cfg)?;
379 let cap = build_capabilities(&cfg, &*engine);
380 let response = tokio::task::spawn_blocking({
381 let api_base_url = cfg.api_base_url.clone();
382 let bootstrap = cfg.bootstrap_token.clone();
383 move || -> Result<RegisterResponse> {
384 let api = ApiClient::new(api_base_url)?;
385 api.register(&bootstrap, cap, None)
386 }
387 })
388 .await??;
389 cfg.worker_id = Some(response.worker_id);
390 cfg.auth_token = Some(response.auth_token);
391 config::save(&cfg, &path)?;
392 info!(
393 worker_id = %cfg.worker_id.as_deref().unwrap_or(""),
394 "auto-registered on first run"
395 );
396 }
397
398 let cfg = config::shared(cfg);
399 let stop = Arc::new(AtomicBool::new(false));
400 let busy = Arc::new(AtomicBool::new(false));
401 let logs: Arc<Mutex<Vec<LogEntry>>> = Arc::new(Mutex::new(Vec::new()));
402 let observers = WorkerObservers::default();
403
404 let stop_clone = stop.clone();
405 tokio::spawn(async move {
406 let _ = tokio::signal::ctrl_c().await;
407 stop_clone.store(true, Ordering::SeqCst);
408 });
409
410 run_loops(cfg, stop, logs, busy, observers, LoopSchedule::default()).await;
411 Ok(())
412}
413
414pub async fn run_loops(
418 cfg: SharedConfig,
419 stop: Arc<AtomicBool>,
420 logs: Arc<Mutex<Vec<LogEntry>>>,
421 busy: Arc<AtomicBool>,
422 observers: WorkerObservers,
423 schedule: LoopSchedule,
424) {
425 let heartbeat = spawn_heartbeat(
426 cfg.clone(),
427 stop.clone(),
428 logs.clone(),
429 busy.clone(),
430 observers.clone(),
431 schedule,
432 );
433 let claim = spawn_claim_loop(
434 cfg.clone(),
435 stop.clone(),
436 logs.clone(),
437 busy.clone(),
438 observers.clone(),
439 schedule,
440 );
441 let log_shipper = spawn_log_shipper(cfg.clone(), stop.clone(), logs.clone(), schedule);
442 let auto_updater = spawn_auto_updater(
443 cfg.clone(),
444 stop.clone(),
445 logs.clone(),
446 busy.clone(),
447 schedule,
448 );
449 let _ = tokio::join!(heartbeat, claim, log_shipper, auto_updater);
450}
451
452#[derive(Debug, Clone, PartialEq, Eq)]
458pub enum ClaimOutcome {
459 RanJob,
461 NoJobs,
463 Error(String),
465 Skipped,
467}
468
469pub async fn heartbeat_tick(
473 cfg: &Config,
474 busy_now: bool,
475 logs: &Arc<Mutex<Vec<LogEntry>>>,
476 observers: &WorkerObservers,
477) -> Result<()> {
478 let engine = match engine::build(cfg) {
479 Ok(e) => e,
480 Err(e) => {
481 push_log(
482 logs,
483 "warn",
484 "heartbeat",
485 &format!("engine error: {e}"),
486 None,
487 );
488 *observers.last_heartbeat.lock() = Some(HeartbeatStatus {
489 last_attempt_at: Utc::now(),
490 outcome: HeartbeatOutcome::Err {
491 reason: format!("engine error: {e}"),
492 },
493 });
494 return Ok(());
495 }
496 };
497 let cap = build_capabilities(cfg, &*engine);
498 let token = cfg.auth_token.clone().unwrap_or_default();
499 let worker_id = cfg.worker_id.clone().unwrap_or_default();
500 let api_base_url = cfg.api_base_url.clone();
501 let logs_for_task = logs.clone();
502 let result = tokio::task::spawn_blocking(move || -> Result<()> {
503 let api = ApiClient::new(api_base_url)?;
504 api.heartbeat(&worker_id, &token, cap, None)
505 })
506 .await;
507 let outcome = match result {
508 Ok(Ok(())) => HeartbeatOutcome::Ok,
509 Ok(Err(e)) => {
510 push_log(
511 &logs_for_task,
512 "warn",
513 "heartbeat",
514 &format!("heartbeat failed (busy={busy_now}): {e}"),
515 None,
516 );
517 HeartbeatOutcome::Err {
518 reason: e.to_string(),
519 }
520 }
521 Err(e) => {
522 push_log(
523 &logs_for_task,
524 "warn",
525 "heartbeat",
526 &format!("heartbeat task panic: {e}"),
527 None,
528 );
529 HeartbeatOutcome::Err {
530 reason: format!("task panic: {e}"),
531 }
532 }
533 };
534 *observers.last_heartbeat.lock() = Some(HeartbeatStatus {
535 last_attempt_at: Utc::now(),
536 outcome,
537 });
538 Ok(())
539}
540
541pub async fn claim_tick(
543 cfg: &Config,
544 logs: &Arc<Mutex<Vec<LogEntry>>>,
545 busy: &Arc<AtomicBool>,
546 observers: &WorkerObservers,
547) -> ClaimOutcome {
548 if !cfg.auto_enabled {
549 return ClaimOutcome::Skipped;
550 }
551 let engine = match engine::build(cfg) {
552 Ok(e) => e,
553 Err(e) => {
554 push_log(logs, "warn", "claim", &format!("engine error: {e}"), None);
555 return ClaimOutcome::Error(e.to_string());
556 }
557 };
558 let token = cfg.auth_token.clone().unwrap_or_default();
559 let worker_id = cfg.worker_id.clone().unwrap_or_default();
560 let api_base_url = cfg.api_base_url.clone();
561
562 let claim_result = tokio::task::spawn_blocking({
563 let token = token.clone();
564 let worker_id = worker_id.clone();
565 let api_base_url = api_base_url.clone();
566 move || -> Result<(ApiClient, Option<JobClaim>)> {
567 let api = ApiClient::new(api_base_url)?;
568 let claim = api.claim(&worker_id, &token)?;
569 Ok((api, claim))
570 }
571 })
572 .await;
573
574 match claim_result {
575 Ok(Ok((api, Some(job)))) => {
576 busy.store(true, Ordering::SeqCst);
577 push_log(
578 logs,
579 "info",
580 "claim",
581 &format!(
582 "claimed job {} (model={}, vram={}GB)",
583 job.job_id, job.model, job.vram_gb_estimate
584 ),
585 Some(job.job_id.clone()),
586 );
587
588 let resolved = job.resolved_task();
593 let snapshot = CurrentJob {
594 job_id: job.job_id.clone(),
595 kind: resolved.kind(),
596 model: job.model.clone(),
597 prompt: truncate_prompt(&prompt_for(&resolved)),
598 started_at: Utc::now(),
599 };
600 *observers.current_job.lock() = Some(snapshot.clone());
601
602 let logs_clone = logs.clone();
605 let token_clone = token.clone();
606 let worker_id_clone = worker_id.clone();
607 let engine_handle = engine;
608 let join = tokio::task::spawn_blocking(move || {
609 run_job(
610 &api,
611 &token_clone,
612 &worker_id_clone,
613 &*engine_handle,
614 &logs_clone,
615 job,
616 )
617 })
618 .await;
619 let outcome = match join {
620 Ok(o) => o,
621 Err(e) => JobOutcome::Failed {
622 reason: format!("job task panic: {e}"),
623 },
624 };
625 *observers.current_job.lock() = None;
626 record_recent_job(
627 observers,
628 RecentJob {
629 job_id: snapshot.job_id,
630 kind: snapshot.kind,
631 model: snapshot.model,
632 prompt: snapshot.prompt,
633 outcome,
634 started_at: snapshot.started_at,
635 finished_at: Utc::now(),
636 },
637 );
638 busy.store(false, Ordering::SeqCst);
639 ClaimOutcome::RanJob
640 }
641 Ok(Ok((_api, None))) => ClaimOutcome::NoJobs,
642 Ok(Err(e)) => {
643 push_log(
644 logs,
645 "warn",
646 "claim",
647 &format!("claim request errored: {e}"),
648 None,
649 );
650 ClaimOutcome::Error(e.to_string())
651 }
652 Err(e) => {
653 push_log(
654 logs,
655 "warn",
656 "claim",
657 &format!("claim task panic: {e}"),
658 None,
659 );
660 ClaimOutcome::Error(e.to_string())
661 }
662 }
663}
664
665pub async fn log_shipper_tick(cfg: &Config, logs: &Arc<Mutex<Vec<LogEntry>>>) -> usize {
668 let token = cfg.auth_token.clone().unwrap_or_default();
669 let worker_id = cfg.worker_id.clone().unwrap_or_default();
670 if worker_id.is_empty() || token.is_empty() {
671 logs.lock().clear();
673 return 0;
674 }
675 let batch = {
676 let mut guard = logs.lock();
677 if guard.is_empty() {
678 return 0;
679 }
680 LogBatch {
681 entries: std::mem::take(&mut *guard),
682 }
683 };
684 let count = batch.entries.len();
685 let api_base_url = cfg.api_base_url.clone();
686 let _ = tokio::task::spawn_blocking(move || -> Result<()> {
687 let api = ApiClient::new(api_base_url)?;
688 api.ship_logs(&worker_id, &token, batch)
689 })
690 .await;
691 count
692}
693
694#[derive(Debug, Clone, PartialEq, Eq)]
696pub enum AutoUpdateDecision {
697 Disabled,
699 SkippedBusy,
701 UpToDate,
703 CheckError(String),
705 Updated,
707 UpdateError(String),
709}
710
711pub async fn auto_update_tick(
712 cfg: &Config,
713 busy: bool,
714 logs: &Arc<Mutex<Vec<LogEntry>>>,
715) -> AutoUpdateDecision {
716 if !cfg.auto_update_enabled {
717 return AutoUpdateDecision::Disabled;
718 }
719 if busy {
720 push_log(
721 logs,
722 "info",
723 "auto-update",
724 "skipping check: worker is busy on a job",
725 None,
726 );
727 return AutoUpdateDecision::SkippedBusy;
728 }
729 let feed = cfg.auto_update_feed.clone();
730 let prerelease = cfg.auto_update_prerelease;
731 let logs_for_task = logs.clone();
732 let outcome = tokio::task::spawn_blocking(move || -> Result<AutoUpdateDecision> {
733 let current = semver::Version::parse(AGENT_VERSION)
734 .map_err(|e| anyhow!("invalid AGENT_VERSION {AGENT_VERSION}: {e}"))?;
735 match update::check(&feed, ¤t, prerelease) {
736 Ok(update::CheckOutcome::UpToDate { current }) => {
737 push_log(
738 &logs_for_task,
739 "info",
740 "auto-update",
741 &format!("up to date at {current}"),
742 None,
743 );
744 Ok(AutoUpdateDecision::UpToDate)
745 }
746 Ok(update::CheckOutcome::NewerAvailable { current, latest }) => {
747 push_log(
748 &logs_for_task,
749 "info",
750 "auto-update",
751 &format!("update available {current} -> {latest}; applying"),
752 None,
753 );
754 match update::apply(&feed, &latest) {
755 Ok(()) => {
756 push_log(
757 &logs_for_task,
758 "info",
759 "auto-update",
760 "binary replaced; restart pending",
761 None,
762 );
763 Ok(AutoUpdateDecision::Updated)
764 }
765 Err(e) => {
766 push_log(
767 &logs_for_task,
768 "error",
769 "auto-update",
770 &format!("update failed: {e}"),
771 None,
772 );
773 Ok(AutoUpdateDecision::UpdateError(e.to_string()))
774 }
775 }
776 }
777 Err(e) => {
778 push_log(
779 &logs_for_task,
780 "warn",
781 "auto-update",
782 &format!("check failed: {e}"),
783 None,
784 );
785 Ok(AutoUpdateDecision::CheckError(e.to_string()))
786 }
787 }
788 })
789 .await;
790 match outcome {
791 Ok(Ok(decision)) => decision,
792 Ok(Err(e)) => AutoUpdateDecision::CheckError(e.to_string()),
793 Err(e) => AutoUpdateDecision::CheckError(e.to_string()),
794 }
795}
796
797pub fn spawn_heartbeat(
803 cfg: SharedConfig,
804 stop: Arc<AtomicBool>,
805 logs: Arc<Mutex<Vec<LogEntry>>>,
806 busy: Arc<AtomicBool>,
807 observers: WorkerObservers,
808 schedule: LoopSchedule,
809) -> tokio::task::JoinHandle<()> {
810 tokio::spawn(async move {
811 while !stop.load(Ordering::SeqCst) {
812 tokio::time::sleep(schedule.heartbeat).await;
813 let snapshot = cfg.lock().clone();
814 let busy_now = busy.load(Ordering::SeqCst);
815 let _ = heartbeat_tick(&snapshot, busy_now, &logs, &observers).await;
816 }
817 })
818}
819
820pub fn spawn_claim_loop(
821 cfg: SharedConfig,
822 stop: Arc<AtomicBool>,
823 logs: Arc<Mutex<Vec<LogEntry>>>,
824 busy: Arc<AtomicBool>,
825 observers: WorkerObservers,
826 schedule: LoopSchedule,
827) -> tokio::task::JoinHandle<()> {
828 tokio::spawn(async move {
829 let mut next_delay = schedule.claim_idle;
830 while !stop.load(Ordering::SeqCst) {
831 tokio::time::sleep(next_delay).await;
832 let snapshot = cfg.lock().clone();
833 let outcome = claim_tick(&snapshot, &logs, &busy, &observers).await;
834 next_delay = match outcome {
835 ClaimOutcome::RanJob => schedule.claim_idle,
836 _ => schedule.claim_after_null,
837 };
838 }
839 })
840}
841
842pub fn next_delay_for(outcome: &ClaimOutcome) -> Duration {
844 match outcome {
845 ClaimOutcome::RanJob => CLAIM_INTERVAL_IDLE,
846 ClaimOutcome::NoJobs | ClaimOutcome::Error(_) | ClaimOutcome::Skipped => {
847 CLAIM_INTERVAL_AFTER_NULL
848 }
849 }
850}
851
852pub fn spawn_log_shipper(
853 cfg: SharedConfig,
854 stop: Arc<AtomicBool>,
855 logs: Arc<Mutex<Vec<LogEntry>>>,
856 schedule: LoopSchedule,
857) -> tokio::task::JoinHandle<()> {
858 tokio::spawn(async move {
859 while !stop.load(Ordering::SeqCst) {
860 tokio::time::sleep(schedule.log_flush).await;
861 let snapshot = cfg.lock().clone();
862 let _ = log_shipper_tick(&snapshot, &logs).await;
863 }
864 })
865}
866
867pub fn spawn_auto_updater(
868 cfg: SharedConfig,
869 stop: Arc<AtomicBool>,
870 logs: Arc<Mutex<Vec<LogEntry>>>,
871 busy: Arc<AtomicBool>,
872 schedule: LoopSchedule,
873) -> tokio::task::JoinHandle<()> {
874 tokio::spawn(async move {
875 let mut elapsed = Duration::from_secs(0);
876 while !stop.load(Ordering::SeqCst) {
877 tokio::time::sleep(schedule.auto_update_tick).await;
878 elapsed += schedule.auto_update_tick;
879 let snapshot = cfg.lock().clone();
880 if elapsed < Duration::from_secs(snapshot.auto_update_interval_secs) {
881 continue;
882 }
883 elapsed = Duration::from_secs(0);
884 let busy_now = busy.load(Ordering::SeqCst);
885 let decision = auto_update_tick(&snapshot, busy_now, &logs).await;
886 if matches!(decision, AutoUpdateDecision::Updated) {
887 stop.store(true, Ordering::SeqCst);
888 update::restart_self();
889 }
890 }
891 })
892}
893
894fn run_job(
899 api: &ApiClient,
900 token: &str,
901 worker_id: &str,
902 engine: &dyn Engine,
903 logs: &Arc<Mutex<Vec<LogEntry>>>,
904 job: JobClaim,
905) -> JobOutcome {
906 let start = std::time::Instant::now();
907 let task = job.resolved_task();
908 let task_kind = task.kind();
909 let prompt_for_log = prompt_for(&task);
910 let result = engine.dispatch(&job.model, task);
911 match result {
912 Ok(task_result) => {
913 push_log(
914 logs,
915 "info",
916 "generate",
917 &format!(
918 "{} task generated in {:?}",
919 task_kind.as_str(),
920 start.elapsed()
921 ),
922 Some(job.job_id.clone()),
923 );
924 let outcome = match task_result {
925 TaskResult::Image { bytes, ext } => {
926 api.complete(worker_id, token, &job.job_id, &ext, &prompt_for_log, bytes)
927 }
928 TaskResult::AudioTts { bytes, ext } => {
929 api.complete(worker_id, token, &job.job_id, &ext, &prompt_for_log, bytes)
930 }
931 TaskResult::Video { bytes, ext } => {
932 api.complete(worker_id, token, &job.job_id, &ext, &prompt_for_log, bytes)
933 }
934 TaskResult::Llm { json } => {
935 api.complete_json(worker_id, token, &job.job_id, &prompt_for_log, &json)
936 }
937 TaskResult::AudioStt { json } => {
938 api.complete_json(worker_id, token, &job.job_id, &prompt_for_log, &json)
939 }
940 };
941 match outcome {
942 Err(e) => {
943 let reason = format!("complete failed: {e}");
944 push_log(logs, "error", "complete", &reason, Some(job.job_id.clone()));
945 JobOutcome::Failed { reason }
946 }
947 Ok(()) => {
948 push_log(
949 logs,
950 "info",
951 "complete",
952 "job uploaded",
953 Some(job.job_id.clone()),
954 );
955 JobOutcome::Completed
956 }
957 }
958 }
959 Err(e) => {
960 warn!("generate failed: {e:#}");
961 let reason = format!("generate failed: {e}");
962 push_log(logs, "error", "generate", &reason, Some(job.job_id.clone()));
963 let retryable = !is_unsupported_kind(&e);
964 let _ = api.fail(worker_id, token, &job.job_id, &e.to_string(), retryable);
965 JobOutcome::Failed { reason }
966 }
967 }
968}
969
970pub fn prompt_for(task: &Task) -> String {
971 match task {
972 Task::Image(p) => p.prompt.clone(),
973 Task::Llm(p) => p
974 .messages
975 .last()
976 .map(|m| m.content.clone())
977 .unwrap_or_default(),
978 Task::AudioStt(p) => p.input_url.clone(),
979 Task::AudioTts(p) => p.text.clone(),
980 Task::Video(p) => p.prompt.clone(),
981 }
982}
983
984pub fn is_unsupported_kind(e: &anyhow::Error) -> bool {
985 e.to_string().contains("cannot serve")
986}
987
988pub fn build_capabilities(cfg: &Config, engine: &dyn Engine) -> WorkerCapabilities {
993 let vram = sys::detect_vram_gb().unwrap_or(0.0);
994 let caps = engine.capabilities();
995 let supported_models_per_kind = caps.supported_models_per_kind.clone();
996 let task_kinds = caps.kinds();
997 let supported_models = {
1001 let mut all = caps.flat_models();
1002 all.sort();
1003 all.dedup();
1004 all
1005 };
1006 let supported_models = if cfg.supported_models_override.is_empty() {
1007 supported_models
1008 } else {
1009 cfg.supported_models_override.clone()
1010 };
1011
1012 WorkerCapabilities {
1013 machine_name: sys::machine_name(),
1014 username: sys::username(),
1015 agent_version: AGENT_VERSION.to_string(),
1016 engine: cfg.engine.clone(),
1017 vram_total_gb: vram,
1018 vram_threshold_gb: cfg.vram_threshold_gb,
1019 auto_enabled: cfg.auto_enabled,
1020 auto_start: cfg.auto_start,
1021 supported_models,
1022 task_kinds,
1023 supported_models_per_kind,
1024 }
1025}
1026
1027pub fn push_log(
1028 logs: &Arc<Mutex<Vec<LogEntry>>>,
1029 level: &str,
1030 category: &str,
1031 message: &str,
1032 job_id: Option<String>,
1033) {
1034 let entry = LogEntry {
1035 ts: Utc::now().to_rfc3339_opts(SecondsFormat::Millis, true),
1036 level: level.to_string(),
1037 category: category.to_string(),
1038 message: message.to_string(),
1039 job_id,
1040 };
1041 if level == "error" {
1042 tracing::error!(target: "studio_worker", "[{category}] {message}");
1043 } else if level == "warn" {
1044 tracing::warn!(target: "studio_worker", "[{category}] {message}");
1045 } else {
1046 info!(target: "studio_worker", "[{category}] {message}");
1047 }
1048 logs.lock().push(entry);
1049}
1050
1051#[cfg(test)]
1052mod tests {
1053 use super::*;
1054 use crate::config::Config;
1055 use crate::engine::SyntheticEngine;
1056
1057 #[test]
1058 fn capabilities_advertises_all_synthetic_kinds() {
1059 let cfg = Config::default();
1060 let engine = SyntheticEngine::new(vec![]);
1061 let cap = build_capabilities(&cfg, &engine);
1062 assert_eq!(cap.engine, "synthetic");
1063 assert_eq!(cap.task_kinds.len(), TaskKind::ALL.len());
1064 for kind in TaskKind::ALL {
1065 assert!(cap.supported_models_per_kind.contains_key(&kind));
1066 }
1067 }
1068
1069 #[test]
1070 fn capabilities_uses_override_for_legacy_flat_list() {
1071 let cfg = Config {
1072 supported_models_override: vec!["only-this".into()],
1073 ..Config::default()
1074 };
1075 let engine = SyntheticEngine::new(vec![]);
1076 let cap = build_capabilities(&cfg, &engine);
1077 assert_eq!(cap.supported_models, vec!["only-this".to_string()]);
1078 }
1079
1080 #[test]
1081 fn prompt_for_extracts_per_kind() {
1082 let image = Task::Image(ImageParams {
1083 prompt: "a stone golem".into(),
1084 width: 512,
1085 height: 512,
1086 steps: 20,
1087 seed: None,
1088 ext: "webp".into(),
1089 });
1090 assert_eq!(prompt_for(&image), "a stone golem");
1091
1092 let llm = Task::Llm(LlmParams {
1093 messages: vec![
1094 ChatMessage {
1095 role: "system".into(),
1096 content: "be helpful".into(),
1097 },
1098 ChatMessage {
1099 role: "user".into(),
1100 content: "hi".into(),
1101 },
1102 ],
1103 max_tokens: 32,
1104 temperature: 0.5,
1105 });
1106 assert_eq!(prompt_for(&llm), "hi");
1107
1108 let llm_empty = Task::Llm(LlmParams {
1109 messages: vec![],
1110 max_tokens: 1,
1111 temperature: 0.0,
1112 });
1113 assert_eq!(prompt_for(&llm_empty), "");
1114
1115 let stt = Task::AudioStt(AudioSttParams {
1116 input_url: "https://example.com/clip.wav".into(),
1117 language: None,
1118 });
1119 assert_eq!(prompt_for(&stt), "https://example.com/clip.wav");
1120
1121 let tts = Task::AudioTts(AudioTtsParams {
1122 text: "hi there".into(),
1123 voice: "v".into(),
1124 ext: "wav".into(),
1125 });
1126 assert_eq!(prompt_for(&tts), "hi there");
1127
1128 let video = Task::Video(VideoParams {
1129 prompt: "a tiny dragon".into(),
1130 seconds: 1.0,
1131 width: 256,
1132 height: 256,
1133 ext: "mp4".into(),
1134 });
1135 assert_eq!(prompt_for(&video), "a tiny dragon");
1136 }
1137
1138 #[test]
1139 fn is_unsupported_kind_matches_engine_message() {
1140 let err = anyhow!("gradio engine cannot serve llm tasks");
1141 assert!(is_unsupported_kind(&err));
1142 let other = anyhow!("network timeout");
1143 assert!(!is_unsupported_kind(&other));
1144 }
1145
1146 #[test]
1147 fn next_delay_for_picks_idle_after_a_job() {
1148 assert_eq!(next_delay_for(&ClaimOutcome::RanJob), CLAIM_INTERVAL_IDLE);
1149 }
1150
1151 #[test]
1152 fn next_delay_for_backs_off_when_no_jobs_or_errors() {
1153 assert_eq!(
1154 next_delay_for(&ClaimOutcome::NoJobs),
1155 CLAIM_INTERVAL_AFTER_NULL
1156 );
1157 assert_eq!(
1158 next_delay_for(&ClaimOutcome::Error("boom".into())),
1159 CLAIM_INTERVAL_AFTER_NULL
1160 );
1161 assert_eq!(
1162 next_delay_for(&ClaimOutcome::Skipped),
1163 CLAIM_INTERVAL_AFTER_NULL
1164 );
1165 }
1166
1167 #[test]
1168 fn format_status_includes_every_field() {
1169 let cfg = Config::default();
1170 let out = format_status(&cfg, std::path::Path::new("/tmp/x.toml"));
1171 assert!(out.contains("config path:"));
1172 assert!(out.contains("api_base_url:"));
1173 assert!(out.contains("worker_id:"));
1174 assert!(out.contains("(not registered)"));
1175 assert!(out.contains("auto_update:"));
1176 assert!(out.contains("update_interval:"));
1177 }
1178
1179 #[test]
1180 fn format_status_shows_worker_id_when_registered() {
1181 let cfg = Config {
1182 worker_id: Some("w-abc".into()),
1183 ..Config::default()
1184 };
1185 let out = format_status(&cfg, std::path::Path::new("/tmp/x.toml"));
1186 assert!(out.contains("w-abc"));
1187 }
1188
1189 #[test]
1190 fn format_check_outcome_handles_both_branches() {
1191 let up = update::CheckOutcome::UpToDate {
1192 current: semver::Version::new(1, 2, 3),
1193 };
1194 assert!(format_check_outcome(&up).contains("up to date"));
1195 let newer = update::CheckOutcome::NewerAvailable {
1196 current: semver::Version::new(1, 2, 3),
1197 latest: semver::Version::new(1, 3, 0),
1198 };
1199 let s = format_check_outcome(&newer);
1200 assert!(s.contains("1.2.3 -> 1.3.0"));
1201 }
1202
1203 #[test]
1204 fn push_log_appends_an_entry() {
1205 let logs: Arc<Mutex<Vec<LogEntry>>> = Arc::new(Mutex::new(Vec::new()));
1206 push_log(&logs, "info", "test", "hi", None);
1207 push_log(&logs, "warn", "test", "wat", Some("j-1".into()));
1208 push_log(&logs, "error", "test", "boom", None);
1209 let v = logs.lock();
1210 assert_eq!(v.len(), 3);
1211 assert_eq!(v[0].level, "info");
1212 assert_eq!(v[1].level, "warn");
1213 assert_eq!(v[1].job_id.as_deref(), Some("j-1"));
1214 assert_eq!(v[2].level, "error");
1215 }
1216
1217 fn cfg_pointing_at(api_base_url: String) -> Config {
1220 Config {
1221 api_base_url,
1222 worker_id: Some("w-test".into()),
1223 auth_token: Some("tok-test".into()),
1224 engine: "synthetic".into(),
1225 auto_enabled: true,
1226 auto_update_enabled: false,
1227 ..Config::default()
1228 }
1229 }
1230
1231 #[tokio::test]
1232 async fn claim_tick_returns_skipped_when_auto_enabled_is_false() {
1233 let cfg = Config {
1234 auto_enabled: false,
1235 ..Config::default()
1236 };
1237 let logs = Arc::new(Mutex::new(Vec::new()));
1238 let busy = Arc::new(AtomicBool::new(false));
1239 let observers = WorkerObservers::default();
1240 let outcome = claim_tick(&cfg, &logs, &busy, &observers).await;
1241 assert_eq!(outcome, ClaimOutcome::Skipped);
1242 }
1243
1244 #[tokio::test]
1245 async fn auto_update_tick_disabled_when_flag_off() {
1246 let cfg = Config {
1247 auto_update_enabled: false,
1248 ..Config::default()
1249 };
1250 let logs = Arc::new(Mutex::new(Vec::new()));
1251 let decision = auto_update_tick(&cfg, false, &logs).await;
1252 assert_eq!(decision, AutoUpdateDecision::Disabled);
1253 }
1254
1255 #[tokio::test]
1256 async fn auto_update_tick_skipped_when_busy() {
1257 let cfg = Config {
1258 auto_update_enabled: true,
1259 ..Config::default()
1260 };
1261 let logs = Arc::new(Mutex::new(Vec::new()));
1262 let decision = auto_update_tick(&cfg, true, &logs).await;
1263 assert_eq!(decision, AutoUpdateDecision::SkippedBusy);
1264 let entries = logs.lock();
1265 assert!(entries.iter().any(|e| e.message.contains("busy on a job")));
1266 }
1267
1268 #[tokio::test]
1269 async fn log_shipper_tick_returns_zero_when_buffer_empty() {
1270 let cfg = cfg_pointing_at("http://unused.invalid".into());
1271 let logs = Arc::new(Mutex::new(Vec::new()));
1272 let n = log_shipper_tick(&cfg, &logs).await;
1273 assert_eq!(n, 0);
1274 }
1275
1276 #[tokio::test]
1277 async fn log_shipper_tick_returns_zero_when_unregistered() {
1278 let cfg = Config {
1279 worker_id: None,
1280 auth_token: None,
1281 ..cfg_pointing_at("http://unused.invalid".into())
1282 };
1283 let logs = Arc::new(Mutex::new(vec![LogEntry {
1284 ts: "ts".into(),
1285 level: "info".into(),
1286 category: "x".into(),
1287 message: "m".into(),
1288 job_id: None,
1289 }]));
1290 let n = log_shipper_tick(&cfg, &logs).await;
1291 assert_eq!(n, 0);
1292 }
1293}