1use std::collections::HashMap;
16use std::sync::Arc;
17use std::sync::atomic::{AtomicU32, Ordering};
18use std::time::Instant;
19
20use async_trait::async_trait;
21use base64::Engine as _;
22use base64::engine::general_purpose::STANDARD as B64;
23use bytes::Bytes;
24use serde_json::{Value, json};
25use tokio::sync::{Mutex, mpsc};
26use tokio::task::JoinHandle;
27use tokio::time::Duration;
28use tokio_tungstenite::tungstenite::Message;
29use tracing::{debug, error, info, warn};
30use uuid::Uuid;
31
32use crate::backend::{Backend, BackendError, Job, JobResult, JobSink};
33use crate::discovery::DiscoveredBackend;
34
35const SEND_TIMEOUT: Duration = Duration::from_secs(5);
38
39struct JobHandle {
42 task: JoinHandle<()>,
43}
44
45#[derive(Clone)]
47struct ModelRoute {
48 backend: Arc<dyn Backend>,
49}
50
51pub struct JobExecutor {
52 routes: HashMap<String, ModelRoute>,
55 max_concurrent: u32,
57 active: Arc<Mutex<HashMap<Uuid, JobHandle>>>,
60 in_flight: Arc<AtomicU32>,
63 out_tx: mpsc::Sender<Message>,
65}
66
67impl JobExecutor {
68 pub fn new(
69 backends: Vec<DiscoveredBackend>,
70 max_concurrent: u32,
71 out_tx: mpsc::Sender<Message>,
72 ) -> Self {
73 let mut routes: HashMap<String, ModelRoute> = HashMap::new();
74 for db in backends {
75 for m in &db.models {
76 routes
77 .entry(m.model_id.clone())
78 .or_insert_with(|| ModelRoute {
79 backend: db.backend.clone(),
80 });
81 }
82 }
83 info!(models = routes.len(), max_concurrent, "job executor ready");
84 Self {
85 routes,
86 max_concurrent,
87 active: Arc::new(Mutex::new(HashMap::new())),
88 in_flight: Arc::new(AtomicU32::new(0)),
89 out_tx,
90 }
91 }
92
93 pub fn queue_depth(&self) -> u32 {
95 self.in_flight.load(Ordering::Relaxed)
96 }
97
98 pub async fn dispatch(&self, job: Job) {
104 let prev = self.in_flight.fetch_add(1, Ordering::AcqRel);
106 if prev >= self.max_concurrent {
107 self.in_flight.fetch_sub(1, Ordering::AcqRel);
108 warn!(job_id = %job.job_id, "rejecting job: out_of_capacity");
109 let _ = send_error(
110 &self.out_tx,
111 job.job_id,
112 "out_of_capacity",
113 "agent at max_concurrent",
114 0,
115 )
116 .await;
117 return;
118 }
119
120 let route = match self.routes.get(&job.model_id).cloned() {
121 Some(r) => r,
122 None => {
123 self.in_flight.fetch_sub(1, Ordering::AcqRel);
124 warn!(
125 job_id = %job.job_id,
126 model_id = %job.model_id,
127 "rejecting job: model_not_loaded"
128 );
129 let _ = send_error(
130 &self.out_tx,
131 job.job_id,
132 "model_not_loaded",
133 "no backend serves this model",
134 0,
135 )
136 .await;
137 return;
138 }
139 };
140
141 let job_id = job.job_id;
142 let out_tx = self.out_tx.clone();
143 let active = self.active.clone();
144 let in_flight = self.in_flight.clone();
145 let deadline = Duration::from_millis(job.deadline_ms.max(1) as u64);
146
147 let task = tokio::spawn(async move {
148 let started = Instant::now();
149 let mut sink = WsJobSink::new(job_id, out_tx.clone());
150
151 let exec = route.backend.execute(&job, &mut sink);
152 let outcome = tokio::time::timeout(deadline, exec).await;
153
154 let final_msg: Value = match outcome {
155 Ok(Ok(JobResult {
156 input_tokens,
157 output_tokens,
158 duration_ms,
159 })) => {
160 let dur = if duration_ms == 0 {
161 started.elapsed().as_millis().min(u32::MAX as u128) as u32
162 } else {
163 duration_ms
164 };
165 json!({
166 "type": "job_done",
167 "job_id": job_id,
168 "tokens": {
169 "input": input_tokens.unwrap_or(0),
170 "output": output_tokens.unwrap_or(0),
171 "input_tokens": input_tokens.unwrap_or(0),
172 "output_tokens": output_tokens.unwrap_or(0),
173 },
174 "duration_ms": dur,
175 })
176 }
177 Ok(Err(err)) => {
178 let (code, msg) = map_backend_error(&err);
179 warn!(%job_id, error_code = code, error = %err, "backend error");
180 json!({
181 "type": "job_error",
182 "job_id": job_id,
183 "error_code": code,
184 "message": msg,
185 "tokens_emitted": sink.bytes_sent(),
186 })
187 }
188 Err(_elapsed) => {
189 warn!(%job_id, "backend timeout exceeded deadline");
190 json!({
191 "type": "job_error",
192 "job_id": job_id,
193 "error_code": "backend_timeout",
194 "message": "backend exceeded deadline_ms",
195 "tokens_emitted": sink.bytes_sent(),
196 })
197 }
198 };
199
200 if out_tx
203 .send(Message::Text(final_msg.to_string().into()))
204 .await
205 .is_err()
206 {
207 debug!(%job_id, "outbound closed before terminal frame");
208 }
209
210 in_flight.fetch_sub(1, Ordering::AcqRel);
211 active.lock().await.remove(&job_id);
212 });
213
214 self.active.lock().await.insert(job_id, JobHandle { task });
215 }
216
217 pub async fn cancel(&self, job_id: Uuid) {
221 let removed = self.active.lock().await.remove(&job_id);
222 match removed {
223 Some(h) => {
224 h.task.abort();
225 self.in_flight.fetch_sub(1, Ordering::AcqRel);
226 info!(%job_id, "job cancelled");
227 }
228 None => debug!(%job_id, "cancel for unknown job (already done?)"),
229 }
230 }
231}
232
233struct WsJobSink {
238 job_id: Uuid,
239 out_tx: mpsc::Sender<Message>,
240 bytes_sent: u64,
241}
242
243impl WsJobSink {
244 fn new(job_id: Uuid, out_tx: mpsc::Sender<Message>) -> Self {
245 Self {
246 job_id,
247 out_tx,
248 bytes_sent: 0,
249 }
250 }
251
252 fn bytes_sent(&self) -> u64 {
253 self.bytes_sent
254 }
255}
256
257#[async_trait]
258impl JobSink for WsJobSink {
259 async fn send_chunk(&mut self, bytes: Bytes) -> Result<(), BackendError> {
260 let frame = json!({
261 "type": "job_chunk",
262 "job_id": self.job_id,
263 "data": B64.encode(&bytes),
264 });
265 let msg = Message::Text(frame.to_string().into());
266 match tokio::time::timeout(SEND_TIMEOUT, self.out_tx.send(msg)).await {
267 Ok(Ok(())) => {
268 self.bytes_sent = self.bytes_sent.saturating_add(bytes.len() as u64);
269 Ok(())
270 }
271 Ok(Err(_closed)) => {
272 error!(job_id = %self.job_id, "ws outbound closed mid-stream");
273 Err(BackendError::Other("ws outbound closed".into()))
274 }
275 Err(_elapsed) => {
276 error!(job_id = %self.job_id, "ws outbound backpressured >5s");
277 Err(BackendError::Other("ws outbound send timeout".into()))
278 }
279 }
280 }
281}
282
283fn map_backend_error(err: &BackendError) -> (&'static str, String) {
284 match err {
285 BackendError::Unreachable(_) => ("backend_unreachable", err.to_string()),
286 BackendError::Timeout => ("backend_timeout", err.to_string()),
287 BackendError::ModelNotFound(_) => ("model_not_loaded", err.to_string()),
288 BackendError::MissingApiKey(_) => ("auth_rejected_by_backend", err.to_string()),
289 BackendError::BadStatus { status, .. } if *status == 401 || *status == 403 => {
290 ("auth_rejected_by_backend", err.to_string())
291 }
292 _ => ("internal", err.to_string()),
293 }
294}
295
296async fn send_error(
297 out_tx: &mpsc::Sender<Message>,
298 job_id: Uuid,
299 code: &str,
300 msg: &str,
301 tokens_emitted: u64,
302) -> Result<(), mpsc::error::SendError<Message>> {
303 let frame = json!({
304 "type": "job_error",
305 "job_id": job_id,
306 "error_code": code,
307 "message": msg,
308 "tokens_emitted": tokens_emitted,
309 });
310 out_tx.send(Message::Text(frame.to_string().into())).await
311}
312
313#[cfg(test)]
314mod tests {
315 use super::*;
316
317 #[test]
318 fn map_backend_error_codes() {
319 assert_eq!(
320 map_backend_error(&BackendError::Unreachable("x".into())).0,
321 "backend_unreachable"
322 );
323 assert_eq!(
324 map_backend_error(&BackendError::Timeout).0,
325 "backend_timeout"
326 );
327 assert_eq!(
328 map_backend_error(&BackendError::ModelNotFound("m".into())).0,
329 "model_not_loaded"
330 );
331 assert_eq!(
332 map_backend_error(&BackendError::MissingApiKey("openrouter")).0,
333 "auth_rejected_by_backend"
334 );
335 assert_eq!(
336 map_backend_error(&BackendError::BadStatus {
337 status: 401,
338 body: "x".into()
339 })
340 .0,
341 "auth_rejected_by_backend"
342 );
343 assert_eq!(
344 map_backend_error(&BackendError::Other("x".into())).0,
345 "internal"
346 );
347 }
348}