1use std::sync::atomic::{AtomicUsize, Ordering};
2use std::sync::Arc;
3use std::time::{Duration, SystemTime, UNIX_EPOCH};
4
5use async_trait::async_trait;
6use serde::{Deserialize, Serialize};
7use serde_json::Value;
8use thiserror::Error;
9use tokio::sync::{mpsc, oneshot, watch, Mutex, RwLock};
10use tokio::task::JoinHandle;
11use tokio::time::timeout;
12
13#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
15pub struct WorkerRequest {
16 pub request_id: String,
18 pub workflow_name: String,
20 pub node_id: String,
22 pub timeout_ms: Option<u64>,
24 pub operation: WorkerOperation,
26}
27
28#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
30#[serde(tag = "kind", rename_all = "snake_case")]
31pub enum WorkerOperation {
32 Llm {
34 model: String,
36 prompt: String,
38 scoped_input: Value,
40 },
41 Tool {
43 tool: String,
45 input: Value,
47 scoped_input: Value,
49 },
50}
51
52#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
54pub struct WorkerResponse {
55 pub request_id: String,
57 pub worker_id: String,
59 pub result: WorkerResult,
61 pub elapsed_ms: u64,
63}
64
65#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
67#[serde(tag = "status", rename_all = "snake_case")]
68pub enum WorkerResult {
69 Success {
71 output: Value,
73 },
74 Error {
76 error: WorkerProtocolError,
78 },
79}
80
81#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
83pub struct WorkerProtocolError {
84 pub code: WorkerErrorCode,
86 pub message: String,
88 pub retryable: bool,
90}
91
92#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
94#[serde(rename_all = "snake_case")]
95pub enum WorkerErrorCode {
96 QueueFull,
98 Unavailable,
100 Timeout,
102 ExecutionFailed,
104 CircuitOpen,
106 Cancelled,
108 InvalidRequest,
110}
111
112#[derive(Debug, Error, Clone, PartialEq, Eq)]
114pub enum WorkerPoolError {
115 #[error("worker queue is full")]
117 QueueFull,
118 #[error("no healthy worker available")]
120 NoHealthyWorker,
121 #[error("worker execution failed: {0:?}")]
123 Worker(WorkerProtocolError),
124 #[error("worker request timed out")]
126 Timeout,
127 #[error("worker pool is shutting down")]
129 ShuttingDown,
130 #[error("request rejected by circuit breaker")]
132 CircuitOpen,
133 #[error("worker request rejected: {reason}")]
135 InvalidRequest {
136 reason: String,
138 },
139}
140
141#[derive(Debug, Clone)]
143pub struct WorkerPoolOptions {
144 pub queue_capacity: usize,
146 pub health_probe_interval: Duration,
148 pub unavailable_after_failures: u32,
150 pub default_request_timeout: Option<Duration>,
152 pub security_policy: WorkerSecurityPolicy,
154}
155
156#[derive(Debug, Clone)]
158pub struct WorkerSecurityPolicy {
159 pub max_request_timeout_ms: u64,
161 pub max_request_payload_bytes: usize,
163 pub max_identifier_length: usize,
165}
166
167impl Default for WorkerPoolOptions {
168 fn default() -> Self {
169 Self {
170 queue_capacity: 64,
171 health_probe_interval: Duration::from_secs(5),
172 unavailable_after_failures: 3,
173 default_request_timeout: Some(Duration::from_secs(30)),
174 security_policy: WorkerSecurityPolicy::default(),
175 }
176 }
177}
178
179impl Default for WorkerSecurityPolicy {
180 fn default() -> Self {
181 Self {
182 max_request_timeout_ms: 120_000,
183 max_request_payload_bytes: 256 * 1024,
184 max_identifier_length: 128,
185 }
186 }
187}
188
189#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
191pub struct WorkerHealth {
192 pub worker_id: String,
194 pub status: WorkerHealthStatus,
196 pub consecutive_failures: u32,
198 pub last_probe_unix_ms: Option<u64>,
200}
201
202#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
204#[serde(rename_all = "snake_case")]
205pub enum WorkerHealthStatus {
206 Healthy,
208 Degraded,
210 Unavailable,
212}
213
214impl WorkerHealth {
215 fn new(worker_id: String) -> Self {
216 Self {
217 worker_id,
218 status: WorkerHealthStatus::Healthy,
219 consecutive_failures: 0,
220 last_probe_unix_ms: None,
221 }
222 }
223
224 fn is_schedulable(&self) -> bool {
225 !matches!(self.status, WorkerHealthStatus::Unavailable)
226 }
227}
228
229#[async_trait]
231pub trait WorkerHandler: Send + Sync {
232 async fn handle(&self, request: WorkerRequest) -> Result<Value, WorkerProtocolError>;
234
235 async fn probe_health(&self) -> WorkerHealthStatus {
237 WorkerHealthStatus::Healthy
238 }
239}
240
241#[async_trait]
243pub trait CircuitBreakerHooks: Send + Sync {
244 fn allow_request(&self, _worker_id: &str, _request: &WorkerRequest) -> bool {
246 true
247 }
248
249 async fn on_request_accepted(&self, _worker_id: &str, _request: &WorkerRequest) {}
251
252 async fn on_request_success(&self, _worker_id: &str, _response: &WorkerResponse) {}
254
255 async fn on_request_failure(&self, _worker_id: &str, _error: &WorkerProtocolError) {}
257
258 async fn on_request_rejected(
260 &self,
261 _worker_id: Option<&str>,
262 _request: &WorkerRequest,
263 _reason: WorkerErrorCode,
264 ) {
265 }
266}
267
268struct WorkItem {
269 request: WorkerRequest,
270 response_tx: oneshot::Sender<Result<WorkerResponse, WorkerPoolError>>,
271}
272
273type WorkerResponseRx = oneshot::Receiver<Result<WorkerResponse, WorkerPoolError>>;
274type WorkerCandidate = (usize, String, mpsc::Sender<WorkItem>);
275type CandidateWithHealth = (
276 usize,
277 String,
278 mpsc::Sender<WorkItem>,
279 Arc<RwLock<WorkerHealth>>,
280);
281
282struct WorkerSlot {
283 worker_id: String,
284 sender: mpsc::Sender<WorkItem>,
285 shutdown_tx: watch::Sender<bool>,
286 worker_task: JoinHandle<()>,
287 probe_task: JoinHandle<()>,
288 health: Arc<RwLock<WorkerHealth>>,
289 handler: Arc<dyn WorkerHandler>,
290}
291
292pub struct WorkerPool {
294 options: WorkerPoolOptions,
295 slots: Mutex<Vec<WorkerSlot>>,
296 next_worker: AtomicUsize,
297 hooks: Option<Arc<dyn CircuitBreakerHooks>>,
298}
299
300#[async_trait]
302pub trait WorkerPoolClient: Send + Sync {
303 async fn submit(&self, request: WorkerRequest) -> Result<WorkerResponse, WorkerPoolError>;
305
306 async fn health_snapshot(&self) -> Vec<WorkerHealth>;
308}
309
310#[async_trait]
311impl WorkerPoolClient for WorkerPool {
312 async fn submit(&self, request: WorkerRequest) -> Result<WorkerResponse, WorkerPoolError> {
313 WorkerPool::submit(self, request).await
314 }
315
316 async fn health_snapshot(&self) -> Vec<WorkerHealth> {
317 WorkerPool::health_snapshot(self).await
318 }
319}
320
321impl WorkerPool {
322 pub fn new_inprocess(
324 handlers: Vec<Arc<dyn WorkerHandler>>,
325 options: WorkerPoolOptions,
326 hooks: Option<Arc<dyn CircuitBreakerHooks>>,
327 ) -> Result<Self, WorkerPoolError> {
328 if handlers.is_empty() {
329 return Err(WorkerPoolError::NoHealthyWorker);
330 }
331
332 let mut slots = Vec::with_capacity(handlers.len());
333 for (index, handler) in handlers.into_iter().enumerate() {
334 let worker_id = format!("worker-{}", index);
335 slots.push(spawn_worker_slot(
336 worker_id,
337 handler,
338 options.queue_capacity,
339 options.health_probe_interval,
340 options.unavailable_after_failures,
341 ));
342 }
343
344 Ok(Self {
345 options,
346 slots: Mutex::new(slots),
347 next_worker: AtomicUsize::new(0),
348 hooks,
349 })
350 }
351
352 pub async fn submit(&self, request: WorkerRequest) -> Result<WorkerResponse, WorkerPoolError> {
354 validate_request_contract(&request, &self.options.security_policy)?;
355 let candidates = self.select_worker_candidates(&request).await?;
356 let mut saw_queue_full = false;
357 let mut saw_circuit_open = false;
358
359 let mut selected_slot: Option<(usize, String, WorkerResponseRx)> = None;
360
361 for (slot_index, worker_id, sender) in candidates {
362 if let Some(hooks) = &self.hooks {
363 if !hooks.allow_request(&worker_id, &request) {
364 saw_circuit_open = true;
365 hooks
366 .on_request_rejected(
367 Some(&worker_id),
368 &request,
369 WorkerErrorCode::CircuitOpen,
370 )
371 .await;
372 continue;
373 }
374 hooks.on_request_accepted(&worker_id, &request).await;
375 }
376
377 let (response_tx, response_rx) = oneshot::channel();
378 let work_item = WorkItem {
379 request: request.clone(),
380 response_tx,
381 };
382
383 match sender.try_send(work_item) {
384 Ok(()) => {
385 selected_slot = Some((slot_index, worker_id, response_rx));
386 break;
387 }
388 Err(tokio::sync::mpsc::error::TrySendError::Full(_)) => {
389 saw_queue_full = true;
390 if let Some(hooks) = &self.hooks {
391 hooks
392 .on_request_rejected(
393 Some(&worker_id),
394 &request,
395 WorkerErrorCode::QueueFull,
396 )
397 .await;
398 }
399 }
400 Err(tokio::sync::mpsc::error::TrySendError::Closed(_)) => {
401 if let Some(hooks) = &self.hooks {
402 hooks
403 .on_request_rejected(
404 Some(&worker_id),
405 &request,
406 WorkerErrorCode::Unavailable,
407 )
408 .await;
409 }
410 }
411 }
412 }
413
414 let Some((slot_index, worker_id, response_rx)) = selected_slot else {
415 return if saw_queue_full {
416 Err(WorkerPoolError::QueueFull)
417 } else if saw_circuit_open {
418 Err(WorkerPoolError::CircuitOpen)
419 } else {
420 Err(WorkerPoolError::NoHealthyWorker)
421 };
422 };
423
424 let timeout_budget = request
425 .timeout_ms
426 .map(Duration::from_millis)
427 .or(self.options.default_request_timeout);
428
429 let outcome = if let Some(duration) = timeout_budget {
430 match timeout(duration, response_rx).await {
431 Ok(result) => result,
432 Err(_) => {
433 self.mark_unavailable(slot_index).await;
434 if let Some(hooks) = &self.hooks {
435 hooks
436 .on_request_rejected(
437 Some(&worker_id),
438 &request,
439 WorkerErrorCode::Timeout,
440 )
441 .await;
442 }
443 return Err(WorkerPoolError::Timeout);
444 }
445 }
446 } else {
447 response_rx.await
448 };
449
450 let response = outcome.map_err(|_| WorkerPoolError::ShuttingDown)??;
451
452 if let Some(hooks) = &self.hooks {
453 match &response.result {
454 WorkerResult::Success { .. } => {
455 hooks.on_request_success(&worker_id, &response).await
456 }
457 WorkerResult::Error { error } => hooks.on_request_failure(&worker_id, error).await,
458 }
459 }
460
461 match &response.result {
462 WorkerResult::Success { .. } => Ok(response),
463 WorkerResult::Error { error } => Err(WorkerPoolError::Worker(error.clone())),
464 }
465 }
466
467 pub async fn health_snapshot(&self) -> Vec<WorkerHealth> {
469 let (health_refs, worker_count) = {
470 let slots = self.slots.lock().await;
471 (
472 slots
473 .iter()
474 .map(|slot| Arc::clone(&slot.health))
475 .collect::<Vec<_>>(),
476 slots.len(),
477 )
478 };
479
480 let mut snapshot = Vec::with_capacity(worker_count);
481 for health in health_refs {
482 snapshot.push(health.read().await.clone());
483 }
484 snapshot
485 }
486
487 pub async fn restart_worker(&self, worker_id: &str) -> Result<(), WorkerPoolError> {
489 let mut slots = self.slots.lock().await;
490 let slot_index = slots
491 .iter()
492 .position(|slot| slot.worker_id == worker_id)
493 .ok_or(WorkerPoolError::NoHealthyWorker)?;
494
495 let old_slot = &slots[slot_index];
496 let _ = old_slot.shutdown_tx.send(true);
497
498 let replacement = spawn_worker_slot(
499 worker_id.to_string(),
500 Arc::clone(&old_slot.handler),
501 self.options.queue_capacity,
502 self.options.health_probe_interval,
503 self.options.unavailable_after_failures,
504 );
505 slots[slot_index] = replacement;
506 Ok(())
507 }
508
509 pub async fn shutdown(&self) {
511 let mut slots = self.slots.lock().await;
512 for slot in slots.iter_mut() {
513 let _ = slot.shutdown_tx.send(true);
514 slot.worker_task.abort();
515 slot.probe_task.abort();
516 }
517 }
518
519 async fn select_worker_candidates(
520 &self,
521 request: &WorkerRequest,
522 ) -> Result<Vec<WorkerCandidate>, WorkerPoolError> {
523 let candidates = {
524 let slots = self.slots.lock().await;
525 if slots.is_empty() {
526 Vec::<CandidateWithHealth>::new()
527 } else {
528 let start = self.next_worker.fetch_add(1, Ordering::Relaxed) % slots.len();
529 let mut candidates = Vec::<CandidateWithHealth>::with_capacity(slots.len());
530 for offset in 0..slots.len() {
531 let idx = (start + offset) % slots.len();
532 let slot = &slots[idx];
533 candidates.push((
534 idx,
535 slot.worker_id.clone(),
536 slot.sender.clone(),
537 Arc::clone(&slot.health),
538 ));
539 }
540 candidates
541 }
542 };
543
544 if candidates.is_empty() {
545 if let Some(hooks) = &self.hooks {
546 hooks
547 .on_request_rejected(None, request, WorkerErrorCode::Unavailable)
548 .await;
549 }
550 return Err(WorkerPoolError::NoHealthyWorker);
551 }
552
553 let mut schedulable = Vec::new();
554 for (idx, worker_id, sender, health_ref) in candidates {
555 if health_ref.read().await.is_schedulable() {
556 schedulable.push((idx, worker_id, sender));
557 }
558 }
559
560 if !schedulable.is_empty() {
561 return Ok(schedulable);
562 }
563
564 if let Some(hooks) = &self.hooks {
565 hooks
566 .on_request_rejected(None, request, WorkerErrorCode::Unavailable)
567 .await;
568 }
569 Err(WorkerPoolError::NoHealthyWorker)
570 }
571
572 async fn mark_unavailable(&self, slot_index: usize) {
573 let health_ref = {
574 let slots = self.slots.lock().await;
575 slots.get(slot_index).map(|slot| Arc::clone(&slot.health))
576 };
577
578 if let Some(health_ref) = health_ref {
579 let mut health = health_ref.write().await;
580 health.status = WorkerHealthStatus::Unavailable;
581 health.consecutive_failures = health.consecutive_failures.saturating_add(1);
582 health.last_probe_unix_ms = Some(now_unix_ms());
583 }
584 }
585}
586
587fn spawn_worker_slot(
588 worker_id: String,
589 handler: Arc<dyn WorkerHandler>,
590 queue_capacity: usize,
591 probe_interval: Duration,
592 unavailable_after_failures: u32,
593) -> WorkerSlot {
594 let (sender, mut receiver) = mpsc::channel::<WorkItem>(queue_capacity);
595 let (shutdown_tx, shutdown_rx) = watch::channel(false);
596 let health = Arc::new(RwLock::new(WorkerHealth::new(worker_id.clone())));
597
598 let worker_id_for_loop = worker_id.clone();
599 let handler_for_loop = Arc::clone(&handler);
600 let health_for_loop = Arc::clone(&health);
601 let mut shutdown_worker_rx = shutdown_rx.clone();
602 let worker_task = tokio::spawn(async move {
603 loop {
604 tokio::select! {
605 maybe_item = receiver.recv() => {
606 let Some(item) = maybe_item else {
607 break;
608 };
609
610 let started = std::time::Instant::now();
611 let result = handler_for_loop.handle(item.request.clone()).await;
612 let elapsed_ms = started.elapsed().as_millis() as u64;
613 let response = match result {
614 Ok(output) => {
615 update_health_on_success(&health_for_loop).await;
616 WorkerResponse {
617 request_id: item.request.request_id.clone(),
618 worker_id: worker_id_for_loop.clone(),
619 result: WorkerResult::Success { output },
620 elapsed_ms,
621 }
622 }
623 Err(error) => {
624 update_health_on_failure(
625 &health_for_loop,
626 unavailable_after_failures,
627 )
628 .await;
629 WorkerResponse {
630 request_id: item.request.request_id.clone(),
631 worker_id: worker_id_for_loop.clone(),
632 result: WorkerResult::Error { error },
633 elapsed_ms,
634 }
635 }
636 };
637 let _ = item.response_tx.send(Ok(response));
638 }
639 changed = shutdown_worker_rx.changed() => {
640 if changed.is_ok() && *shutdown_worker_rx.borrow() {
641 break;
642 }
643 }
644 }
645 }
646 });
647
648 let worker_id_for_probe = worker_id.clone();
649 let handler_for_probe = Arc::clone(&handler);
650 let health_for_probe = Arc::clone(&health);
651 let mut shutdown_probe_rx = shutdown_rx.clone();
652 let probe_task = tokio::spawn(async move {
653 let mut ticker = tokio::time::interval(probe_interval);
654 loop {
655 tokio::select! {
656 _ = ticker.tick() => {
657 let status = handler_for_probe.probe_health().await;
658 let mut health = health_for_probe.write().await;
659 health.status = status;
660 if status == WorkerHealthStatus::Healthy {
661 health.consecutive_failures = 0;
662 }
663 health.last_probe_unix_ms = Some(now_unix_ms());
664 }
665 changed = shutdown_probe_rx.changed() => {
666 if changed.is_ok() && *shutdown_probe_rx.borrow() {
667 break;
668 }
669 }
670 }
671 }
672 let mut health = health_for_probe.write().await;
673 health.status = WorkerHealthStatus::Unavailable;
674 health.last_probe_unix_ms = Some(now_unix_ms());
675 health.worker_id = worker_id_for_probe;
676 });
677
678 WorkerSlot {
679 worker_id,
680 sender,
681 shutdown_tx,
682 worker_task,
683 probe_task,
684 health,
685 handler,
686 }
687}
688
689async fn update_health_on_success(health_ref: &Arc<RwLock<WorkerHealth>>) {
690 let mut health = health_ref.write().await;
691 health.status = WorkerHealthStatus::Healthy;
692 health.consecutive_failures = 0;
693 health.last_probe_unix_ms = Some(now_unix_ms());
694}
695
696async fn update_health_on_failure(
697 health_ref: &Arc<RwLock<WorkerHealth>>,
698 unavailable_after_failures: u32,
699) {
700 let mut health = health_ref.write().await;
701 health.consecutive_failures = health.consecutive_failures.saturating_add(1);
702 health.status = if health.consecutive_failures >= unavailable_after_failures {
703 WorkerHealthStatus::Unavailable
704 } else {
705 WorkerHealthStatus::Degraded
706 };
707 health.last_probe_unix_ms = Some(now_unix_ms());
708}
709
710fn now_unix_ms() -> u64 {
711 SystemTime::now()
712 .duration_since(UNIX_EPOCH)
713 .unwrap_or_default()
714 .as_millis() as u64
715}
716
717fn validate_request_contract(
718 request: &WorkerRequest,
719 policy: &WorkerSecurityPolicy,
720) -> Result<(), WorkerPoolError> {
721 if request.request_id.len() > policy.max_identifier_length {
722 return Err(WorkerPoolError::InvalidRequest {
723 reason: format!(
724 "request_id length {} exceeds max {}",
725 request.request_id.len(),
726 policy.max_identifier_length
727 ),
728 });
729 }
730 if request.workflow_name.len() > policy.max_identifier_length {
731 return Err(WorkerPoolError::InvalidRequest {
732 reason: format!(
733 "workflow_name length {} exceeds max {}",
734 request.workflow_name.len(),
735 policy.max_identifier_length
736 ),
737 });
738 }
739 if request.node_id.len() > policy.max_identifier_length {
740 return Err(WorkerPoolError::InvalidRequest {
741 reason: format!(
742 "node_id length {} exceeds max {}",
743 request.node_id.len(),
744 policy.max_identifier_length
745 ),
746 });
747 }
748 if let Some(timeout_ms) = request.timeout_ms {
749 if timeout_ms > policy.max_request_timeout_ms {
750 return Err(WorkerPoolError::InvalidRequest {
751 reason: format!(
752 "timeout_ms {} exceeds max {}",
753 timeout_ms, policy.max_request_timeout_ms
754 ),
755 });
756 }
757 }
758
759 let payload_size = estimate_payload_size(request);
760 if payload_size > policy.max_request_payload_bytes {
761 return Err(WorkerPoolError::InvalidRequest {
762 reason: format!(
763 "request payload {} bytes exceeds max {}",
764 payload_size, policy.max_request_payload_bytes
765 ),
766 });
767 }
768 Ok(())
769}
770
771fn estimate_payload_size(request: &WorkerRequest) -> usize {
772 serde_json::to_vec(request)
773 .map(|payload| payload.len())
774 .unwrap_or(usize::MAX)
775}
776
777#[cfg(test)]
778mod tests {
779 use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
780
781 use serde_json::json;
782 use tokio::time::sleep;
783
784 use super::*;
785
786 struct EchoWorker;
787
788 #[async_trait]
789 impl WorkerHandler for EchoWorker {
790 async fn handle(&self, request: WorkerRequest) -> Result<Value, WorkerProtocolError> {
791 Ok(json!({"node": request.node_id}))
792 }
793 }
794
795 struct SlowWorker {
796 delay: Duration,
797 }
798
799 #[async_trait]
800 impl WorkerHandler for SlowWorker {
801 async fn handle(&self, _request: WorkerRequest) -> Result<Value, WorkerProtocolError> {
802 sleep(self.delay).await;
803 Ok(json!({"status": "ok"}))
804 }
805 }
806
807 struct FlakyWorker {
808 available: AtomicBool,
809 calls: AtomicUsize,
810 }
811
812 #[async_trait]
813 impl WorkerHandler for FlakyWorker {
814 async fn handle(&self, _request: WorkerRequest) -> Result<Value, WorkerProtocolError> {
815 self.calls.fetch_add(1, Ordering::Relaxed);
816 if self.available.load(Ordering::Relaxed) {
817 Ok(json!({"status": "up"}))
818 } else {
819 Err(WorkerProtocolError {
820 code: WorkerErrorCode::Unavailable,
821 message: "worker unavailable".to_string(),
822 retryable: true,
823 })
824 }
825 }
826
827 async fn probe_health(&self) -> WorkerHealthStatus {
828 if self.available.load(Ordering::Relaxed) {
829 WorkerHealthStatus::Healthy
830 } else {
831 WorkerHealthStatus::Unavailable
832 }
833 }
834 }
835
836 fn sample_request(id: &str) -> WorkerRequest {
837 WorkerRequest {
838 request_id: id.to_string(),
839 workflow_name: "wf".to_string(),
840 node_id: "node-1".to_string(),
841 timeout_ms: None,
842 operation: WorkerOperation::Tool {
843 tool: "echo".to_string(),
844 input: json!({"x": 1}),
845 scoped_input: json!({"input": {}}),
846 },
847 }
848 }
849
850 #[test]
851 fn worker_protocol_roundtrip() {
852 let request = sample_request("req-1");
853 let serialized =
854 serde_json::to_string(&request).expect("request serialization should work");
855 let decoded: WorkerRequest =
856 serde_json::from_str(&serialized).expect("request deserialization should work");
857 assert_eq!(request, decoded);
858 }
859
860 #[tokio::test]
861 async fn routes_requests_across_worker_pool() {
862 let pool = WorkerPool::new_inprocess(
863 vec![Arc::new(EchoWorker), Arc::new(EchoWorker)],
864 WorkerPoolOptions {
865 queue_capacity: 4,
866 health_probe_interval: Duration::from_millis(10),
867 ..WorkerPoolOptions::default()
868 },
869 None,
870 )
871 .expect("pool should initialize");
872
873 let response = pool
874 .submit(sample_request("req-2"))
875 .await
876 .expect("request should succeed");
877 assert_eq!(response.request_id, "req-2");
878 assert_eq!(
879 response.result,
880 WorkerResult::Success {
881 output: json!({"node": "node-1"})
882 }
883 );
884
885 let health = pool.health_snapshot().await;
886 assert_eq!(health.len(), 2);
887 assert!(health.iter().all(|entry| entry.is_schedulable()));
888
889 pool.shutdown().await;
890 }
891
892 #[tokio::test]
893 async fn enforces_queue_backpressure_limits() {
894 let pool = WorkerPool::new_inprocess(
895 vec![Arc::new(SlowWorker {
896 delay: Duration::from_millis(80),
897 })],
898 WorkerPoolOptions {
899 queue_capacity: 1,
900 health_probe_interval: Duration::from_millis(100),
901 default_request_timeout: Some(Duration::from_secs(1)),
902 ..WorkerPoolOptions::default()
903 },
904 None,
905 )
906 .expect("pool should initialize");
907
908 let first = pool.submit(sample_request("q1"));
909 let second = pool.submit(sample_request("q2"));
910 let third = pool.submit(sample_request("q3"));
911
912 let (first_result, second_result, third_result) = tokio::join!(first, second, third);
913 let failures = [&first_result, &second_result, &third_result]
914 .iter()
915 .filter(|result| matches!(result, Err(WorkerPoolError::QueueFull)))
916 .count();
917 let successes = [&first_result, &second_result, &third_result]
918 .iter()
919 .filter(|result| result.is_ok())
920 .count();
921 assert!(failures >= 1);
922 assert!(successes >= 1);
923
924 pool.shutdown().await;
925 }
926
927 #[tokio::test]
928 async fn marks_worker_unavailable_after_failures_and_recovers_on_restart() {
929 let flaky = Arc::new(FlakyWorker {
930 available: AtomicBool::new(false),
931 calls: AtomicUsize::new(0),
932 });
933 let pool = WorkerPool::new_inprocess(
934 vec![Arc::clone(&flaky) as Arc<dyn WorkerHandler>],
935 WorkerPoolOptions {
936 queue_capacity: 2,
937 unavailable_after_failures: 1,
938 health_probe_interval: Duration::from_millis(15),
939 default_request_timeout: Some(Duration::from_secs(1)),
940 ..WorkerPoolOptions::default()
941 },
942 None,
943 )
944 .expect("pool should initialize");
945
946 let error = pool
947 .submit(sample_request("down"))
948 .await
949 .expect_err("request should fail while worker is unavailable");
950 assert!(matches!(error, WorkerPoolError::Worker(_)));
951
952 sleep(Duration::from_millis(25)).await;
953 let health_before = pool.health_snapshot().await;
954 assert_eq!(health_before[0].status, WorkerHealthStatus::Unavailable);
955
956 flaky.available.store(true, Ordering::Relaxed);
957 pool.restart_worker("worker-0")
958 .await
959 .expect("restart should succeed");
960
961 sleep(Duration::from_millis(25)).await;
962 let response = pool
963 .submit(sample_request("up"))
964 .await
965 .expect("request should succeed after restart");
966 assert_eq!(
967 response.result,
968 WorkerResult::Success {
969 output: json!({"status": "up"})
970 }
971 );
972
973 pool.shutdown().await;
974 }
975
976 #[tokio::test]
977 async fn returns_timeout_for_slow_worker() {
978 let pool = WorkerPool::new_inprocess(
979 vec![Arc::new(SlowWorker {
980 delay: Duration::from_millis(100),
981 })],
982 WorkerPoolOptions {
983 queue_capacity: 2,
984 default_request_timeout: Some(Duration::from_millis(5)),
985 ..WorkerPoolOptions::default()
986 },
987 None,
988 )
989 .expect("pool should initialize");
990
991 let error = pool
992 .submit(sample_request("timeout"))
993 .await
994 .expect_err("request should time out");
995 assert!(matches!(error, WorkerPoolError::Timeout));
996
997 pool.shutdown().await;
998 }
999
1000 #[tokio::test]
1001 async fn rejects_request_when_security_contract_is_violated() {
1002 let pool = WorkerPool::new_inprocess(
1003 vec![Arc::new(EchoWorker)],
1004 WorkerPoolOptions {
1005 security_policy: WorkerSecurityPolicy {
1006 max_request_timeout_ms: 10,
1007 max_request_payload_bytes: 256,
1008 max_identifier_length: 12,
1009 },
1010 ..WorkerPoolOptions::default()
1011 },
1012 None,
1013 )
1014 .expect("pool should initialize");
1015
1016 let mut request = sample_request("req-too-large");
1017 request.timeout_ms = Some(99);
1018 request.operation = WorkerOperation::Tool {
1019 tool: "echo".to_string(),
1020 input: json!({"payload": "x".repeat(1024)}),
1021 scoped_input: json!({"input": {}}),
1022 };
1023
1024 let error = pool
1025 .submit(request)
1026 .await
1027 .expect_err("request should be rejected by security policy");
1028
1029 assert!(matches!(error, WorkerPoolError::InvalidRequest { .. }));
1030 pool.shutdown().await;
1031 }
1032
1033 #[tokio::test]
1034 async fn handles_parallel_submissions_without_deadlock() {
1035 let pool = Arc::new(
1036 WorkerPool::new_inprocess(
1037 vec![Arc::new(EchoWorker), Arc::new(EchoWorker)],
1038 WorkerPoolOptions {
1039 queue_capacity: 32,
1040 health_probe_interval: Duration::from_millis(5),
1041 default_request_timeout: Some(Duration::from_secs(1)),
1042 ..WorkerPoolOptions::default()
1043 },
1044 None,
1045 )
1046 .expect("pool should initialize"),
1047 );
1048
1049 let mut tasks = Vec::new();
1050 for idx in 0..32usize {
1051 let pool = Arc::clone(&pool);
1052 tasks.push(tokio::spawn(async move {
1053 pool.submit(sample_request(&format!("parallel-{idx}")))
1054 .await
1055 }));
1056 }
1057
1058 let joined = tokio::time::timeout(Duration::from_secs(3), async {
1059 for task in tasks {
1060 let result = task.await.expect("join should succeed");
1061 assert!(result.is_ok(), "submit should succeed under parallel load");
1062 }
1063 })
1064 .await;
1065
1066 assert!(joined.is_ok(), "parallel submissions should not deadlock");
1067 pool.shutdown().await;
1068 }
1069}