simple_agents_workflow/
worker_adapter.rs1use std::sync::atomic::{AtomicU64, Ordering};
2use std::sync::Arc;
3
4use async_trait::async_trait;
5
6use crate::runtime::{ToolExecutionError, ToolExecutionInput, ToolExecutor};
7use crate::worker::{WorkerOperation, WorkerPoolClient, WorkerPoolError, WorkerRequest};
8
9pub struct WorkerPoolToolExecutor {
11 workflow_name: String,
12 timeout_ms: Option<u64>,
13 pool: Arc<dyn WorkerPoolClient>,
14 request_seq: AtomicU64,
15}
16
17impl WorkerPoolToolExecutor {
18 pub fn new(
20 workflow_name: impl Into<String>,
21 timeout_ms: Option<u64>,
22 pool: Arc<dyn WorkerPoolClient>,
23 ) -> Self {
24 Self {
25 workflow_name: workflow_name.into(),
26 timeout_ms,
27 pool,
28 request_seq: AtomicU64::new(0),
29 }
30 }
31
32 fn next_request_id(&self, node_id: &str) -> String {
33 let seq = self.request_seq.fetch_add(1, Ordering::Relaxed);
34 format!("{}-{}", node_id, seq)
35 }
36}
37
38#[async_trait]
39impl ToolExecutor for WorkerPoolToolExecutor {
40 async fn execute_tool(
41 &self,
42 input: ToolExecutionInput,
43 ) -> Result<serde_json::Value, ToolExecutionError> {
44 let request = WorkerRequest {
45 request_id: self.next_request_id(&input.node_id),
46 workflow_name: self.workflow_name.clone(),
47 node_id: input.node_id,
48 timeout_ms: self.timeout_ms,
49 operation: WorkerOperation::Tool {
50 tool: input.tool,
51 input: input.input,
52 scoped_input: input.scoped_input,
53 },
54 };
55
56 let response = self
57 .pool
58 .submit(request)
59 .await
60 .map_err(map_worker_pool_error)?;
61
62 match response.result {
63 crate::worker::WorkerResult::Success { output } => Ok(output),
64 crate::worker::WorkerResult::Error { error } => Err(ToolExecutionError::Failed(
65 format!("{:?}: {}", error.code, error.message),
66 )),
67 }
68 }
69}
70
71fn map_worker_pool_error(error: WorkerPoolError) -> ToolExecutionError {
72 match error {
73 WorkerPoolError::Worker(worker_error) => ToolExecutionError::Failed(worker_error.message),
74 WorkerPoolError::Timeout => {
75 ToolExecutionError::Failed("worker request timed out".to_string())
76 }
77 WorkerPoolError::QueueFull => {
78 ToolExecutionError::Failed("worker queue is full".to_string())
79 }
80 WorkerPoolError::NoHealthyWorker => {
81 ToolExecutionError::Failed("no healthy worker available".to_string())
82 }
83 WorkerPoolError::ShuttingDown => {
84 ToolExecutionError::Failed("worker pool is shutting down".to_string())
85 }
86 WorkerPoolError::CircuitOpen => {
87 ToolExecutionError::Failed("worker circuit is open".to_string())
88 }
89 WorkerPoolError::InvalidRequest { reason } => ToolExecutionError::Failed(format!(
90 "worker request rejected by security contract: {reason}"
91 )),
92 }
93}
94
95#[cfg(test)]
96mod tests {
97 use std::sync::Arc;
98
99 use serde_json::json;
100
101 use super::*;
102 use crate::worker::{
103 WorkerErrorCode, WorkerHealth, WorkerHealthStatus, WorkerPoolError, WorkerProtocolError,
104 WorkerResponse, WorkerResult,
105 };
106
107 struct MockPool;
108
109 #[async_trait]
110 impl WorkerPoolClient for MockPool {
111 async fn submit(&self, request: WorkerRequest) -> Result<WorkerResponse, WorkerPoolError> {
112 if let WorkerOperation::Tool { tool, input, .. } = request.operation {
113 if tool == "fail" {
114 return Err(WorkerPoolError::Worker(WorkerProtocolError {
115 code: WorkerErrorCode::ExecutionFailed,
116 message: "forced failure".to_string(),
117 retryable: false,
118 }));
119 }
120 return Ok(WorkerResponse {
121 request_id: request.request_id,
122 worker_id: "mock-0".to_string(),
123 result: WorkerResult::Success {
124 output: json!({"input": input}),
125 },
126 elapsed_ms: 1,
127 });
128 }
129 unreachable!("test only uses tool requests")
130 }
131
132 async fn health_snapshot(&self) -> Vec<WorkerHealth> {
133 vec![WorkerHealth {
134 worker_id: "mock-0".to_string(),
135 status: WorkerHealthStatus::Healthy,
136 consecutive_failures: 0,
137 last_probe_unix_ms: Some(1),
138 }]
139 }
140 }
141
142 #[tokio::test]
143 async fn executes_tool_through_worker_pool_client() {
144 let executor = WorkerPoolToolExecutor::new("wf", Some(500), Arc::new(MockPool));
145 let output = executor
146 .execute_tool(ToolExecutionInput {
147 node_id: "node-1".to_string(),
148 tool: "echo".to_string(),
149 input: json!({"x": 1}),
150 scoped_input: json!({"input": {"foo": "bar"}}),
151 })
152 .await
153 .expect("worker pool adapter should return output");
154 assert_eq!(output, json!({"input": {"x": 1}}));
155 }
156
157 #[tokio::test]
158 async fn maps_worker_errors_to_tool_errors() {
159 let executor = WorkerPoolToolExecutor::new("wf", Some(500), Arc::new(MockPool));
160 let error = executor
161 .execute_tool(ToolExecutionInput {
162 node_id: "node-1".to_string(),
163 tool: "fail".to_string(),
164 input: json!({}),
165 scoped_input: json!({"input": {}}),
166 })
167 .await
168 .expect_err("worker error should map to tool error");
169
170 assert!(matches!(error, ToolExecutionError::Failed(_)));
171 }
172}