Skip to main content

simple_agents_workflow/
worker.rs

1use std::sync::atomic::{AtomicUsize, Ordering};
2use std::sync::Arc;
3use std::time::{Duration, SystemTime, UNIX_EPOCH};
4
5use async_trait::async_trait;
6use serde::{Deserialize, Serialize};
7use serde_json::Value;
8use thiserror::Error;
9use tokio::sync::{mpsc, oneshot, watch, Mutex, RwLock};
10use tokio::task::JoinHandle;
11use tokio::time::timeout;
12
13/// Worker protocol request payload.
14#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
15pub struct WorkerRequest {
16    /// Stable request identifier supplied by the workflow runtime.
17    pub request_id: String,
18    /// Workflow name associated with this request.
19    pub workflow_name: String,
20    /// Node id associated with this request.
21    pub node_id: String,
22    /// Optional execution timeout budget.
23    pub timeout_ms: Option<u64>,
24    /// Worker operation to execute.
25    pub operation: WorkerOperation,
26}
27
28/// Worker protocol operation variants.
29#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
30#[serde(tag = "kind", rename_all = "snake_case")]
31pub enum WorkerOperation {
32    /// Execute an LLM-like operation against worker host.
33    Llm {
34        /// Model name.
35        model: String,
36        /// Prompt text.
37        prompt: String,
38        /// Deterministic scoped input.
39        scoped_input: Value,
40    },
41    /// Execute a tool operation against worker host.
42    Tool {
43        /// Tool name.
44        tool: String,
45        /// Tool node static input payload.
46        input: Value,
47        /// Deterministic scoped input.
48        scoped_input: Value,
49    },
50}
51
52/// Worker protocol response payload.
53#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
54pub struct WorkerResponse {
55    /// Request identifier for correlation.
56    pub request_id: String,
57    /// Worker id that served the request.
58    pub worker_id: String,
59    /// Result payload.
60    pub result: WorkerResult,
61    /// Wall-clock execution latency.
62    pub elapsed_ms: u64,
63}
64
65/// Worker protocol result variants.
66#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
67#[serde(tag = "status", rename_all = "snake_case")]
68pub enum WorkerResult {
69    /// Successful request execution payload.
70    Success {
71        /// Worker-produced output.
72        output: Value,
73    },
74    /// Failed request execution payload.
75    Error {
76        /// Structured worker error.
77        error: WorkerProtocolError,
78    },
79}
80
81/// Protocol-level error payload returned by workers and pool guards.
82#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
83pub struct WorkerProtocolError {
84    /// Programmatic error code.
85    pub code: WorkerErrorCode,
86    /// Human-readable diagnostic.
87    pub message: String,
88    /// Indicates whether a caller can retry safely.
89    pub retryable: bool,
90}
91
92/// Stable worker error codes.
93#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
94#[serde(rename_all = "snake_case")]
95pub enum WorkerErrorCode {
96    /// Request could not be accepted by the pool queue.
97    QueueFull,
98    /// Worker was unavailable.
99    Unavailable,
100    /// Request timed out.
101    Timeout,
102    /// Request failed during worker execution.
103    ExecutionFailed,
104    /// Request rejected by circuit breaker hooks.
105    CircuitOpen,
106    /// Request cancelled before completion.
107    Cancelled,
108    /// Request violated runtime security policy.
109    InvalidRequest,
110}
111
112/// Pool-level errors surfaced to runtime callers.
113#[derive(Debug, Error, Clone, PartialEq, Eq)]
114pub enum WorkerPoolError {
115    /// Worker queue reached configured limit.
116    #[error("worker queue is full")]
117    QueueFull,
118    /// No healthy worker could accept the request.
119    #[error("no healthy worker available")]
120    NoHealthyWorker,
121    /// Worker-side execution failed.
122    #[error("worker execution failed: {0:?}")]
123    Worker(WorkerProtocolError),
124    /// Request timed out while waiting for worker completion.
125    #[error("worker request timed out")]
126    Timeout,
127    /// Request was cancelled because pool is shutting down.
128    #[error("worker pool is shutting down")]
129    ShuttingDown,
130    /// Request rejected by circuit breaker hook.
131    #[error("request rejected by circuit breaker")]
132    CircuitOpen,
133    /// Request violated runtime security policy.
134    #[error("worker request rejected: {reason}")]
135    InvalidRequest {
136        /// Rejection reason.
137        reason: String,
138    },
139}
140
141/// Worker lifecycle and scheduling configuration.
142#[derive(Debug, Clone)]
143pub struct WorkerPoolOptions {
144    /// Maximum queue depth per worker.
145    pub queue_capacity: usize,
146    /// Interval between worker health probes.
147    pub health_probe_interval: Duration,
148    /// Consecutive failures needed before marking worker unavailable.
149    pub unavailable_after_failures: u32,
150    /// Default timeout used when request timeout is not set.
151    pub default_request_timeout: Option<Duration>,
152    /// Security guards for request payload and budgets.
153    pub security_policy: WorkerSecurityPolicy,
154}
155
156/// Request hardening limits for worker pool submit path.
157#[derive(Debug, Clone)]
158pub struct WorkerSecurityPolicy {
159    /// Maximum request timeout accepted by pool.
160    pub max_request_timeout_ms: u64,
161    /// Maximum serialized request payload size in bytes.
162    pub max_request_payload_bytes: usize,
163    /// Maximum length for request/workflow/node identifiers.
164    pub max_identifier_length: usize,
165}
166
167impl Default for WorkerPoolOptions {
168    fn default() -> Self {
169        Self {
170            queue_capacity: 64,
171            health_probe_interval: Duration::from_secs(5),
172            unavailable_after_failures: 3,
173            default_request_timeout: Some(Duration::from_secs(30)),
174            security_policy: WorkerSecurityPolicy::default(),
175        }
176    }
177}
178
179impl Default for WorkerSecurityPolicy {
180    fn default() -> Self {
181        Self {
182            max_request_timeout_ms: 120_000,
183            max_request_payload_bytes: 256 * 1024,
184            max_identifier_length: 128,
185        }
186    }
187}
188
189/// Health state for one pool worker.
190#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
191pub struct WorkerHealth {
192    /// Worker id.
193    pub worker_id: String,
194    /// Current status.
195    pub status: WorkerHealthStatus,
196    /// Number of consecutive failures observed.
197    pub consecutive_failures: u32,
198    /// Last probe timestamp in unix milliseconds.
199    pub last_probe_unix_ms: Option<u64>,
200}
201
202/// Coarse worker health statuses used by scheduler.
203#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
204#[serde(rename_all = "snake_case")]
205pub enum WorkerHealthStatus {
206    /// Worker is healthy and can receive traffic.
207    Healthy,
208    /// Worker has transient failures and may still receive traffic.
209    Degraded,
210    /// Worker is unavailable and should not receive traffic.
211    Unavailable,
212}
213
214impl WorkerHealth {
215    fn new(worker_id: String) -> Self {
216        Self {
217            worker_id,
218            status: WorkerHealthStatus::Healthy,
219            consecutive_failures: 0,
220            last_probe_unix_ms: None,
221        }
222    }
223
224    fn is_schedulable(&self) -> bool {
225        !matches!(self.status, WorkerHealthStatus::Unavailable)
226    }
227}
228
229/// Host-implemented worker behavior for in-process worker pools.
230#[async_trait]
231pub trait WorkerHandler: Send + Sync {
232    /// Handles one worker protocol request.
233    async fn handle(&self, request: WorkerRequest) -> Result<Value, WorkerProtocolError>;
234
235    /// Performs worker health probe.
236    async fn probe_health(&self) -> WorkerHealthStatus {
237        WorkerHealthStatus::Healthy
238    }
239}
240
241/// Optional hook surface for circuit breaker integration.
242#[async_trait]
243pub trait CircuitBreakerHooks: Send + Sync {
244    /// Returns false to reject a request before queueing.
245    fn allow_request(&self, _worker_id: &str, _request: &WorkerRequest) -> bool {
246        true
247    }
248
249    /// Called after a request is accepted for execution.
250    async fn on_request_accepted(&self, _worker_id: &str, _request: &WorkerRequest) {}
251
252    /// Called after a request succeeds.
253    async fn on_request_success(&self, _worker_id: &str, _response: &WorkerResponse) {}
254
255    /// Called after a request fails.
256    async fn on_request_failure(&self, _worker_id: &str, _error: &WorkerProtocolError) {}
257
258    /// Called after a request is rejected.
259    async fn on_request_rejected(
260        &self,
261        _worker_id: Option<&str>,
262        _request: &WorkerRequest,
263        _reason: WorkerErrorCode,
264    ) {
265    }
266}
267
268struct WorkItem {
269    request: WorkerRequest,
270    response_tx: oneshot::Sender<Result<WorkerResponse, WorkerPoolError>>,
271}
272
273type WorkerResponseRx = oneshot::Receiver<Result<WorkerResponse, WorkerPoolError>>;
274type WorkerCandidate = (usize, String, mpsc::Sender<WorkItem>);
275type CandidateWithHealth = (
276    usize,
277    String,
278    mpsc::Sender<WorkItem>,
279    Arc<RwLock<WorkerHealth>>,
280);
281
282struct WorkerSlot {
283    worker_id: String,
284    sender: mpsc::Sender<WorkItem>,
285    shutdown_tx: watch::Sender<bool>,
286    worker_task: JoinHandle<()>,
287    probe_task: JoinHandle<()>,
288    health: Arc<RwLock<WorkerHealth>>,
289    handler: Arc<dyn WorkerHandler>,
290}
291
292/// In-process worker pool with bounded queues and health-aware routing.
293pub struct WorkerPool {
294    options: WorkerPoolOptions,
295    slots: Mutex<Vec<WorkerSlot>>,
296    next_worker: AtomicUsize,
297    hooks: Option<Arc<dyn CircuitBreakerHooks>>,
298}
299
300/// Adapter trait for worker pools used by runtime integrations.
301#[async_trait]
302pub trait WorkerPoolClient: Send + Sync {
303    /// Submits one request to the underlying worker pool.
304    async fn submit(&self, request: WorkerRequest) -> Result<WorkerResponse, WorkerPoolError>;
305
306    /// Returns a snapshot of worker health state.
307    async fn health_snapshot(&self) -> Vec<WorkerHealth>;
308}
309
310#[async_trait]
311impl WorkerPoolClient for WorkerPool {
312    async fn submit(&self, request: WorkerRequest) -> Result<WorkerResponse, WorkerPoolError> {
313        WorkerPool::submit(self, request).await
314    }
315
316    async fn health_snapshot(&self) -> Vec<WorkerHealth> {
317        WorkerPool::health_snapshot(self).await
318    }
319}
320
321impl WorkerPool {
322    /// Creates and starts an in-process worker pool.
323    pub fn new_inprocess(
324        handlers: Vec<Arc<dyn WorkerHandler>>,
325        options: WorkerPoolOptions,
326        hooks: Option<Arc<dyn CircuitBreakerHooks>>,
327    ) -> Result<Self, WorkerPoolError> {
328        if handlers.is_empty() {
329            return Err(WorkerPoolError::NoHealthyWorker);
330        }
331
332        let mut slots = Vec::with_capacity(handlers.len());
333        for (index, handler) in handlers.into_iter().enumerate() {
334            let worker_id = format!("worker-{}", index);
335            slots.push(spawn_worker_slot(
336                worker_id,
337                handler,
338                options.queue_capacity,
339                options.health_probe_interval,
340                options.unavailable_after_failures,
341            ));
342        }
343
344        Ok(Self {
345            options,
346            slots: Mutex::new(slots),
347            next_worker: AtomicUsize::new(0),
348            hooks,
349        })
350    }
351
352    /// Submits one request to the pool and waits for completion.
353    pub async fn submit(&self, request: WorkerRequest) -> Result<WorkerResponse, WorkerPoolError> {
354        validate_request_contract(&request, &self.options.security_policy)?;
355        let candidates = self.select_worker_candidates(&request).await?;
356        let mut saw_queue_full = false;
357        let mut saw_circuit_open = false;
358
359        let mut selected_slot: Option<(usize, String, WorkerResponseRx)> = None;
360
361        for (slot_index, worker_id, sender) in candidates {
362            if let Some(hooks) = &self.hooks {
363                if !hooks.allow_request(&worker_id, &request) {
364                    saw_circuit_open = true;
365                    hooks
366                        .on_request_rejected(
367                            Some(&worker_id),
368                            &request,
369                            WorkerErrorCode::CircuitOpen,
370                        )
371                        .await;
372                    continue;
373                }
374                hooks.on_request_accepted(&worker_id, &request).await;
375            }
376
377            let (response_tx, response_rx) = oneshot::channel();
378            let work_item = WorkItem {
379                request: request.clone(),
380                response_tx,
381            };
382
383            match sender.try_send(work_item) {
384                Ok(()) => {
385                    selected_slot = Some((slot_index, worker_id, response_rx));
386                    break;
387                }
388                Err(tokio::sync::mpsc::error::TrySendError::Full(_)) => {
389                    saw_queue_full = true;
390                    if let Some(hooks) = &self.hooks {
391                        hooks
392                            .on_request_rejected(
393                                Some(&worker_id),
394                                &request,
395                                WorkerErrorCode::QueueFull,
396                            )
397                            .await;
398                    }
399                }
400                Err(tokio::sync::mpsc::error::TrySendError::Closed(_)) => {
401                    if let Some(hooks) = &self.hooks {
402                        hooks
403                            .on_request_rejected(
404                                Some(&worker_id),
405                                &request,
406                                WorkerErrorCode::Unavailable,
407                            )
408                            .await;
409                    }
410                }
411            }
412        }
413
414        let Some((slot_index, worker_id, response_rx)) = selected_slot else {
415            return if saw_queue_full {
416                Err(WorkerPoolError::QueueFull)
417            } else if saw_circuit_open {
418                Err(WorkerPoolError::CircuitOpen)
419            } else {
420                Err(WorkerPoolError::NoHealthyWorker)
421            };
422        };
423
424        let timeout_budget = request
425            .timeout_ms
426            .map(Duration::from_millis)
427            .or(self.options.default_request_timeout);
428
429        let outcome = if let Some(duration) = timeout_budget {
430            match timeout(duration, response_rx).await {
431                Ok(result) => result,
432                Err(_) => {
433                    self.mark_unavailable(slot_index).await;
434                    if let Some(hooks) = &self.hooks {
435                        hooks
436                            .on_request_rejected(
437                                Some(&worker_id),
438                                &request,
439                                WorkerErrorCode::Timeout,
440                            )
441                            .await;
442                    }
443                    return Err(WorkerPoolError::Timeout);
444                }
445            }
446        } else {
447            response_rx.await
448        };
449
450        let response = outcome.map_err(|_| WorkerPoolError::ShuttingDown)??;
451
452        if let Some(hooks) = &self.hooks {
453            match &response.result {
454                WorkerResult::Success { .. } => {
455                    hooks.on_request_success(&worker_id, &response).await
456                }
457                WorkerResult::Error { error } => hooks.on_request_failure(&worker_id, error).await,
458            }
459        }
460
461        match &response.result {
462            WorkerResult::Success { .. } => Ok(response),
463            WorkerResult::Error { error } => Err(WorkerPoolError::Worker(error.clone())),
464        }
465    }
466
467    /// Returns current worker health snapshot.
468    pub async fn health_snapshot(&self) -> Vec<WorkerHealth> {
469        let (health_refs, worker_count) = {
470            let slots = self.slots.lock().await;
471            (
472                slots
473                    .iter()
474                    .map(|slot| Arc::clone(&slot.health))
475                    .collect::<Vec<_>>(),
476                slots.len(),
477            )
478        };
479
480        let mut snapshot = Vec::with_capacity(worker_count);
481        for health in health_refs {
482            snapshot.push(health.read().await.clone());
483        }
484        snapshot
485    }
486
487    /// Restarts one worker task in place and resets health counters.
488    pub async fn restart_worker(&self, worker_id: &str) -> Result<(), WorkerPoolError> {
489        let mut slots = self.slots.lock().await;
490        let slot_index = slots
491            .iter()
492            .position(|slot| slot.worker_id == worker_id)
493            .ok_or(WorkerPoolError::NoHealthyWorker)?;
494
495        let old_slot = &slots[slot_index];
496        let _ = old_slot.shutdown_tx.send(true);
497
498        let replacement = spawn_worker_slot(
499            worker_id.to_string(),
500            Arc::clone(&old_slot.handler),
501            self.options.queue_capacity,
502            self.options.health_probe_interval,
503            self.options.unavailable_after_failures,
504        );
505        slots[slot_index] = replacement;
506        Ok(())
507    }
508
509    /// Gracefully shuts down all worker and probe tasks.
510    pub async fn shutdown(&self) {
511        let mut slots = self.slots.lock().await;
512        for slot in slots.iter_mut() {
513            let _ = slot.shutdown_tx.send(true);
514            slot.worker_task.abort();
515            slot.probe_task.abort();
516        }
517    }
518
519    async fn select_worker_candidates(
520        &self,
521        request: &WorkerRequest,
522    ) -> Result<Vec<WorkerCandidate>, WorkerPoolError> {
523        let candidates = {
524            let slots = self.slots.lock().await;
525            if slots.is_empty() {
526                Vec::<CandidateWithHealth>::new()
527            } else {
528                let start = self.next_worker.fetch_add(1, Ordering::Relaxed) % slots.len();
529                let mut candidates = Vec::<CandidateWithHealth>::with_capacity(slots.len());
530                for offset in 0..slots.len() {
531                    let idx = (start + offset) % slots.len();
532                    let slot = &slots[idx];
533                    candidates.push((
534                        idx,
535                        slot.worker_id.clone(),
536                        slot.sender.clone(),
537                        Arc::clone(&slot.health),
538                    ));
539                }
540                candidates
541            }
542        };
543
544        if candidates.is_empty() {
545            if let Some(hooks) = &self.hooks {
546                hooks
547                    .on_request_rejected(None, request, WorkerErrorCode::Unavailable)
548                    .await;
549            }
550            return Err(WorkerPoolError::NoHealthyWorker);
551        }
552
553        let mut schedulable = Vec::new();
554        for (idx, worker_id, sender, health_ref) in candidates {
555            if health_ref.read().await.is_schedulable() {
556                schedulable.push((idx, worker_id, sender));
557            }
558        }
559
560        if !schedulable.is_empty() {
561            return Ok(schedulable);
562        }
563
564        if let Some(hooks) = &self.hooks {
565            hooks
566                .on_request_rejected(None, request, WorkerErrorCode::Unavailable)
567                .await;
568        }
569        Err(WorkerPoolError::NoHealthyWorker)
570    }
571
572    async fn mark_unavailable(&self, slot_index: usize) {
573        let health_ref = {
574            let slots = self.slots.lock().await;
575            slots.get(slot_index).map(|slot| Arc::clone(&slot.health))
576        };
577
578        if let Some(health_ref) = health_ref {
579            let mut health = health_ref.write().await;
580            health.status = WorkerHealthStatus::Unavailable;
581            health.consecutive_failures = health.consecutive_failures.saturating_add(1);
582            health.last_probe_unix_ms = Some(now_unix_ms());
583        }
584    }
585}
586
587fn spawn_worker_slot(
588    worker_id: String,
589    handler: Arc<dyn WorkerHandler>,
590    queue_capacity: usize,
591    probe_interval: Duration,
592    unavailable_after_failures: u32,
593) -> WorkerSlot {
594    let (sender, mut receiver) = mpsc::channel::<WorkItem>(queue_capacity);
595    let (shutdown_tx, shutdown_rx) = watch::channel(false);
596    let health = Arc::new(RwLock::new(WorkerHealth::new(worker_id.clone())));
597
598    let worker_id_for_loop = worker_id.clone();
599    let handler_for_loop = Arc::clone(&handler);
600    let health_for_loop = Arc::clone(&health);
601    let mut shutdown_worker_rx = shutdown_rx.clone();
602    let worker_task = tokio::spawn(async move {
603        loop {
604            tokio::select! {
605                maybe_item = receiver.recv() => {
606                    let Some(item) = maybe_item else {
607                        break;
608                    };
609
610                    let started = std::time::Instant::now();
611                    let result = handler_for_loop.handle(item.request.clone()).await;
612                    let elapsed_ms = started.elapsed().as_millis() as u64;
613                    let response = match result {
614                        Ok(output) => {
615                            update_health_on_success(&health_for_loop).await;
616                            WorkerResponse {
617                                request_id: item.request.request_id.clone(),
618                                worker_id: worker_id_for_loop.clone(),
619                                result: WorkerResult::Success { output },
620                                elapsed_ms,
621                            }
622                        }
623                        Err(error) => {
624                            update_health_on_failure(
625                                &health_for_loop,
626                                unavailable_after_failures,
627                            )
628                            .await;
629                            WorkerResponse {
630                                request_id: item.request.request_id.clone(),
631                                worker_id: worker_id_for_loop.clone(),
632                                result: WorkerResult::Error { error },
633                                elapsed_ms,
634                            }
635                        }
636                    };
637                    let _ = item.response_tx.send(Ok(response));
638                }
639                changed = shutdown_worker_rx.changed() => {
640                    if changed.is_ok() && *shutdown_worker_rx.borrow() {
641                        break;
642                    }
643                }
644            }
645        }
646    });
647
648    let worker_id_for_probe = worker_id.clone();
649    let handler_for_probe = Arc::clone(&handler);
650    let health_for_probe = Arc::clone(&health);
651    let mut shutdown_probe_rx = shutdown_rx.clone();
652    let probe_task = tokio::spawn(async move {
653        let mut ticker = tokio::time::interval(probe_interval);
654        loop {
655            tokio::select! {
656                _ = ticker.tick() => {
657                    let status = handler_for_probe.probe_health().await;
658                    let mut health = health_for_probe.write().await;
659                    health.status = status;
660                    if status == WorkerHealthStatus::Healthy {
661                        health.consecutive_failures = 0;
662                    }
663                    health.last_probe_unix_ms = Some(now_unix_ms());
664                }
665                changed = shutdown_probe_rx.changed() => {
666                    if changed.is_ok() && *shutdown_probe_rx.borrow() {
667                        break;
668                    }
669                }
670            }
671        }
672        let mut health = health_for_probe.write().await;
673        health.status = WorkerHealthStatus::Unavailable;
674        health.last_probe_unix_ms = Some(now_unix_ms());
675        health.worker_id = worker_id_for_probe;
676    });
677
678    WorkerSlot {
679        worker_id,
680        sender,
681        shutdown_tx,
682        worker_task,
683        probe_task,
684        health,
685        handler,
686    }
687}
688
689async fn update_health_on_success(health_ref: &Arc<RwLock<WorkerHealth>>) {
690    let mut health = health_ref.write().await;
691    health.status = WorkerHealthStatus::Healthy;
692    health.consecutive_failures = 0;
693    health.last_probe_unix_ms = Some(now_unix_ms());
694}
695
696async fn update_health_on_failure(
697    health_ref: &Arc<RwLock<WorkerHealth>>,
698    unavailable_after_failures: u32,
699) {
700    let mut health = health_ref.write().await;
701    health.consecutive_failures = health.consecutive_failures.saturating_add(1);
702    health.status = if health.consecutive_failures >= unavailable_after_failures {
703        WorkerHealthStatus::Unavailable
704    } else {
705        WorkerHealthStatus::Degraded
706    };
707    health.last_probe_unix_ms = Some(now_unix_ms());
708}
709
710fn now_unix_ms() -> u64 {
711    SystemTime::now()
712        .duration_since(UNIX_EPOCH)
713        .unwrap_or_default()
714        .as_millis() as u64
715}
716
717fn validate_request_contract(
718    request: &WorkerRequest,
719    policy: &WorkerSecurityPolicy,
720) -> Result<(), WorkerPoolError> {
721    if request.request_id.len() > policy.max_identifier_length {
722        return Err(WorkerPoolError::InvalidRequest {
723            reason: format!(
724                "request_id length {} exceeds max {}",
725                request.request_id.len(),
726                policy.max_identifier_length
727            ),
728        });
729    }
730    if request.workflow_name.len() > policy.max_identifier_length {
731        return Err(WorkerPoolError::InvalidRequest {
732            reason: format!(
733                "workflow_name length {} exceeds max {}",
734                request.workflow_name.len(),
735                policy.max_identifier_length
736            ),
737        });
738    }
739    if request.node_id.len() > policy.max_identifier_length {
740        return Err(WorkerPoolError::InvalidRequest {
741            reason: format!(
742                "node_id length {} exceeds max {}",
743                request.node_id.len(),
744                policy.max_identifier_length
745            ),
746        });
747    }
748    if let Some(timeout_ms) = request.timeout_ms {
749        if timeout_ms > policy.max_request_timeout_ms {
750            return Err(WorkerPoolError::InvalidRequest {
751                reason: format!(
752                    "timeout_ms {} exceeds max {}",
753                    timeout_ms, policy.max_request_timeout_ms
754                ),
755            });
756        }
757    }
758
759    let payload_size = estimate_payload_size(request);
760    if payload_size > policy.max_request_payload_bytes {
761        return Err(WorkerPoolError::InvalidRequest {
762            reason: format!(
763                "request payload {} bytes exceeds max {}",
764                payload_size, policy.max_request_payload_bytes
765            ),
766        });
767    }
768    Ok(())
769}
770
771fn estimate_payload_size(request: &WorkerRequest) -> usize {
772    serde_json::to_vec(request)
773        .map(|payload| payload.len())
774        .unwrap_or(usize::MAX)
775}
776
777#[cfg(test)]
778mod tests {
779    use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
780
781    use serde_json::json;
782    use tokio::time::sleep;
783
784    use super::*;
785
786    struct EchoWorker;
787
788    #[async_trait]
789    impl WorkerHandler for EchoWorker {
790        async fn handle(&self, request: WorkerRequest) -> Result<Value, WorkerProtocolError> {
791            Ok(json!({"node": request.node_id}))
792        }
793    }
794
795    struct SlowWorker {
796        delay: Duration,
797    }
798
799    #[async_trait]
800    impl WorkerHandler for SlowWorker {
801        async fn handle(&self, _request: WorkerRequest) -> Result<Value, WorkerProtocolError> {
802            sleep(self.delay).await;
803            Ok(json!({"status": "ok"}))
804        }
805    }
806
807    struct FlakyWorker {
808        available: AtomicBool,
809        calls: AtomicUsize,
810    }
811
812    #[async_trait]
813    impl WorkerHandler for FlakyWorker {
814        async fn handle(&self, _request: WorkerRequest) -> Result<Value, WorkerProtocolError> {
815            self.calls.fetch_add(1, Ordering::Relaxed);
816            if self.available.load(Ordering::Relaxed) {
817                Ok(json!({"status": "up"}))
818            } else {
819                Err(WorkerProtocolError {
820                    code: WorkerErrorCode::Unavailable,
821                    message: "worker unavailable".to_string(),
822                    retryable: true,
823                })
824            }
825        }
826
827        async fn probe_health(&self) -> WorkerHealthStatus {
828            if self.available.load(Ordering::Relaxed) {
829                WorkerHealthStatus::Healthy
830            } else {
831                WorkerHealthStatus::Unavailable
832            }
833        }
834    }
835
836    fn sample_request(id: &str) -> WorkerRequest {
837        WorkerRequest {
838            request_id: id.to_string(),
839            workflow_name: "wf".to_string(),
840            node_id: "node-1".to_string(),
841            timeout_ms: None,
842            operation: WorkerOperation::Tool {
843                tool: "echo".to_string(),
844                input: json!({"x": 1}),
845                scoped_input: json!({"input": {}}),
846            },
847        }
848    }
849
850    #[test]
851    fn worker_protocol_roundtrip() {
852        let request = sample_request("req-1");
853        let serialized =
854            serde_json::to_string(&request).expect("request serialization should work");
855        let decoded: WorkerRequest =
856            serde_json::from_str(&serialized).expect("request deserialization should work");
857        assert_eq!(request, decoded);
858    }
859
860    #[tokio::test]
861    async fn routes_requests_across_worker_pool() {
862        let pool = WorkerPool::new_inprocess(
863            vec![Arc::new(EchoWorker), Arc::new(EchoWorker)],
864            WorkerPoolOptions {
865                queue_capacity: 4,
866                health_probe_interval: Duration::from_millis(10),
867                ..WorkerPoolOptions::default()
868            },
869            None,
870        )
871        .expect("pool should initialize");
872
873        let response = pool
874            .submit(sample_request("req-2"))
875            .await
876            .expect("request should succeed");
877        assert_eq!(response.request_id, "req-2");
878        assert_eq!(
879            response.result,
880            WorkerResult::Success {
881                output: json!({"node": "node-1"})
882            }
883        );
884
885        let health = pool.health_snapshot().await;
886        assert_eq!(health.len(), 2);
887        assert!(health.iter().all(|entry| entry.is_schedulable()));
888
889        pool.shutdown().await;
890    }
891
892    #[tokio::test]
893    async fn enforces_queue_backpressure_limits() {
894        let pool = WorkerPool::new_inprocess(
895            vec![Arc::new(SlowWorker {
896                delay: Duration::from_millis(80),
897            })],
898            WorkerPoolOptions {
899                queue_capacity: 1,
900                health_probe_interval: Duration::from_millis(100),
901                default_request_timeout: Some(Duration::from_secs(1)),
902                ..WorkerPoolOptions::default()
903            },
904            None,
905        )
906        .expect("pool should initialize");
907
908        let first = pool.submit(sample_request("q1"));
909        let second = pool.submit(sample_request("q2"));
910        let third = pool.submit(sample_request("q3"));
911
912        let (first_result, second_result, third_result) = tokio::join!(first, second, third);
913        let failures = [&first_result, &second_result, &third_result]
914            .iter()
915            .filter(|result| matches!(result, Err(WorkerPoolError::QueueFull)))
916            .count();
917        let successes = [&first_result, &second_result, &third_result]
918            .iter()
919            .filter(|result| result.is_ok())
920            .count();
921        assert!(failures >= 1);
922        assert!(successes >= 1);
923
924        pool.shutdown().await;
925    }
926
927    #[tokio::test]
928    async fn marks_worker_unavailable_after_failures_and_recovers_on_restart() {
929        let flaky = Arc::new(FlakyWorker {
930            available: AtomicBool::new(false),
931            calls: AtomicUsize::new(0),
932        });
933        let pool = WorkerPool::new_inprocess(
934            vec![Arc::clone(&flaky) as Arc<dyn WorkerHandler>],
935            WorkerPoolOptions {
936                queue_capacity: 2,
937                unavailable_after_failures: 1,
938                health_probe_interval: Duration::from_millis(15),
939                default_request_timeout: Some(Duration::from_secs(1)),
940                ..WorkerPoolOptions::default()
941            },
942            None,
943        )
944        .expect("pool should initialize");
945
946        let error = pool
947            .submit(sample_request("down"))
948            .await
949            .expect_err("request should fail while worker is unavailable");
950        assert!(matches!(error, WorkerPoolError::Worker(_)));
951
952        sleep(Duration::from_millis(25)).await;
953        let health_before = pool.health_snapshot().await;
954        assert_eq!(health_before[0].status, WorkerHealthStatus::Unavailable);
955
956        flaky.available.store(true, Ordering::Relaxed);
957        pool.restart_worker("worker-0")
958            .await
959            .expect("restart should succeed");
960
961        sleep(Duration::from_millis(25)).await;
962        let response = pool
963            .submit(sample_request("up"))
964            .await
965            .expect("request should succeed after restart");
966        assert_eq!(
967            response.result,
968            WorkerResult::Success {
969                output: json!({"status": "up"})
970            }
971        );
972
973        pool.shutdown().await;
974    }
975
976    #[tokio::test]
977    async fn returns_timeout_for_slow_worker() {
978        let pool = WorkerPool::new_inprocess(
979            vec![Arc::new(SlowWorker {
980                delay: Duration::from_millis(100),
981            })],
982            WorkerPoolOptions {
983                queue_capacity: 2,
984                default_request_timeout: Some(Duration::from_millis(5)),
985                ..WorkerPoolOptions::default()
986            },
987            None,
988        )
989        .expect("pool should initialize");
990
991        let error = pool
992            .submit(sample_request("timeout"))
993            .await
994            .expect_err("request should time out");
995        assert!(matches!(error, WorkerPoolError::Timeout));
996
997        pool.shutdown().await;
998    }
999
1000    #[tokio::test]
1001    async fn rejects_request_when_security_contract_is_violated() {
1002        let pool = WorkerPool::new_inprocess(
1003            vec![Arc::new(EchoWorker)],
1004            WorkerPoolOptions {
1005                security_policy: WorkerSecurityPolicy {
1006                    max_request_timeout_ms: 10,
1007                    max_request_payload_bytes: 256,
1008                    max_identifier_length: 12,
1009                },
1010                ..WorkerPoolOptions::default()
1011            },
1012            None,
1013        )
1014        .expect("pool should initialize");
1015
1016        let mut request = sample_request("req-too-large");
1017        request.timeout_ms = Some(99);
1018        request.operation = WorkerOperation::Tool {
1019            tool: "echo".to_string(),
1020            input: json!({"payload": "x".repeat(1024)}),
1021            scoped_input: json!({"input": {}}),
1022        };
1023
1024        let error = pool
1025            .submit(request)
1026            .await
1027            .expect_err("request should be rejected by security policy");
1028
1029        assert!(matches!(error, WorkerPoolError::InvalidRequest { .. }));
1030        pool.shutdown().await;
1031    }
1032
1033    #[tokio::test]
1034    async fn handles_parallel_submissions_without_deadlock() {
1035        let pool = Arc::new(
1036            WorkerPool::new_inprocess(
1037                vec![Arc::new(EchoWorker), Arc::new(EchoWorker)],
1038                WorkerPoolOptions {
1039                    queue_capacity: 32,
1040                    health_probe_interval: Duration::from_millis(5),
1041                    default_request_timeout: Some(Duration::from_secs(1)),
1042                    ..WorkerPoolOptions::default()
1043                },
1044                None,
1045            )
1046            .expect("pool should initialize"),
1047        );
1048
1049        let mut tasks = Vec::new();
1050        for idx in 0..32usize {
1051            let pool = Arc::clone(&pool);
1052            tasks.push(tokio::spawn(async move {
1053                pool.submit(sample_request(&format!("parallel-{idx}")))
1054                    .await
1055            }));
1056        }
1057
1058        let joined = tokio::time::timeout(Duration::from_secs(3), async {
1059            for task in tasks {
1060                let result = task.await.expect("join should succeed");
1061                assert!(result.is_ok(), "submit should succeed under parallel load");
1062            }
1063        })
1064        .await;
1065
1066        assert!(joined.is_ok(), "parallel submissions should not deadlock");
1067        pool.shutdown().await;
1068    }
1069}