Skip to main content

studio_worker/
runtime.rs

1//! Long-running heartbeat + claim loop, log shipper, auto-update task,
2//! and the one-shot CLI helpers.
3//!
4//! The four loops (`spawn_heartbeat`, `spawn_claim_loop`,
5//! `spawn_log_shipper`, `spawn_auto_updater`) are thin `while !stop`
6//! wrappers around the testable `*_tick` helpers below — those are the
7//! places real per-iteration logic lives, and they're 100% unit-tested.
8use crate::{
9    config::{self, Config, SharedConfig},
10    engine::{self, Engine},
11    http::ApiClient,
12    sys,
13    types::*,
14    update, AGENT_VERSION,
15};
16use anyhow::{anyhow, Result};
17use chrono::{DateTime, SecondsFormat, Utc};
18use parking_lot::Mutex;
19use std::{
20    collections::VecDeque,
21    sync::{
22        atomic::{AtomicBool, Ordering},
23        Arc,
24    },
25    time::Duration,
26};
27use tracing::{info, warn};
28
29// ---------------------------------------------------------------------------
30// Live observation slots consumed by the upcoming egui UI (see
31// `plans/native-ui.md`).  None of this is wire-format; it's purely
32// in-memory state that the existing tick functions populate so an
33// in-process subscriber can read what the loops already know.
34// ---------------------------------------------------------------------------
35
36/// Maximum number of finished jobs kept in `WorkerObservers::recent_jobs`.
37/// Older entries fall off the back of the ring.
38pub const RECENT_JOBS_CAP: usize = 50;
39
40/// Prompt previews stored in `CurrentJob` / `RecentJob` are clipped to
41/// this many chars so the in-memory state stays bounded even when LLM
42/// prompts are huge.
43pub const PROMPT_PREVIEW_CHARS: usize = 200;
44
45/// Job in flight right now.  Populated by `claim_tick` before
46/// dispatch, cleared once the job finishes (success or failure).
47#[derive(Debug, Clone)]
48pub struct CurrentJob {
49    pub job_id: String,
50    pub kind: TaskKind,
51    pub model: String,
52    pub prompt: String,
53    pub started_at: DateTime<Utc>,
54}
55
56/// Outcome a finished job ended with.  Failures carry the human
57/// reason (already surfaced to logs + Sentry).
58#[derive(Debug, Clone, PartialEq, Eq)]
59pub enum JobOutcome {
60    Completed,
61    Failed { reason: String },
62}
63
64/// One finished job, retained in the recent-jobs ring for the UI.
65#[derive(Debug, Clone)]
66pub struct RecentJob {
67    pub job_id: String,
68    pub kind: TaskKind,
69    pub model: String,
70    pub prompt: String,
71    pub outcome: JobOutcome,
72    pub started_at: DateTime<Utc>,
73    pub finished_at: DateTime<Utc>,
74}
75
76/// Result of the most recent `heartbeat_tick`.
77#[derive(Debug, Clone, PartialEq, Eq)]
78pub enum HeartbeatOutcome {
79    Ok,
80    Err { reason: String },
81}
82
83#[derive(Debug, Clone)]
84pub struct HeartbeatStatus {
85    pub last_attempt_at: DateTime<Utc>,
86    pub outcome: HeartbeatOutcome,
87}
88
89/// Bundle of in-process observation slots written by the runtime
90/// loops and read by the UI.  `Default` gives empty slots so existing
91/// (headless) call sites stay one-liners.  Cheap to clone — every
92/// field is an `Arc`.
93#[derive(Clone, Default)]
94pub struct WorkerObservers {
95    pub current_job: Arc<Mutex<Option<CurrentJob>>>,
96    pub recent_jobs: Arc<Mutex<VecDeque<RecentJob>>>,
97    pub last_heartbeat: Arc<Mutex<Option<HeartbeatStatus>>>,
98}
99
100fn truncate_prompt(s: &str) -> String {
101    if s.chars().count() <= PROMPT_PREVIEW_CHARS {
102        return s.to_string();
103    }
104    let mut out: String = s.chars().take(PROMPT_PREVIEW_CHARS).collect();
105    out.push('…');
106    out
107}
108
109fn record_recent_job(observers: &WorkerObservers, entry: RecentJob) {
110    let mut ring = observers.recent_jobs.lock();
111    ring.push_front(entry);
112    while ring.len() > RECENT_JOBS_CAP {
113        ring.pop_back();
114    }
115}
116
117/// Test-only helper to populate the recent-jobs ring without driving a
118/// full claim cycle.  Lives in the library surface so integration
119/// tests can pin the ring-capacity contract cheaply.
120#[doc(hidden)]
121pub fn push_recent_job_for_tests(observers: &WorkerObservers, job_id: &str) {
122    let now = Utc::now();
123    record_recent_job(
124        observers,
125        RecentJob {
126            job_id: job_id.to_string(),
127            kind: TaskKind::Image,
128            model: "synthetic".into(),
129            prompt: String::new(),
130            outcome: JobOutcome::Completed,
131            started_at: now,
132            finished_at: now,
133        },
134    );
135}
136
137/// Tracing target for runtime-level events (startup, state mutations).
138/// Stable so operators can filter with `RUST_LOG=studio_worker::runtime=debug`.
139const TRACE_TARGET: &str = "studio_worker::runtime";
140
141pub const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5);
142pub const CLAIM_INTERVAL_IDLE: Duration = Duration::from_secs(2);
143pub const CLAIM_INTERVAL_AFTER_NULL: Duration = Duration::from_secs(5);
144pub const LOG_FLUSH_INTERVAL: Duration = Duration::from_secs(1);
145pub const AUTO_UPDATE_TICK: Duration = Duration::from_secs(60);
146
147/// Schedule for the four background loops.  Tests dial these intervals
148/// down to milliseconds so we can exercise the loop bodies quickly.
149#[derive(Debug, Clone, Copy)]
150pub struct LoopSchedule {
151    pub heartbeat: Duration,
152    pub claim_idle: Duration,
153    pub claim_after_null: Duration,
154    pub log_flush: Duration,
155    pub auto_update_tick: Duration,
156}
157
158impl Default for LoopSchedule {
159    fn default() -> Self {
160        Self {
161            heartbeat: HEARTBEAT_INTERVAL,
162            claim_idle: CLAIM_INTERVAL_IDLE,
163            claim_after_null: CLAIM_INTERVAL_AFTER_NULL,
164            log_flush: LOG_FLUSH_INTERVAL,
165            auto_update_tick: AUTO_UPDATE_TICK,
166        }
167    }
168}
169
170impl LoopSchedule {
171    /// Schedule with 1 ms intervals — used by tests to exercise the
172    /// loop wrappers without blocking.
173    pub fn fast_for_tests() -> Self {
174        Self {
175            heartbeat: Duration::from_millis(1),
176            claim_idle: Duration::from_millis(1),
177            claim_after_null: Duration::from_millis(1),
178            log_flush: Duration::from_millis(1),
179            auto_update_tick: Duration::from_millis(1),
180        }
181    }
182}
183
184// ---------------------------------------------------------------------------
185// One-shot helpers used by the CLI subcommands
186// ---------------------------------------------------------------------------
187
188pub async fn register(
189    config_path: Option<&str>,
190    bootstrap_override: Option<String>,
191    api_base_url: Option<String>,
192) -> Result<()> {
193    let (mut cfg, path) = config::load(config_path)?;
194    if let Some(token) = bootstrap_override {
195        cfg.bootstrap_token = token;
196    }
197    if let Some(url) = api_base_url {
198        cfg.api_base_url = url;
199    }
200    let engine = engine::build(&cfg)?;
201    let cap = build_capabilities(&cfg, &*engine);
202    let api_for_diag = cfg.api_base_url.clone();
203    let response = tokio::task::spawn_blocking({
204        let api_base_url = cfg.api_base_url.clone();
205        let bootstrap = cfg.bootstrap_token.clone();
206        let worker_id = cfg.worker_id.clone();
207        let cap = cap.clone();
208        move || -> Result<RegisterResponse> {
209            let api = ApiClient::new(api_base_url)?;
210            api.register(&bootstrap, cap, worker_id)
211        }
212    })
213    .await?
214    .map_err(|e| friendly_register_error(e, &api_for_diag))?;
215    cfg.worker_id = Some(response.worker_id.clone());
216    cfg.auth_token = Some(response.auth_token);
217    config::save(&cfg, &path)?;
218    info!(
219        worker_id = %response.worker_id,
220        api = %cfg.api_base_url,
221        "registered with studio API"
222    );
223    Ok(())
224}
225
226/// Wrap network/HTTP errors from register() with a hint that points the
227/// operator at `--api-base-url` and the right secret.  Saves people from
228/// hitting the default `http://localhost:9790` and wondering what happened.
229fn friendly_register_error(err: anyhow::Error, api_base_url: &str) -> anyhow::Error {
230    // Walk the full error chain so we catch the cause inside
231    // reqwest/hyper, not just the top-level wrap.
232    let message = format!("{:#}", err);
233    let is_connection_refused =
234        message.contains("Connection refused") || message.contains("ConnectionRefused");
235    if is_connection_refused {
236        anyhow!(
237            "could not reach the studio API at {api_base_url}: {message}\n\
238             \n\
239             Hint: pass --api-base-url <URL> on the register command, e.g.\n\
240               studio-worker register \\\n\
241                 --bootstrap-token <TOKEN> \\\n\
242                 --api-base-url https://studio.example.com\n\
243             \n\
244             The bootstrap token is the WORKER_BOOTSTRAP_TOKEN wrangler secret\n\
245             on the studio side (for local dev the default is `dev-bootstrap-token`)."
246        )
247    } else if message.contains("401") || message.contains("403") {
248        anyhow!(
249            "the studio API rejected our bootstrap token: {message}\n\
250             \n\
251             Check that --bootstrap-token matches the WORKER_BOOTSTRAP_TOKEN\n\
252             secret on the studio side."
253        )
254    } else {
255        err
256    }
257}
258
259pub async fn status(config_path: Option<&str>) -> Result<()> {
260    let (cfg, path) = config::load(config_path)?;
261    println!("{}", format_status(&cfg, &path));
262    Ok(())
263}
264
265pub fn format_status(cfg: &Config, path: &std::path::Path) -> String {
266    let mut out = String::new();
267    use std::fmt::Write as _;
268    let _ = writeln!(out, "config path:        {}", path.display());
269    let _ = writeln!(out, "api_base_url:       {}", cfg.api_base_url);
270    let _ = writeln!(
271        out,
272        "worker_id:          {}",
273        cfg.worker_id.as_deref().unwrap_or("(not registered)")
274    );
275    let _ = writeln!(out, "engine:             {}", cfg.engine);
276    let _ = writeln!(out, "vram_threshold_gb:  {}", cfg.vram_threshold_gb);
277    let _ = writeln!(out, "auto_enabled:       {}", cfg.auto_enabled);
278    let _ = writeln!(out, "auto_start:         {}", cfg.auto_start);
279    let _ = writeln!(out, "auto_update:        {}", cfg.auto_update_enabled);
280    let _ = writeln!(
281        out,
282        "update_interval:    {}s",
283        cfg.auto_update_interval_secs
284    );
285    out
286}
287
288pub fn set_enabled(config_path: Option<&str>, enabled: bool) -> Result<()> {
289    let (mut cfg, path) = config::load(config_path)?;
290    cfg.auto_enabled = enabled;
291    config::save(&cfg, &path)?;
292    info!(
293        target: TRACE_TARGET,
294        op = "set_enabled",
295        auto_enabled = enabled,
296        config_path = path.display().to_string(),
297        "auto-claim flag persisted"
298    );
299    println!("auto_enabled = {enabled}");
300    Ok(())
301}
302
303pub fn set_threshold(config_path: Option<&str>, gb: f32) -> Result<()> {
304    if gb < 0.0 {
305        return Err(anyhow!("threshold must be >= 0"));
306    }
307    let (mut cfg, path) = config::load(config_path)?;
308    cfg.vram_threshold_gb = gb;
309    config::save(&cfg, &path)?;
310    info!(
311        target: TRACE_TARGET,
312        op = "set_threshold",
313        vram_threshold_gb = gb,
314        config_path = path.display().to_string(),
315        "VRAM threshold persisted"
316    );
317    println!("vram_threshold_gb = {gb}");
318    Ok(())
319}
320
321/// Emit a one-shot startup banner so operators can confirm which
322/// config the worker actually loaded.  Without this the only thing in
323/// `journalctl -u studio-worker` on a healthy boot is whatever the
324/// loops happen to log on their first tick.
325pub fn log_startup_banner(cfg: &Config, path: &std::path::Path) {
326    info!(
327        target: TRACE_TARGET,
328        op = "startup",
329        version = AGENT_VERSION,
330        config_path = path.display().to_string(),
331        api_base_url = cfg.api_base_url.as_str(),
332        engine = cfg.engine.as_str(),
333        vram_threshold_gb = cfg.vram_threshold_gb,
334        auto_enabled = cfg.auto_enabled,
335        auto_update_enabled = cfg.auto_update_enabled,
336        auto_update_interval_secs = cfg.auto_update_interval_secs,
337        worker_id = cfg.worker_id.as_deref().unwrap_or("(unregistered)"),
338        "studio-worker booting"
339    );
340}
341
342pub fn show_config(config_path: Option<&str>) -> Result<()> {
343    let (cfg, path) = config::load(config_path)?;
344    println!("# {}", path.display());
345    print!("{}", toml::to_string_pretty(&cfg)?);
346    Ok(())
347}
348
349pub async fn check_update(config_path: Option<&str>) -> Result<()> {
350    let (cfg, _) = config::load(config_path)?;
351    let current = semver::Version::parse(AGENT_VERSION)
352        .map_err(|e| anyhow!("invalid current version {AGENT_VERSION}: {e}"))?;
353    let outcome = tokio::task::spawn_blocking(move || {
354        update::check(&cfg.auto_update_feed, &current, cfg.auto_update_prerelease)
355    })
356    .await??;
357    println!("{}", format_check_outcome(&outcome));
358    Ok(())
359}
360
361pub fn format_check_outcome(outcome: &update::CheckOutcome) -> String {
362    match outcome {
363        update::CheckOutcome::UpToDate { current } => format!("up to date: {current}"),
364        update::CheckOutcome::NewerAvailable { current, latest } => {
365            format!("update available: {current} -> {latest}")
366        }
367    }
368}
369
370// ---------------------------------------------------------------------------
371// Long-running run loop
372// ---------------------------------------------------------------------------
373
374pub async fn run(config_path: Option<&str>) -> Result<()> {
375    let (mut cfg, path) = config::load(config_path)?;
376    log_startup_banner(&cfg, &path);
377    if cfg.worker_id.is_none() || cfg.auth_token.is_none() {
378        let engine = engine::build(&cfg)?;
379        let cap = build_capabilities(&cfg, &*engine);
380        let response = tokio::task::spawn_blocking({
381            let api_base_url = cfg.api_base_url.clone();
382            let bootstrap = cfg.bootstrap_token.clone();
383            move || -> Result<RegisterResponse> {
384                let api = ApiClient::new(api_base_url)?;
385                api.register(&bootstrap, cap, None)
386            }
387        })
388        .await??;
389        cfg.worker_id = Some(response.worker_id);
390        cfg.auth_token = Some(response.auth_token);
391        config::save(&cfg, &path)?;
392        info!(
393            worker_id = %cfg.worker_id.as_deref().unwrap_or(""),
394            "auto-registered on first run"
395        );
396    }
397
398    let cfg = config::shared(cfg);
399    let stop = Arc::new(AtomicBool::new(false));
400    let busy = Arc::new(AtomicBool::new(false));
401    let logs: Arc<Mutex<Vec<LogEntry>>> = Arc::new(Mutex::new(Vec::new()));
402    let observers = WorkerObservers::default();
403
404    let stop_clone = stop.clone();
405    tokio::spawn(async move {
406        let _ = tokio::signal::ctrl_c().await;
407        stop_clone.store(true, Ordering::SeqCst);
408    });
409
410    run_loops(cfg, stop, logs, busy, observers, LoopSchedule::default()).await;
411    Ok(())
412}
413
414/// Spawn the four loops with the supplied schedule and wait for them to
415/// exit.  Pulled out of `run` so tests can drive the loop wrappers with
416/// short intervals.
417pub async fn run_loops(
418    cfg: SharedConfig,
419    stop: Arc<AtomicBool>,
420    logs: Arc<Mutex<Vec<LogEntry>>>,
421    busy: Arc<AtomicBool>,
422    observers: WorkerObservers,
423    schedule: LoopSchedule,
424) {
425    let heartbeat = spawn_heartbeat(
426        cfg.clone(),
427        stop.clone(),
428        logs.clone(),
429        busy.clone(),
430        observers.clone(),
431        schedule,
432    );
433    let claim = spawn_claim_loop(
434        cfg.clone(),
435        stop.clone(),
436        logs.clone(),
437        busy.clone(),
438        observers.clone(),
439        schedule,
440    );
441    let log_shipper = spawn_log_shipper(cfg.clone(), stop.clone(), logs.clone(), schedule);
442    let auto_updater = spawn_auto_updater(
443        cfg.clone(),
444        stop.clone(),
445        logs.clone(),
446        busy.clone(),
447        schedule,
448    );
449    let _ = tokio::join!(heartbeat, claim, log_shipper, auto_updater);
450}
451
452// ---------------------------------------------------------------------------
453// Per-tick helpers — pure async fns, easy to drive from unit tests.
454// ---------------------------------------------------------------------------
455
456/// Outcome of a single claim attempt.
457#[derive(Debug, Clone, PartialEq, Eq)]
458pub enum ClaimOutcome {
459    /// We claimed and ran a job — back to the fast poll interval.
460    RanJob,
461    /// 204 from the API — nothing to do, back off briefly.
462    NoJobs,
463    /// HTTP / wiring error — back off briefly and try again.
464    Error(String),
465    /// Worker is currently `auto_enabled = false`; skip the claim entirely.
466    Skipped,
467}
468
469/// Send one heartbeat.  Returns Ok regardless of whether the HTTP call
470/// succeeded — failures are captured in the log buffer and in
471/// `observers.last_heartbeat` for the UI.
472pub async fn heartbeat_tick(
473    cfg: &Config,
474    busy_now: bool,
475    logs: &Arc<Mutex<Vec<LogEntry>>>,
476    observers: &WorkerObservers,
477) -> Result<()> {
478    let engine = match engine::build(cfg) {
479        Ok(e) => e,
480        Err(e) => {
481            push_log(
482                logs,
483                "warn",
484                "heartbeat",
485                &format!("engine error: {e}"),
486                None,
487            );
488            *observers.last_heartbeat.lock() = Some(HeartbeatStatus {
489                last_attempt_at: Utc::now(),
490                outcome: HeartbeatOutcome::Err {
491                    reason: format!("engine error: {e}"),
492                },
493            });
494            return Ok(());
495        }
496    };
497    let cap = build_capabilities(cfg, &*engine);
498    let token = cfg.auth_token.clone().unwrap_or_default();
499    let worker_id = cfg.worker_id.clone().unwrap_or_default();
500    let api_base_url = cfg.api_base_url.clone();
501    let logs_for_task = logs.clone();
502    let result = tokio::task::spawn_blocking(move || -> Result<()> {
503        let api = ApiClient::new(api_base_url)?;
504        api.heartbeat(&worker_id, &token, cap, None)
505    })
506    .await;
507    let outcome = match result {
508        Ok(Ok(())) => HeartbeatOutcome::Ok,
509        Ok(Err(e)) => {
510            push_log(
511                &logs_for_task,
512                "warn",
513                "heartbeat",
514                &format!("heartbeat failed (busy={busy_now}): {e}"),
515                None,
516            );
517            HeartbeatOutcome::Err {
518                reason: e.to_string(),
519            }
520        }
521        Err(e) => {
522            push_log(
523                &logs_for_task,
524                "warn",
525                "heartbeat",
526                &format!("heartbeat task panic: {e}"),
527                None,
528            );
529            HeartbeatOutcome::Err {
530                reason: format!("task panic: {e}"),
531            }
532        }
533    };
534    *observers.last_heartbeat.lock() = Some(HeartbeatStatus {
535        last_attempt_at: Utc::now(),
536        outcome,
537    });
538    Ok(())
539}
540
541/// One claim attempt + (when a job is claimed) run-to-completion.
542pub async fn claim_tick(
543    cfg: &Config,
544    logs: &Arc<Mutex<Vec<LogEntry>>>,
545    busy: &Arc<AtomicBool>,
546    observers: &WorkerObservers,
547) -> ClaimOutcome {
548    if !cfg.auto_enabled {
549        return ClaimOutcome::Skipped;
550    }
551    let engine = match engine::build(cfg) {
552        Ok(e) => e,
553        Err(e) => {
554            push_log(logs, "warn", "claim", &format!("engine error: {e}"), None);
555            return ClaimOutcome::Error(e.to_string());
556        }
557    };
558    let token = cfg.auth_token.clone().unwrap_or_default();
559    let worker_id = cfg.worker_id.clone().unwrap_or_default();
560    let api_base_url = cfg.api_base_url.clone();
561
562    let claim_result = tokio::task::spawn_blocking({
563        let token = token.clone();
564        let worker_id = worker_id.clone();
565        let api_base_url = api_base_url.clone();
566        move || -> Result<(ApiClient, Option<JobClaim>)> {
567            let api = ApiClient::new(api_base_url)?;
568            let claim = api.claim(&worker_id, &token)?;
569            Ok((api, claim))
570        }
571    })
572    .await;
573
574    match claim_result {
575        Ok(Ok((api, Some(job)))) => {
576            busy.store(true, Ordering::SeqCst);
577            push_log(
578                logs,
579                "info",
580                "claim",
581                &format!(
582                    "claimed job {} (model={}, vram={}GB)",
583                    job.job_id, job.model, job.vram_gb_estimate
584                ),
585                Some(job.job_id.clone()),
586            );
587
588            // Snapshot for the UI _before_ we move `job` into the
589            // blocking thread.  Resolving the task lazily here lets us
590            // record the real kind even when the wire payload omitted
591            // the explicit `task` field (legacy image-only path).
592            let resolved = job.resolved_task();
593            let snapshot = CurrentJob {
594                job_id: job.job_id.clone(),
595                kind: resolved.kind(),
596                model: job.model.clone(),
597                prompt: truncate_prompt(&prompt_for(&resolved)),
598                started_at: Utc::now(),
599            };
600            *observers.current_job.lock() = Some(snapshot.clone());
601
602            // Drop the api on the blocking thread to avoid reqwest's
603            // internal runtime dropping inside an async context.
604            let logs_clone = logs.clone();
605            let token_clone = token.clone();
606            let worker_id_clone = worker_id.clone();
607            let engine_handle = engine;
608            let join = tokio::task::spawn_blocking(move || {
609                run_job(
610                    &api,
611                    &token_clone,
612                    &worker_id_clone,
613                    &*engine_handle,
614                    &logs_clone,
615                    job,
616                )
617            })
618            .await;
619            let outcome = match join {
620                Ok(o) => o,
621                Err(e) => JobOutcome::Failed {
622                    reason: format!("job task panic: {e}"),
623                },
624            };
625            *observers.current_job.lock() = None;
626            record_recent_job(
627                observers,
628                RecentJob {
629                    job_id: snapshot.job_id,
630                    kind: snapshot.kind,
631                    model: snapshot.model,
632                    prompt: snapshot.prompt,
633                    outcome,
634                    started_at: snapshot.started_at,
635                    finished_at: Utc::now(),
636                },
637            );
638            busy.store(false, Ordering::SeqCst);
639            ClaimOutcome::RanJob
640        }
641        Ok(Ok((_api, None))) => ClaimOutcome::NoJobs,
642        Ok(Err(e)) => {
643            push_log(
644                logs,
645                "warn",
646                "claim",
647                &format!("claim request errored: {e}"),
648                None,
649            );
650            ClaimOutcome::Error(e.to_string())
651        }
652        Err(e) => {
653            push_log(
654                logs,
655                "warn",
656                "claim",
657                &format!("claim task panic: {e}"),
658                None,
659            );
660            ClaimOutcome::Error(e.to_string())
661        }
662    }
663}
664
665/// Flush all buffered logs to the API.  Returns the number of entries
666/// shipped (0 if no logs were buffered or the worker isn't registered).
667pub async fn log_shipper_tick(cfg: &Config, logs: &Arc<Mutex<Vec<LogEntry>>>) -> usize {
668    let token = cfg.auth_token.clone().unwrap_or_default();
669    let worker_id = cfg.worker_id.clone().unwrap_or_default();
670    if worker_id.is_empty() || token.is_empty() {
671        // Drain the buffer anyway so it doesn't grow unbounded.
672        logs.lock().clear();
673        return 0;
674    }
675    let batch = {
676        let mut guard = logs.lock();
677        if guard.is_empty() {
678            return 0;
679        }
680        LogBatch {
681            entries: std::mem::take(&mut *guard),
682        }
683    };
684    let count = batch.entries.len();
685    let api_base_url = cfg.api_base_url.clone();
686    let _ = tokio::task::spawn_blocking(move || -> Result<()> {
687        let api = ApiClient::new(api_base_url)?;
688        api.ship_logs(&worker_id, &token, batch)
689    })
690    .await;
691    count
692}
693
694/// What the auto-updater decided this tick.
695#[derive(Debug, Clone, PartialEq, Eq)]
696pub enum AutoUpdateDecision {
697    /// Auto-update is turned off — do nothing.
698    Disabled,
699    /// Worker is currently running a job — skip.
700    SkippedBusy,
701    /// Local version is already the latest.
702    UpToDate,
703    /// Check failed (network etc.) — leave a log entry, try again later.
704    CheckError(String),
705    /// A newer version was applied successfully.  Caller should restart.
706    Updated,
707    /// A newer version was found but the install failed.
708    UpdateError(String),
709}
710
711pub async fn auto_update_tick(
712    cfg: &Config,
713    busy: bool,
714    logs: &Arc<Mutex<Vec<LogEntry>>>,
715) -> AutoUpdateDecision {
716    if !cfg.auto_update_enabled {
717        return AutoUpdateDecision::Disabled;
718    }
719    if busy {
720        push_log(
721            logs,
722            "info",
723            "auto-update",
724            "skipping check: worker is busy on a job",
725            None,
726        );
727        return AutoUpdateDecision::SkippedBusy;
728    }
729    let feed = cfg.auto_update_feed.clone();
730    let prerelease = cfg.auto_update_prerelease;
731    let logs_for_task = logs.clone();
732    let outcome = tokio::task::spawn_blocking(move || -> Result<AutoUpdateDecision> {
733        let current = semver::Version::parse(AGENT_VERSION)
734            .map_err(|e| anyhow!("invalid AGENT_VERSION {AGENT_VERSION}: {e}"))?;
735        match update::check(&feed, &current, prerelease) {
736            Ok(update::CheckOutcome::UpToDate { current }) => {
737                push_log(
738                    &logs_for_task,
739                    "info",
740                    "auto-update",
741                    &format!("up to date at {current}"),
742                    None,
743                );
744                Ok(AutoUpdateDecision::UpToDate)
745            }
746            Ok(update::CheckOutcome::NewerAvailable { current, latest }) => {
747                push_log(
748                    &logs_for_task,
749                    "info",
750                    "auto-update",
751                    &format!("update available {current} -> {latest}; applying"),
752                    None,
753                );
754                match update::apply(&feed, &latest) {
755                    Ok(()) => {
756                        push_log(
757                            &logs_for_task,
758                            "info",
759                            "auto-update",
760                            "binary replaced; restart pending",
761                            None,
762                        );
763                        Ok(AutoUpdateDecision::Updated)
764                    }
765                    Err(e) => {
766                        push_log(
767                            &logs_for_task,
768                            "error",
769                            "auto-update",
770                            &format!("update failed: {e}"),
771                            None,
772                        );
773                        Ok(AutoUpdateDecision::UpdateError(e.to_string()))
774                    }
775                }
776            }
777            Err(e) => {
778                push_log(
779                    &logs_for_task,
780                    "warn",
781                    "auto-update",
782                    &format!("check failed: {e}"),
783                    None,
784                );
785                Ok(AutoUpdateDecision::CheckError(e.to_string()))
786            }
787        }
788    })
789    .await;
790    match outcome {
791        Ok(Ok(decision)) => decision,
792        Ok(Err(e)) => AutoUpdateDecision::CheckError(e.to_string()),
793        Err(e) => AutoUpdateDecision::CheckError(e.to_string()),
794    }
795}
796
797// ---------------------------------------------------------------------------
798// Long-running task wrappers — they exist solely to call the ticks in a
799// loop on a schedule.  All real logic lives in the ticks.
800// ---------------------------------------------------------------------------
801
802pub fn spawn_heartbeat(
803    cfg: SharedConfig,
804    stop: Arc<AtomicBool>,
805    logs: Arc<Mutex<Vec<LogEntry>>>,
806    busy: Arc<AtomicBool>,
807    observers: WorkerObservers,
808    schedule: LoopSchedule,
809) -> tokio::task::JoinHandle<()> {
810    tokio::spawn(async move {
811        while !stop.load(Ordering::SeqCst) {
812            tokio::time::sleep(schedule.heartbeat).await;
813            let snapshot = cfg.lock().clone();
814            let busy_now = busy.load(Ordering::SeqCst);
815            let _ = heartbeat_tick(&snapshot, busy_now, &logs, &observers).await;
816        }
817    })
818}
819
820pub fn spawn_claim_loop(
821    cfg: SharedConfig,
822    stop: Arc<AtomicBool>,
823    logs: Arc<Mutex<Vec<LogEntry>>>,
824    busy: Arc<AtomicBool>,
825    observers: WorkerObservers,
826    schedule: LoopSchedule,
827) -> tokio::task::JoinHandle<()> {
828    tokio::spawn(async move {
829        let mut next_delay = schedule.claim_idle;
830        while !stop.load(Ordering::SeqCst) {
831            tokio::time::sleep(next_delay).await;
832            let snapshot = cfg.lock().clone();
833            let outcome = claim_tick(&snapshot, &logs, &busy, &observers).await;
834            next_delay = match outcome {
835                ClaimOutcome::RanJob => schedule.claim_idle,
836                _ => schedule.claim_after_null,
837            };
838        }
839    })
840}
841
842/// Pure helper so the schedule decision is unit-testable.
843pub fn next_delay_for(outcome: &ClaimOutcome) -> Duration {
844    match outcome {
845        ClaimOutcome::RanJob => CLAIM_INTERVAL_IDLE,
846        ClaimOutcome::NoJobs | ClaimOutcome::Error(_) | ClaimOutcome::Skipped => {
847            CLAIM_INTERVAL_AFTER_NULL
848        }
849    }
850}
851
852pub fn spawn_log_shipper(
853    cfg: SharedConfig,
854    stop: Arc<AtomicBool>,
855    logs: Arc<Mutex<Vec<LogEntry>>>,
856    schedule: LoopSchedule,
857) -> tokio::task::JoinHandle<()> {
858    tokio::spawn(async move {
859        while !stop.load(Ordering::SeqCst) {
860            tokio::time::sleep(schedule.log_flush).await;
861            let snapshot = cfg.lock().clone();
862            let _ = log_shipper_tick(&snapshot, &logs).await;
863        }
864    })
865}
866
867pub fn spawn_auto_updater(
868    cfg: SharedConfig,
869    stop: Arc<AtomicBool>,
870    logs: Arc<Mutex<Vec<LogEntry>>>,
871    busy: Arc<AtomicBool>,
872    schedule: LoopSchedule,
873) -> tokio::task::JoinHandle<()> {
874    tokio::spawn(async move {
875        let mut elapsed = Duration::from_secs(0);
876        while !stop.load(Ordering::SeqCst) {
877            tokio::time::sleep(schedule.auto_update_tick).await;
878            elapsed += schedule.auto_update_tick;
879            let snapshot = cfg.lock().clone();
880            if elapsed < Duration::from_secs(snapshot.auto_update_interval_secs) {
881                continue;
882            }
883            elapsed = Duration::from_secs(0);
884            let busy_now = busy.load(Ordering::SeqCst);
885            let decision = auto_update_tick(&snapshot, busy_now, &logs).await;
886            if matches!(decision, AutoUpdateDecision::Updated) {
887                stop.store(true, Ordering::SeqCst);
888                update::restart_self();
889            }
890        }
891    })
892}
893
894// ---------------------------------------------------------------------------
895// Job runner (shared by claim_tick) — pure-ish; HTTP via ApiClient.
896// ---------------------------------------------------------------------------
897
898fn run_job(
899    api: &ApiClient,
900    token: &str,
901    worker_id: &str,
902    engine: &dyn Engine,
903    logs: &Arc<Mutex<Vec<LogEntry>>>,
904    job: JobClaim,
905) -> JobOutcome {
906    let start = std::time::Instant::now();
907    let task = job.resolved_task();
908    let task_kind = task.kind();
909    let prompt_for_log = prompt_for(&task);
910    let result = engine.dispatch(&job.model, task);
911    match result {
912        Ok(task_result) => {
913            push_log(
914                logs,
915                "info",
916                "generate",
917                &format!(
918                    "{} task generated in {:?}",
919                    task_kind.as_str(),
920                    start.elapsed()
921                ),
922                Some(job.job_id.clone()),
923            );
924            let outcome = match task_result {
925                TaskResult::Image { bytes, ext } => {
926                    api.complete(worker_id, token, &job.job_id, &ext, &prompt_for_log, bytes)
927                }
928                TaskResult::AudioTts { bytes, ext } => {
929                    api.complete(worker_id, token, &job.job_id, &ext, &prompt_for_log, bytes)
930                }
931                TaskResult::Video { bytes, ext } => {
932                    api.complete(worker_id, token, &job.job_id, &ext, &prompt_for_log, bytes)
933                }
934                TaskResult::Llm { json } => {
935                    api.complete_json(worker_id, token, &job.job_id, &prompt_for_log, &json)
936                }
937                TaskResult::AudioStt { json } => {
938                    api.complete_json(worker_id, token, &job.job_id, &prompt_for_log, &json)
939                }
940            };
941            match outcome {
942                Err(e) => {
943                    let reason = format!("complete failed: {e}");
944                    push_log(logs, "error", "complete", &reason, Some(job.job_id.clone()));
945                    JobOutcome::Failed { reason }
946                }
947                Ok(()) => {
948                    push_log(
949                        logs,
950                        "info",
951                        "complete",
952                        "job uploaded",
953                        Some(job.job_id.clone()),
954                    );
955                    JobOutcome::Completed
956                }
957            }
958        }
959        Err(e) => {
960            warn!("generate failed: {e:#}");
961            let reason = format!("generate failed: {e}");
962            push_log(logs, "error", "generate", &reason, Some(job.job_id.clone()));
963            let retryable = !is_unsupported_kind(&e);
964            let _ = api.fail(worker_id, token, &job.job_id, &e.to_string(), retryable);
965            JobOutcome::Failed { reason }
966        }
967    }
968}
969
970pub fn prompt_for(task: &Task) -> String {
971    match task {
972        Task::Image(p) => p.prompt.clone(),
973        Task::Llm(p) => p
974            .messages
975            .last()
976            .map(|m| m.content.clone())
977            .unwrap_or_default(),
978        Task::AudioStt(p) => p.input_url.clone(),
979        Task::AudioTts(p) => p.text.clone(),
980        Task::Video(p) => p.prompt.clone(),
981    }
982}
983
984pub fn is_unsupported_kind(e: &anyhow::Error) -> bool {
985    e.to_string().contains("cannot serve")
986}
987
988// ---------------------------------------------------------------------------
989// Helpers
990// ---------------------------------------------------------------------------
991
992pub fn build_capabilities(cfg: &Config, engine: &dyn Engine) -> WorkerCapabilities {
993    let vram = sys::detect_vram_gb().unwrap_or(0.0);
994    let caps = engine.capabilities();
995    let supported_models_per_kind = caps.supported_models_per_kind.clone();
996    let task_kinds = caps.kinds();
997    // Legacy `supported_models` is a flat list across all kinds so the
998    // studio API's claim filter (which only knows about this field) can
999    // match jobs of any modality this worker can serve.
1000    let supported_models = {
1001        let mut all = caps.flat_models();
1002        all.sort();
1003        all.dedup();
1004        all
1005    };
1006    let supported_models = if cfg.supported_models_override.is_empty() {
1007        supported_models
1008    } else {
1009        cfg.supported_models_override.clone()
1010    };
1011
1012    WorkerCapabilities {
1013        machine_name: sys::machine_name(),
1014        username: sys::username(),
1015        agent_version: AGENT_VERSION.to_string(),
1016        engine: cfg.engine.clone(),
1017        vram_total_gb: vram,
1018        vram_threshold_gb: cfg.vram_threshold_gb,
1019        auto_enabled: cfg.auto_enabled,
1020        auto_start: cfg.auto_start,
1021        supported_models,
1022        task_kinds,
1023        supported_models_per_kind,
1024    }
1025}
1026
1027pub fn push_log(
1028    logs: &Arc<Mutex<Vec<LogEntry>>>,
1029    level: &str,
1030    category: &str,
1031    message: &str,
1032    job_id: Option<String>,
1033) {
1034    let entry = LogEntry {
1035        ts: Utc::now().to_rfc3339_opts(SecondsFormat::Millis, true),
1036        level: level.to_string(),
1037        category: category.to_string(),
1038        message: message.to_string(),
1039        job_id,
1040    };
1041    if level == "error" {
1042        tracing::error!(target: "studio_worker", "[{category}] {message}");
1043    } else if level == "warn" {
1044        tracing::warn!(target: "studio_worker", "[{category}] {message}");
1045    } else {
1046        info!(target: "studio_worker", "[{category}] {message}");
1047    }
1048    logs.lock().push(entry);
1049}
1050
1051#[cfg(test)]
1052mod tests {
1053    use super::*;
1054    use crate::config::Config;
1055    use crate::engine::SyntheticEngine;
1056
1057    #[test]
1058    fn capabilities_advertises_all_synthetic_kinds() {
1059        let cfg = Config::default();
1060        let engine = SyntheticEngine::new(vec![]);
1061        let cap = build_capabilities(&cfg, &engine);
1062        assert_eq!(cap.engine, "synthetic");
1063        assert_eq!(cap.task_kinds.len(), TaskKind::ALL.len());
1064        for kind in TaskKind::ALL {
1065            assert!(cap.supported_models_per_kind.contains_key(&kind));
1066        }
1067    }
1068
1069    #[test]
1070    fn capabilities_uses_override_for_legacy_flat_list() {
1071        let cfg = Config {
1072            supported_models_override: vec!["only-this".into()],
1073            ..Config::default()
1074        };
1075        let engine = SyntheticEngine::new(vec![]);
1076        let cap = build_capabilities(&cfg, &engine);
1077        assert_eq!(cap.supported_models, vec!["only-this".to_string()]);
1078    }
1079
1080    #[test]
1081    fn prompt_for_extracts_per_kind() {
1082        let image = Task::Image(ImageParams {
1083            prompt: "a stone golem".into(),
1084            width: 512,
1085            height: 512,
1086            steps: 20,
1087            seed: None,
1088            ext: "webp".into(),
1089        });
1090        assert_eq!(prompt_for(&image), "a stone golem");
1091
1092        let llm = Task::Llm(LlmParams {
1093            messages: vec![
1094                ChatMessage {
1095                    role: "system".into(),
1096                    content: "be helpful".into(),
1097                },
1098                ChatMessage {
1099                    role: "user".into(),
1100                    content: "hi".into(),
1101                },
1102            ],
1103            max_tokens: 32,
1104            temperature: 0.5,
1105        });
1106        assert_eq!(prompt_for(&llm), "hi");
1107
1108        let llm_empty = Task::Llm(LlmParams {
1109            messages: vec![],
1110            max_tokens: 1,
1111            temperature: 0.0,
1112        });
1113        assert_eq!(prompt_for(&llm_empty), "");
1114
1115        let stt = Task::AudioStt(AudioSttParams {
1116            input_url: "https://example.com/clip.wav".into(),
1117            language: None,
1118        });
1119        assert_eq!(prompt_for(&stt), "https://example.com/clip.wav");
1120
1121        let tts = Task::AudioTts(AudioTtsParams {
1122            text: "hi there".into(),
1123            voice: "v".into(),
1124            ext: "wav".into(),
1125        });
1126        assert_eq!(prompt_for(&tts), "hi there");
1127
1128        let video = Task::Video(VideoParams {
1129            prompt: "a tiny dragon".into(),
1130            seconds: 1.0,
1131            width: 256,
1132            height: 256,
1133            ext: "mp4".into(),
1134        });
1135        assert_eq!(prompt_for(&video), "a tiny dragon");
1136    }
1137
1138    #[test]
1139    fn is_unsupported_kind_matches_engine_message() {
1140        let err = anyhow!("gradio engine cannot serve llm tasks");
1141        assert!(is_unsupported_kind(&err));
1142        let other = anyhow!("network timeout");
1143        assert!(!is_unsupported_kind(&other));
1144    }
1145
1146    #[test]
1147    fn next_delay_for_picks_idle_after_a_job() {
1148        assert_eq!(next_delay_for(&ClaimOutcome::RanJob), CLAIM_INTERVAL_IDLE);
1149    }
1150
1151    #[test]
1152    fn next_delay_for_backs_off_when_no_jobs_or_errors() {
1153        assert_eq!(
1154            next_delay_for(&ClaimOutcome::NoJobs),
1155            CLAIM_INTERVAL_AFTER_NULL
1156        );
1157        assert_eq!(
1158            next_delay_for(&ClaimOutcome::Error("boom".into())),
1159            CLAIM_INTERVAL_AFTER_NULL
1160        );
1161        assert_eq!(
1162            next_delay_for(&ClaimOutcome::Skipped),
1163            CLAIM_INTERVAL_AFTER_NULL
1164        );
1165    }
1166
1167    #[test]
1168    fn format_status_includes_every_field() {
1169        let cfg = Config::default();
1170        let out = format_status(&cfg, std::path::Path::new("/tmp/x.toml"));
1171        assert!(out.contains("config path:"));
1172        assert!(out.contains("api_base_url:"));
1173        assert!(out.contains("worker_id:"));
1174        assert!(out.contains("(not registered)"));
1175        assert!(out.contains("auto_update:"));
1176        assert!(out.contains("update_interval:"));
1177    }
1178
1179    #[test]
1180    fn format_status_shows_worker_id_when_registered() {
1181        let cfg = Config {
1182            worker_id: Some("w-abc".into()),
1183            ..Config::default()
1184        };
1185        let out = format_status(&cfg, std::path::Path::new("/tmp/x.toml"));
1186        assert!(out.contains("w-abc"));
1187    }
1188
1189    #[test]
1190    fn format_check_outcome_handles_both_branches() {
1191        let up = update::CheckOutcome::UpToDate {
1192            current: semver::Version::new(1, 2, 3),
1193        };
1194        assert!(format_check_outcome(&up).contains("up to date"));
1195        let newer = update::CheckOutcome::NewerAvailable {
1196            current: semver::Version::new(1, 2, 3),
1197            latest: semver::Version::new(1, 3, 0),
1198        };
1199        let s = format_check_outcome(&newer);
1200        assert!(s.contains("1.2.3 -> 1.3.0"));
1201    }
1202
1203    #[test]
1204    fn push_log_appends_an_entry() {
1205        let logs: Arc<Mutex<Vec<LogEntry>>> = Arc::new(Mutex::new(Vec::new()));
1206        push_log(&logs, "info", "test", "hi", None);
1207        push_log(&logs, "warn", "test", "wat", Some("j-1".into()));
1208        push_log(&logs, "error", "test", "boom", None);
1209        let v = logs.lock();
1210        assert_eq!(v.len(), 3);
1211        assert_eq!(v[0].level, "info");
1212        assert_eq!(v[1].level, "warn");
1213        assert_eq!(v[1].job_id.as_deref(), Some("j-1"));
1214        assert_eq!(v[2].level, "error");
1215    }
1216
1217    // --- async tick tests ---
1218
1219    fn cfg_pointing_at(api_base_url: String) -> Config {
1220        Config {
1221            api_base_url,
1222            worker_id: Some("w-test".into()),
1223            auth_token: Some("tok-test".into()),
1224            engine: "synthetic".into(),
1225            auto_enabled: true,
1226            auto_update_enabled: false,
1227            ..Config::default()
1228        }
1229    }
1230
1231    #[tokio::test]
1232    async fn claim_tick_returns_skipped_when_auto_enabled_is_false() {
1233        let cfg = Config {
1234            auto_enabled: false,
1235            ..Config::default()
1236        };
1237        let logs = Arc::new(Mutex::new(Vec::new()));
1238        let busy = Arc::new(AtomicBool::new(false));
1239        let observers = WorkerObservers::default();
1240        let outcome = claim_tick(&cfg, &logs, &busy, &observers).await;
1241        assert_eq!(outcome, ClaimOutcome::Skipped);
1242    }
1243
1244    #[tokio::test]
1245    async fn auto_update_tick_disabled_when_flag_off() {
1246        let cfg = Config {
1247            auto_update_enabled: false,
1248            ..Config::default()
1249        };
1250        let logs = Arc::new(Mutex::new(Vec::new()));
1251        let decision = auto_update_tick(&cfg, false, &logs).await;
1252        assert_eq!(decision, AutoUpdateDecision::Disabled);
1253    }
1254
1255    #[tokio::test]
1256    async fn auto_update_tick_skipped_when_busy() {
1257        let cfg = Config {
1258            auto_update_enabled: true,
1259            ..Config::default()
1260        };
1261        let logs = Arc::new(Mutex::new(Vec::new()));
1262        let decision = auto_update_tick(&cfg, true, &logs).await;
1263        assert_eq!(decision, AutoUpdateDecision::SkippedBusy);
1264        let entries = logs.lock();
1265        assert!(entries.iter().any(|e| e.message.contains("busy on a job")));
1266    }
1267
1268    #[tokio::test]
1269    async fn log_shipper_tick_returns_zero_when_buffer_empty() {
1270        let cfg = cfg_pointing_at("http://unused.invalid".into());
1271        let logs = Arc::new(Mutex::new(Vec::new()));
1272        let n = log_shipper_tick(&cfg, &logs).await;
1273        assert_eq!(n, 0);
1274    }
1275
1276    #[tokio::test]
1277    async fn log_shipper_tick_returns_zero_when_unregistered() {
1278        let cfg = Config {
1279            worker_id: None,
1280            auth_token: None,
1281            ..cfg_pointing_at("http://unused.invalid".into())
1282        };
1283        let logs = Arc::new(Mutex::new(vec![LogEntry {
1284            ts: "ts".into(),
1285            level: "info".into(),
1286            category: "x".into(),
1287            message: "m".into(),
1288            job_id: None,
1289        }]));
1290        let n = log_shipper_tick(&cfg, &logs).await;
1291        assert_eq!(n, 0);
1292    }
1293}