Skip to main content

provider_agent/
job_executor.rs

1//! Agent-side job executor. See `plan/V2_AGENT_SPEC.md` §7.
2//!
3//! Owns the dispatch table (`Vec<DiscoveredBackend>`) produced by
4//! [`crate::discovery`] and the outbound WSS channel. Receives `Job` payloads
5//! from the WSS read loop, looks up the backend that serves the requested
6//! `model_id`, spawns a tokio task that drives `Backend::execute`, and pumps
7//! emitted bytes back as `job_chunk` messages followed by either `job_done`
8//! or `job_error`.
9//!
10//! Concurrency is bounded by `limits.max_concurrent` from `agent.toml`: any
11//! job that arrives once that many are in flight is rejected immediately
12//! with `out_of_capacity`. Active jobs are tracked in a map so `job_cancel`
13//! from the coordinator can abort the underlying task.
14
15use std::collections::HashMap;
16use std::sync::Arc;
17use std::sync::atomic::{AtomicU32, Ordering};
18use std::time::Instant;
19
20use async_trait::async_trait;
21use base64::Engine as _;
22use base64::engine::general_purpose::STANDARD as B64;
23use bytes::Bytes;
24use serde_json::{Value, json};
25use tokio::sync::{Mutex, mpsc};
26use tokio::task::JoinHandle;
27use tokio::time::Duration;
28use tokio_tungstenite::tungstenite::Message;
29use tracing::{debug, error, info, warn};
30use uuid::Uuid;
31
32use crate::backend::{Backend, BackendError, Job, JobResult, JobSink};
33use crate::discovery::DiscoveredBackend;
34
35/// Maximum time to wait for a backpressured outbound channel before treating
36/// the WSS as fatally stuck for this job.
37const SEND_TIMEOUT: Duration = Duration::from_secs(5);
38
39/// Live record of a dispatched job. Held in [`JobExecutor::active`] until the
40/// driver task finishes (success, error, or cancel).
41struct JobHandle {
42    task: JoinHandle<()>,
43}
44
45/// Per-model dispatch entry derived from [`DiscoveredBackend`].
46#[derive(Clone)]
47struct ModelRoute {
48    backend: Arc<dyn Backend>,
49}
50
51pub struct JobExecutor {
52    /// model_id → backend. Local backends already win priority during
53    /// discovery, so the first hit per model_id is canonical.
54    routes: HashMap<String, ModelRoute>,
55    /// `limits.max_concurrent` from `agent.toml`.
56    max_concurrent: u32,
57    /// Currently in-flight jobs, by job_id. We `lock().await` on insert and
58    /// remove only — never held across `await` of work.
59    active: Arc<Mutex<HashMap<Uuid, JobHandle>>>,
60    /// Counter for capacity enforcement. Incremented on accept, decremented
61    /// on completion. Atomic so the WSS read loop never blocks on accept.
62    in_flight: Arc<AtomicU32>,
63    /// Shared outbound channel into the WSS write half.
64    out_tx: mpsc::Sender<Message>,
65}
66
67impl JobExecutor {
68    pub fn new(
69        backends: Vec<DiscoveredBackend>,
70        max_concurrent: u32,
71        out_tx: mpsc::Sender<Message>,
72    ) -> Self {
73        let mut routes: HashMap<String, ModelRoute> = HashMap::new();
74        for db in backends {
75            for m in &db.models {
76                routes
77                    .entry(m.model_id.clone())
78                    .or_insert_with(|| ModelRoute {
79                        backend: db.backend.clone(),
80                    });
81            }
82        }
83        info!(
84            models = routes.len(),
85            max_concurrent, "job executor ready"
86        );
87        Self {
88            routes,
89            max_concurrent,
90            active: Arc::new(Mutex::new(HashMap::new())),
91            in_flight: Arc::new(AtomicU32::new(0)),
92            out_tx,
93        }
94    }
95
96    /// Number of currently-executing jobs. Useful for heartbeat reporting.
97    pub fn queue_depth(&self) -> u32 {
98        self.in_flight.load(Ordering::Relaxed)
99    }
100
101    /// Accept a job from the WSS read loop and spawn its driver task.
102    ///
103    /// Never blocks the caller: capacity violations and unknown models are
104    /// signalled to the coordinator via `job_error` and the function returns
105    /// immediately.
106    pub async fn dispatch(&self, job: Job) {
107        // Capacity check first — cheap, no lookup needed.
108        let prev = self.in_flight.fetch_add(1, Ordering::AcqRel);
109        if prev >= self.max_concurrent {
110            self.in_flight.fetch_sub(1, Ordering::AcqRel);
111            warn!(job_id = %job.job_id, "rejecting job: out_of_capacity");
112            let _ = send_error(
113                &self.out_tx,
114                job.job_id,
115                "out_of_capacity",
116                "agent at max_concurrent",
117                0,
118            )
119            .await;
120            return;
121        }
122
123        let route = match self.routes.get(&job.model_id).cloned() {
124            Some(r) => r,
125            None => {
126                self.in_flight.fetch_sub(1, Ordering::AcqRel);
127                warn!(
128                    job_id = %job.job_id,
129                    model_id = %job.model_id,
130                    "rejecting job: model_not_loaded"
131                );
132                let _ = send_error(
133                    &self.out_tx,
134                    job.job_id,
135                    "model_not_loaded",
136                    "no backend serves this model",
137                    0,
138                )
139                .await;
140                return;
141            }
142        };
143
144        let job_id = job.job_id;
145        let out_tx = self.out_tx.clone();
146        let active = self.active.clone();
147        let in_flight = self.in_flight.clone();
148        let deadline = Duration::from_millis(job.deadline_ms.max(1) as u64);
149
150        let task = tokio::spawn(async move {
151            let started = Instant::now();
152            let mut sink = WsJobSink::new(job_id, out_tx.clone());
153
154            let exec = route.backend.execute(&job, &mut sink);
155            let outcome = tokio::time::timeout(deadline, exec).await;
156
157            let final_msg: Value = match outcome {
158                Ok(Ok(JobResult { input_tokens, output_tokens, duration_ms })) => {
159                    let dur = if duration_ms == 0 {
160                        started.elapsed().as_millis().min(u32::MAX as u128) as u32
161                    } else {
162                        duration_ms
163                    };
164                    json!({
165                        "type": "job_done",
166                        "job_id": job_id,
167                        "tokens": {
168                            "input": input_tokens.unwrap_or(0),
169                            "output": output_tokens.unwrap_or(0),
170                        },
171                        "duration_ms": dur,
172                    })
173                }
174                Ok(Err(err)) => {
175                    let (code, msg) = map_backend_error(&err);
176                    warn!(%job_id, error_code = code, error = %err, "backend error");
177                    json!({
178                        "type": "job_error",
179                        "job_id": job_id,
180                        "error_code": code,
181                        "message": msg,
182                        "tokens_emitted": sink.bytes_sent(),
183                    })
184                }
185                Err(_elapsed) => {
186                    warn!(%job_id, "backend timeout exceeded deadline");
187                    json!({
188                        "type": "job_error",
189                        "job_id": job_id,
190                        "error_code": "backend_timeout",
191                        "message": "backend exceeded deadline_ms",
192                        "tokens_emitted": sink.bytes_sent(),
193                    })
194                }
195            };
196
197            // Best-effort terminal frame; if the WSS sender is dead the
198            // surrounding read loop will tear down anyway.
199            if out_tx
200                .send(Message::Text(final_msg.to_string().into()))
201                .await
202                .is_err()
203            {
204                debug!(%job_id, "outbound closed before terminal frame");
205            }
206
207            in_flight.fetch_sub(1, Ordering::AcqRel);
208            active.lock().await.remove(&job_id);
209        });
210
211        self.active
212            .lock()
213            .await
214            .insert(job_id, JobHandle { task });
215    }
216
217    /// Cancel a job in response to a coordinator `job_cancel`. Aborts the
218    /// driver task and removes the entry. The terminal frame the driver was
219    /// going to send is suppressed by the abort.
220    pub async fn cancel(&self, job_id: Uuid) {
221        let removed = self.active.lock().await.remove(&job_id);
222        match removed {
223            Some(h) => {
224                h.task.abort();
225                self.in_flight.fetch_sub(1, Ordering::AcqRel);
226                info!(%job_id, "job cancelled");
227            }
228            None => debug!(%job_id, "cancel for unknown job (already done?)"),
229        }
230    }
231}
232
233/// `JobSink` impl that wraps the WSS outbound channel. Each `send_chunk`
234/// base64-encodes the bytes and emits one `job_chunk` JSON frame. Backpressure
235/// surfaces as a `BackendError::Other("ws send timeout/closed")`, which the
236/// driver task converts to a `job_error` terminal frame.
237struct WsJobSink {
238    job_id: Uuid,
239    out_tx: mpsc::Sender<Message>,
240    bytes_sent: u64,
241}
242
243impl WsJobSink {
244    fn new(job_id: Uuid, out_tx: mpsc::Sender<Message>) -> Self {
245        Self { job_id, out_tx, bytes_sent: 0 }
246    }
247
248    fn bytes_sent(&self) -> u64 {
249        self.bytes_sent
250    }
251}
252
253#[async_trait]
254impl JobSink for WsJobSink {
255    async fn send_chunk(&mut self, bytes: Bytes) -> Result<(), BackendError> {
256        let frame = json!({
257            "type": "job_chunk",
258            "job_id": self.job_id,
259            "data": B64.encode(&bytes),
260        });
261        let msg = Message::Text(frame.to_string().into());
262        match tokio::time::timeout(SEND_TIMEOUT, self.out_tx.send(msg)).await {
263            Ok(Ok(())) => {
264                self.bytes_sent = self.bytes_sent.saturating_add(bytes.len() as u64);
265                Ok(())
266            }
267            Ok(Err(_closed)) => {
268                error!(job_id = %self.job_id, "ws outbound closed mid-stream");
269                Err(BackendError::Other("ws outbound closed".into()))
270            }
271            Err(_elapsed) => {
272                error!(job_id = %self.job_id, "ws outbound backpressured >5s");
273                Err(BackendError::Other("ws outbound send timeout".into()))
274            }
275        }
276    }
277}
278
279fn map_backend_error(err: &BackendError) -> (&'static str, String) {
280    match err {
281        BackendError::Unreachable(_) => ("backend_unreachable", err.to_string()),
282        BackendError::Timeout => ("backend_timeout", err.to_string()),
283        BackendError::ModelNotFound(_) => ("model_not_loaded", err.to_string()),
284        BackendError::MissingApiKey(_) => ("auth_rejected_by_backend", err.to_string()),
285        BackendError::BadStatus { status, .. } if *status == 401 || *status == 403 => {
286            ("auth_rejected_by_backend", err.to_string())
287        }
288        _ => ("internal", err.to_string()),
289    }
290}
291
292async fn send_error(
293    out_tx: &mpsc::Sender<Message>,
294    job_id: Uuid,
295    code: &str,
296    msg: &str,
297    tokens_emitted: u64,
298) -> Result<(), mpsc::error::SendError<Message>> {
299    let frame = json!({
300        "type": "job_error",
301        "job_id": job_id,
302        "error_code": code,
303        "message": msg,
304        "tokens_emitted": tokens_emitted,
305    });
306    out_tx.send(Message::Text(frame.to_string().into())).await
307}
308
309#[cfg(test)]
310mod tests {
311    use super::*;
312
313    #[test]
314    fn map_backend_error_codes() {
315        assert_eq!(
316            map_backend_error(&BackendError::Unreachable("x".into())).0,
317            "backend_unreachable"
318        );
319        assert_eq!(map_backend_error(&BackendError::Timeout).0, "backend_timeout");
320        assert_eq!(
321            map_backend_error(&BackendError::ModelNotFound("m".into())).0,
322            "model_not_loaded"
323        );
324        assert_eq!(
325            map_backend_error(&BackendError::MissingApiKey("openrouter")).0,
326            "auth_rejected_by_backend"
327        );
328        assert_eq!(
329            map_backend_error(&BackendError::BadStatus { status: 401, body: "x".into() }).0,
330            "auth_rejected_by_backend"
331        );
332        assert_eq!(
333            map_backend_error(&BackendError::Other("x".into())).0,
334            "internal"
335        );
336    }
337}