1use std::collections::HashMap;
16use std::sync::Arc;
17use std::sync::atomic::{AtomicU32, Ordering};
18use std::time::Instant;
19
20use async_trait::async_trait;
21use base64::Engine as _;
22use base64::engine::general_purpose::STANDARD as B64;
23use bytes::Bytes;
24use serde_json::{Value, json};
25use tokio::sync::{Mutex, mpsc};
26use tokio::task::JoinHandle;
27use tokio::time::Duration;
28use tokio_tungstenite::tungstenite::Message;
29use tracing::{debug, error, info, warn};
30use uuid::Uuid;
31
32use crate::backend::{Backend, BackendError, Job, JobResult, JobSink};
33use crate::discovery::DiscoveredBackend;
34
35const SEND_TIMEOUT: Duration = Duration::from_secs(5);
38
39struct JobHandle {
42 task: JoinHandle<()>,
43}
44
45#[derive(Clone)]
47struct ModelRoute {
48 backend: Arc<dyn Backend>,
49}
50
51pub struct JobExecutor {
52 routes: HashMap<String, ModelRoute>,
55 max_concurrent: u32,
57 active: Arc<Mutex<HashMap<Uuid, JobHandle>>>,
60 in_flight: Arc<AtomicU32>,
63 out_tx: mpsc::Sender<Message>,
65}
66
67impl JobExecutor {
68 pub fn new(
69 backends: Vec<DiscoveredBackend>,
70 max_concurrent: u32,
71 out_tx: mpsc::Sender<Message>,
72 ) -> Self {
73 let mut routes: HashMap<String, ModelRoute> = HashMap::new();
74 for db in backends {
75 for m in &db.models {
76 routes
77 .entry(m.model_id.clone())
78 .or_insert_with(|| ModelRoute {
79 backend: db.backend.clone(),
80 });
81 }
82 }
83 info!(
84 models = routes.len(),
85 max_concurrent, "job executor ready"
86 );
87 Self {
88 routes,
89 max_concurrent,
90 active: Arc::new(Mutex::new(HashMap::new())),
91 in_flight: Arc::new(AtomicU32::new(0)),
92 out_tx,
93 }
94 }
95
96 pub fn queue_depth(&self) -> u32 {
98 self.in_flight.load(Ordering::Relaxed)
99 }
100
101 pub async fn dispatch(&self, job: Job) {
107 let prev = self.in_flight.fetch_add(1, Ordering::AcqRel);
109 if prev >= self.max_concurrent {
110 self.in_flight.fetch_sub(1, Ordering::AcqRel);
111 warn!(job_id = %job.job_id, "rejecting job: out_of_capacity");
112 let _ = send_error(
113 &self.out_tx,
114 job.job_id,
115 "out_of_capacity",
116 "agent at max_concurrent",
117 0,
118 )
119 .await;
120 return;
121 }
122
123 let route = match self.routes.get(&job.model_id).cloned() {
124 Some(r) => r,
125 None => {
126 self.in_flight.fetch_sub(1, Ordering::AcqRel);
127 warn!(
128 job_id = %job.job_id,
129 model_id = %job.model_id,
130 "rejecting job: model_not_loaded"
131 );
132 let _ = send_error(
133 &self.out_tx,
134 job.job_id,
135 "model_not_loaded",
136 "no backend serves this model",
137 0,
138 )
139 .await;
140 return;
141 }
142 };
143
144 let job_id = job.job_id;
145 let out_tx = self.out_tx.clone();
146 let active = self.active.clone();
147 let in_flight = self.in_flight.clone();
148 let deadline = Duration::from_millis(job.deadline_ms.max(1) as u64);
149
150 let task = tokio::spawn(async move {
151 let started = Instant::now();
152 let mut sink = WsJobSink::new(job_id, out_tx.clone());
153
154 let exec = route.backend.execute(&job, &mut sink);
155 let outcome = tokio::time::timeout(deadline, exec).await;
156
157 let final_msg: Value = match outcome {
158 Ok(Ok(JobResult { input_tokens, output_tokens, duration_ms })) => {
159 let dur = if duration_ms == 0 {
160 started.elapsed().as_millis().min(u32::MAX as u128) as u32
161 } else {
162 duration_ms
163 };
164 json!({
165 "type": "job_done",
166 "job_id": job_id,
167 "tokens": {
168 "input": input_tokens.unwrap_or(0),
169 "output": output_tokens.unwrap_or(0),
170 },
171 "duration_ms": dur,
172 })
173 }
174 Ok(Err(err)) => {
175 let (code, msg) = map_backend_error(&err);
176 warn!(%job_id, error_code = code, error = %err, "backend error");
177 json!({
178 "type": "job_error",
179 "job_id": job_id,
180 "error_code": code,
181 "message": msg,
182 "tokens_emitted": sink.bytes_sent(),
183 })
184 }
185 Err(_elapsed) => {
186 warn!(%job_id, "backend timeout exceeded deadline");
187 json!({
188 "type": "job_error",
189 "job_id": job_id,
190 "error_code": "backend_timeout",
191 "message": "backend exceeded deadline_ms",
192 "tokens_emitted": sink.bytes_sent(),
193 })
194 }
195 };
196
197 if out_tx
200 .send(Message::Text(final_msg.to_string().into()))
201 .await
202 .is_err()
203 {
204 debug!(%job_id, "outbound closed before terminal frame");
205 }
206
207 in_flight.fetch_sub(1, Ordering::AcqRel);
208 active.lock().await.remove(&job_id);
209 });
210
211 self.active
212 .lock()
213 .await
214 .insert(job_id, JobHandle { task });
215 }
216
217 pub async fn cancel(&self, job_id: Uuid) {
221 let removed = self.active.lock().await.remove(&job_id);
222 match removed {
223 Some(h) => {
224 h.task.abort();
225 self.in_flight.fetch_sub(1, Ordering::AcqRel);
226 info!(%job_id, "job cancelled");
227 }
228 None => debug!(%job_id, "cancel for unknown job (already done?)"),
229 }
230 }
231}
232
233struct WsJobSink {
238 job_id: Uuid,
239 out_tx: mpsc::Sender<Message>,
240 bytes_sent: u64,
241}
242
243impl WsJobSink {
244 fn new(job_id: Uuid, out_tx: mpsc::Sender<Message>) -> Self {
245 Self { job_id, out_tx, bytes_sent: 0 }
246 }
247
248 fn bytes_sent(&self) -> u64 {
249 self.bytes_sent
250 }
251}
252
253#[async_trait]
254impl JobSink for WsJobSink {
255 async fn send_chunk(&mut self, bytes: Bytes) -> Result<(), BackendError> {
256 let frame = json!({
257 "type": "job_chunk",
258 "job_id": self.job_id,
259 "data": B64.encode(&bytes),
260 });
261 let msg = Message::Text(frame.to_string().into());
262 match tokio::time::timeout(SEND_TIMEOUT, self.out_tx.send(msg)).await {
263 Ok(Ok(())) => {
264 self.bytes_sent = self.bytes_sent.saturating_add(bytes.len() as u64);
265 Ok(())
266 }
267 Ok(Err(_closed)) => {
268 error!(job_id = %self.job_id, "ws outbound closed mid-stream");
269 Err(BackendError::Other("ws outbound closed".into()))
270 }
271 Err(_elapsed) => {
272 error!(job_id = %self.job_id, "ws outbound backpressured >5s");
273 Err(BackendError::Other("ws outbound send timeout".into()))
274 }
275 }
276 }
277}
278
279fn map_backend_error(err: &BackendError) -> (&'static str, String) {
280 match err {
281 BackendError::Unreachable(_) => ("backend_unreachable", err.to_string()),
282 BackendError::Timeout => ("backend_timeout", err.to_string()),
283 BackendError::ModelNotFound(_) => ("model_not_loaded", err.to_string()),
284 BackendError::MissingApiKey(_) => ("auth_rejected_by_backend", err.to_string()),
285 BackendError::BadStatus { status, .. } if *status == 401 || *status == 403 => {
286 ("auth_rejected_by_backend", err.to_string())
287 }
288 _ => ("internal", err.to_string()),
289 }
290}
291
292async fn send_error(
293 out_tx: &mpsc::Sender<Message>,
294 job_id: Uuid,
295 code: &str,
296 msg: &str,
297 tokens_emitted: u64,
298) -> Result<(), mpsc::error::SendError<Message>> {
299 let frame = json!({
300 "type": "job_error",
301 "job_id": job_id,
302 "error_code": code,
303 "message": msg,
304 "tokens_emitted": tokens_emitted,
305 });
306 out_tx.send(Message::Text(frame.to_string().into())).await
307}
308
309#[cfg(test)]
310mod tests {
311 use super::*;
312
313 #[test]
314 fn map_backend_error_codes() {
315 assert_eq!(
316 map_backend_error(&BackendError::Unreachable("x".into())).0,
317 "backend_unreachable"
318 );
319 assert_eq!(map_backend_error(&BackendError::Timeout).0, "backend_timeout");
320 assert_eq!(
321 map_backend_error(&BackendError::ModelNotFound("m".into())).0,
322 "model_not_loaded"
323 );
324 assert_eq!(
325 map_backend_error(&BackendError::MissingApiKey("openrouter")).0,
326 "auth_rejected_by_backend"
327 );
328 assert_eq!(
329 map_backend_error(&BackendError::BadStatus { status: 401, body: "x".into() }).0,
330 "auth_rejected_by_backend"
331 );
332 assert_eq!(
333 map_backend_error(&BackendError::Other("x".into())).0,
334 "internal"
335 );
336 }
337}