Skip to main content

provider_agent/
job_executor.rs

1//! Agent-side job executor. See `plan/V2_AGENT_SPEC.md` §7.
2//!
3//! Owns the dispatch table (`Vec<DiscoveredBackend>`) produced by
4//! [`crate::discovery`] and the outbound WSS channel. Receives `Job` payloads
5//! from the WSS read loop, looks up the backend that serves the requested
6//! `model_id`, spawns a tokio task that drives `Backend::execute`, and pumps
7//! emitted bytes back as `job_chunk` messages followed by either `job_done`
8//! or `job_error`.
9//!
10//! Concurrency is bounded by `limits.max_concurrent` from `agent.toml`: any
11//! job that arrives once that many are in flight is rejected immediately
12//! with `out_of_capacity`. Active jobs are tracked in a map so `job_cancel`
13//! from the coordinator can abort the underlying task.
14
15use std::collections::HashMap;
16use std::sync::Arc;
17use std::sync::atomic::{AtomicU32, Ordering};
18use std::time::Instant;
19
20use async_trait::async_trait;
21use base64::Engine as _;
22use base64::engine::general_purpose::STANDARD as B64;
23use bytes::Bytes;
24use serde_json::{Value, json};
25use tokio::sync::{Mutex, mpsc};
26use tokio::task::JoinHandle;
27use tokio::time::Duration;
28use tokio_tungstenite::tungstenite::Message;
29use tracing::{debug, error, info, warn};
30use uuid::Uuid;
31
32use crate::backend::{Backend, BackendError, Job, JobResult, JobSink};
33use crate::discovery::DiscoveredBackend;
34
35/// Maximum time to wait for a backpressured outbound channel before treating
36/// the WSS as fatally stuck for this job.
37const SEND_TIMEOUT: Duration = Duration::from_secs(5);
38
39/// Live record of a dispatched job. Held in [`JobExecutor::active`] until the
40/// driver task finishes (success, error, or cancel).
41struct JobHandle {
42    task: JoinHandle<()>,
43}
44
45/// Per-model dispatch entry derived from [`DiscoveredBackend`].
46#[derive(Clone)]
47struct ModelRoute {
48    backend: Arc<dyn Backend>,
49}
50
51pub struct JobExecutor {
52    /// model_id → backend. Local backends already win priority during
53    /// discovery, so the first hit per model_id is canonical.
54    routes: HashMap<String, ModelRoute>,
55    /// `limits.max_concurrent` from `agent.toml`.
56    max_concurrent: u32,
57    /// Currently in-flight jobs, by job_id. We `lock().await` on insert and
58    /// remove only — never held across `await` of work.
59    active: Arc<Mutex<HashMap<Uuid, JobHandle>>>,
60    /// Counter for capacity enforcement. Incremented on accept, decremented
61    /// on completion. Atomic so the WSS read loop never blocks on accept.
62    in_flight: Arc<AtomicU32>,
63    /// Shared outbound channel into the WSS write half.
64    out_tx: mpsc::Sender<Message>,
65}
66
67impl JobExecutor {
68    pub fn new(
69        backends: Vec<DiscoveredBackend>,
70        max_concurrent: u32,
71        out_tx: mpsc::Sender<Message>,
72    ) -> Self {
73        let mut routes: HashMap<String, ModelRoute> = HashMap::new();
74        for db in backends {
75            for m in &db.models {
76                routes
77                    .entry(m.model_id.clone())
78                    .or_insert_with(|| ModelRoute {
79                        backend: db.backend.clone(),
80                    });
81            }
82        }
83        info!(models = routes.len(), max_concurrent, "job executor ready");
84        Self {
85            routes,
86            max_concurrent,
87            active: Arc::new(Mutex::new(HashMap::new())),
88            in_flight: Arc::new(AtomicU32::new(0)),
89            out_tx,
90        }
91    }
92
93    /// Number of currently-executing jobs. Useful for heartbeat reporting.
94    pub fn queue_depth(&self) -> u32 {
95        self.in_flight.load(Ordering::Relaxed)
96    }
97
98    /// Accept a job from the WSS read loop and spawn its driver task.
99    ///
100    /// Never blocks the caller: capacity violations and unknown models are
101    /// signalled to the coordinator via `job_error` and the function returns
102    /// immediately.
103    pub async fn dispatch(&self, job: Job) {
104        // Capacity check first — cheap, no lookup needed.
105        let prev = self.in_flight.fetch_add(1, Ordering::AcqRel);
106        if prev >= self.max_concurrent {
107            self.in_flight.fetch_sub(1, Ordering::AcqRel);
108            warn!(job_id = %job.job_id, "rejecting job: out_of_capacity");
109            let _ = send_error(
110                &self.out_tx,
111                job.job_id,
112                "out_of_capacity",
113                "agent at max_concurrent",
114                0,
115            )
116            .await;
117            return;
118        }
119
120        let route = match self.routes.get(&job.model_id).cloned() {
121            Some(r) => r,
122            None => {
123                self.in_flight.fetch_sub(1, Ordering::AcqRel);
124                warn!(
125                    job_id = %job.job_id,
126                    model_id = %job.model_id,
127                    "rejecting job: model_not_loaded"
128                );
129                let _ = send_error(
130                    &self.out_tx,
131                    job.job_id,
132                    "model_not_loaded",
133                    "no backend serves this model",
134                    0,
135                )
136                .await;
137                return;
138            }
139        };
140
141        let job_id = job.job_id;
142        let out_tx = self.out_tx.clone();
143        let active = self.active.clone();
144        let in_flight = self.in_flight.clone();
145        let deadline = Duration::from_millis(job.deadline_ms.max(1) as u64);
146
147        let task = tokio::spawn(async move {
148            let started = Instant::now();
149            let mut sink = WsJobSink::new(job_id, out_tx.clone());
150
151            let exec = route.backend.execute(&job, &mut sink);
152            let outcome = tokio::time::timeout(deadline, exec).await;
153
154            let final_msg: Value = match outcome {
155                Ok(Ok(JobResult {
156                    input_tokens,
157                    output_tokens,
158                    duration_ms,
159                })) => {
160                    let dur = if duration_ms == 0 {
161                        started.elapsed().as_millis().min(u32::MAX as u128) as u32
162                    } else {
163                        duration_ms
164                    };
165                    json!({
166                        "type": "job_done",
167                        "job_id": job_id,
168                        "tokens": {
169                            "input": input_tokens.unwrap_or(0),
170                            "output": output_tokens.unwrap_or(0),
171                            "input_tokens": input_tokens.unwrap_or(0),
172                            "output_tokens": output_tokens.unwrap_or(0),
173                        },
174                        "duration_ms": dur,
175                    })
176                }
177                Ok(Err(err)) => {
178                    let (code, msg) = map_backend_error(&err);
179                    warn!(%job_id, error_code = code, error = %err, "backend error");
180                    json!({
181                        "type": "job_error",
182                        "job_id": job_id,
183                        "error_code": code,
184                        "message": msg,
185                        "tokens_emitted": sink.bytes_sent(),
186                    })
187                }
188                Err(_elapsed) => {
189                    warn!(%job_id, "backend timeout exceeded deadline");
190                    json!({
191                        "type": "job_error",
192                        "job_id": job_id,
193                        "error_code": "backend_timeout",
194                        "message": "backend exceeded deadline_ms",
195                        "tokens_emitted": sink.bytes_sent(),
196                    })
197                }
198            };
199
200            // Best-effort terminal frame; if the WSS sender is dead the
201            // surrounding read loop will tear down anyway.
202            if out_tx
203                .send(Message::Text(final_msg.to_string().into()))
204                .await
205                .is_err()
206            {
207                debug!(%job_id, "outbound closed before terminal frame");
208            }
209
210            in_flight.fetch_sub(1, Ordering::AcqRel);
211            active.lock().await.remove(&job_id);
212        });
213
214        self.active.lock().await.insert(job_id, JobHandle { task });
215    }
216
217    /// Cancel a job in response to a coordinator `job_cancel`. Aborts the
218    /// driver task and removes the entry. The terminal frame the driver was
219    /// going to send is suppressed by the abort.
220    pub async fn cancel(&self, job_id: Uuid) {
221        let removed = self.active.lock().await.remove(&job_id);
222        match removed {
223            Some(h) => {
224                h.task.abort();
225                self.in_flight.fetch_sub(1, Ordering::AcqRel);
226                info!(%job_id, "job cancelled");
227            }
228            None => debug!(%job_id, "cancel for unknown job (already done?)"),
229        }
230    }
231}
232
233/// `JobSink` impl that wraps the WSS outbound channel. Each `send_chunk`
234/// base64-encodes the bytes and emits one `job_chunk` JSON frame. Backpressure
235/// surfaces as a `BackendError::Other("ws send timeout/closed")`, which the
236/// driver task converts to a `job_error` terminal frame.
237struct WsJobSink {
238    job_id: Uuid,
239    out_tx: mpsc::Sender<Message>,
240    bytes_sent: u64,
241}
242
243impl WsJobSink {
244    fn new(job_id: Uuid, out_tx: mpsc::Sender<Message>) -> Self {
245        Self {
246            job_id,
247            out_tx,
248            bytes_sent: 0,
249        }
250    }
251
252    fn bytes_sent(&self) -> u64 {
253        self.bytes_sent
254    }
255}
256
257#[async_trait]
258impl JobSink for WsJobSink {
259    async fn send_chunk(&mut self, bytes: Bytes) -> Result<(), BackendError> {
260        let frame = json!({
261            "type": "job_chunk",
262            "job_id": self.job_id,
263            "data": B64.encode(&bytes),
264        });
265        let msg = Message::Text(frame.to_string().into());
266        match tokio::time::timeout(SEND_TIMEOUT, self.out_tx.send(msg)).await {
267            Ok(Ok(())) => {
268                self.bytes_sent = self.bytes_sent.saturating_add(bytes.len() as u64);
269                Ok(())
270            }
271            Ok(Err(_closed)) => {
272                error!(job_id = %self.job_id, "ws outbound closed mid-stream");
273                Err(BackendError::Other("ws outbound closed".into()))
274            }
275            Err(_elapsed) => {
276                error!(job_id = %self.job_id, "ws outbound backpressured >5s");
277                Err(BackendError::Other("ws outbound send timeout".into()))
278            }
279        }
280    }
281}
282
283fn map_backend_error(err: &BackendError) -> (&'static str, String) {
284    match err {
285        BackendError::Unreachable(_) => ("backend_unreachable", err.to_string()),
286        BackendError::Timeout => ("backend_timeout", err.to_string()),
287        BackendError::ModelNotFound(_) => ("model_not_loaded", err.to_string()),
288        BackendError::MissingApiKey(_) => ("auth_rejected_by_backend", err.to_string()),
289        BackendError::BadStatus { status, .. } if *status == 401 || *status == 403 => {
290            ("auth_rejected_by_backend", err.to_string())
291        }
292        _ => ("internal", err.to_string()),
293    }
294}
295
296async fn send_error(
297    out_tx: &mpsc::Sender<Message>,
298    job_id: Uuid,
299    code: &str,
300    msg: &str,
301    tokens_emitted: u64,
302) -> Result<(), mpsc::error::SendError<Message>> {
303    let frame = json!({
304        "type": "job_error",
305        "job_id": job_id,
306        "error_code": code,
307        "message": msg,
308        "tokens_emitted": tokens_emitted,
309    });
310    out_tx.send(Message::Text(frame.to_string().into())).await
311}
312
313#[cfg(test)]
314mod tests {
315    use super::*;
316
317    #[test]
318    fn map_backend_error_codes() {
319        assert_eq!(
320            map_backend_error(&BackendError::Unreachable("x".into())).0,
321            "backend_unreachable"
322        );
323        assert_eq!(
324            map_backend_error(&BackendError::Timeout).0,
325            "backend_timeout"
326        );
327        assert_eq!(
328            map_backend_error(&BackendError::ModelNotFound("m".into())).0,
329            "model_not_loaded"
330        );
331        assert_eq!(
332            map_backend_error(&BackendError::MissingApiKey("openrouter")).0,
333            "auth_rejected_by_backend"
334        );
335        assert_eq!(
336            map_backend_error(&BackendError::BadStatus {
337                status: 401,
338                body: "x".into()
339            })
340            .0,
341            "auth_rejected_by_backend"
342        );
343        assert_eq!(
344            map_backend_error(&BackendError::Other("x".into())).0,
345            "internal"
346        );
347    }
348}