Skip to main content

simple_agents_workflow/
worker_adapter.rs

1use std::sync::atomic::{AtomicU64, Ordering};
2use std::sync::Arc;
3
4use async_trait::async_trait;
5
6use crate::runtime::{ToolExecutionError, ToolExecutionInput, ToolExecutor};
7use crate::worker::{WorkerOperation, WorkerPoolClient, WorkerPoolError, WorkerRequest};
8
9/// Runtime tool-executor adapter backed by a worker pool client.
10pub struct WorkerPoolToolExecutor {
11    workflow_name: String,
12    timeout_ms: Option<u64>,
13    pool: Arc<dyn WorkerPoolClient>,
14    request_seq: AtomicU64,
15}
16
17impl WorkerPoolToolExecutor {
18    /// Creates a new worker-pool-backed tool executor.
19    pub fn new(
20        workflow_name: impl Into<String>,
21        timeout_ms: Option<u64>,
22        pool: Arc<dyn WorkerPoolClient>,
23    ) -> Self {
24        Self {
25            workflow_name: workflow_name.into(),
26            timeout_ms,
27            pool,
28            request_seq: AtomicU64::new(0),
29        }
30    }
31
32    fn next_request_id(&self, node_id: &str) -> String {
33        let seq = self.request_seq.fetch_add(1, Ordering::Relaxed);
34        format!("{}-{}", node_id, seq)
35    }
36}
37
38#[async_trait]
39impl ToolExecutor for WorkerPoolToolExecutor {
40    async fn execute_tool(
41        &self,
42        input: ToolExecutionInput,
43    ) -> Result<serde_json::Value, ToolExecutionError> {
44        let request = WorkerRequest {
45            request_id: self.next_request_id(&input.node_id),
46            workflow_name: self.workflow_name.clone(),
47            node_id: input.node_id,
48            timeout_ms: self.timeout_ms,
49            operation: WorkerOperation::Tool {
50                tool: input.tool,
51                input: input.input,
52                scoped_input: input.scoped_input,
53            },
54        };
55
56        let response = self
57            .pool
58            .submit(request)
59            .await
60            .map_err(map_worker_pool_error)?;
61
62        match response.result {
63            crate::worker::WorkerResult::Success { output } => Ok(output),
64            crate::worker::WorkerResult::Error { error } => Err(ToolExecutionError::Failed(
65                format!("{:?}: {}", error.code, error.message),
66            )),
67        }
68    }
69}
70
71fn map_worker_pool_error(error: WorkerPoolError) -> ToolExecutionError {
72    match error {
73        WorkerPoolError::Worker(worker_error) => ToolExecutionError::Failed(worker_error.message),
74        WorkerPoolError::Timeout => {
75            ToolExecutionError::Failed("worker request timed out".to_string())
76        }
77        WorkerPoolError::QueueFull => {
78            ToolExecutionError::Failed("worker queue is full".to_string())
79        }
80        WorkerPoolError::NoHealthyWorker => {
81            ToolExecutionError::Failed("no healthy worker available".to_string())
82        }
83        WorkerPoolError::ShuttingDown => {
84            ToolExecutionError::Failed("worker pool is shutting down".to_string())
85        }
86        WorkerPoolError::CircuitOpen => {
87            ToolExecutionError::Failed("worker circuit is open".to_string())
88        }
89        WorkerPoolError::InvalidRequest { reason } => ToolExecutionError::Failed(format!(
90            "worker request rejected by security contract: {reason}"
91        )),
92    }
93}
94
95#[cfg(test)]
96mod tests {
97    use std::sync::Arc;
98
99    use serde_json::json;
100
101    use super::*;
102    use crate::worker::{
103        WorkerErrorCode, WorkerHealth, WorkerHealthStatus, WorkerPoolError, WorkerProtocolError,
104        WorkerResponse, WorkerResult,
105    };
106
107    struct MockPool;
108
109    #[async_trait]
110    impl WorkerPoolClient for MockPool {
111        async fn submit(&self, request: WorkerRequest) -> Result<WorkerResponse, WorkerPoolError> {
112            if let WorkerOperation::Tool { tool, input, .. } = request.operation {
113                if tool == "fail" {
114                    return Err(WorkerPoolError::Worker(WorkerProtocolError {
115                        code: WorkerErrorCode::ExecutionFailed,
116                        message: "forced failure".to_string(),
117                        retryable: false,
118                    }));
119                }
120                return Ok(WorkerResponse {
121                    request_id: request.request_id,
122                    worker_id: "mock-0".to_string(),
123                    result: WorkerResult::Success {
124                        output: json!({"input": input}),
125                    },
126                    elapsed_ms: 1,
127                });
128            }
129            unreachable!("test only uses tool requests")
130        }
131
132        async fn health_snapshot(&self) -> Vec<WorkerHealth> {
133            vec![WorkerHealth {
134                worker_id: "mock-0".to_string(),
135                status: WorkerHealthStatus::Healthy,
136                consecutive_failures: 0,
137                last_probe_unix_ms: Some(1),
138            }]
139        }
140    }
141
142    #[tokio::test]
143    async fn executes_tool_through_worker_pool_client() {
144        let executor = WorkerPoolToolExecutor::new("wf", Some(500), Arc::new(MockPool));
145        let output = executor
146            .execute_tool(ToolExecutionInput {
147                node_id: "node-1".to_string(),
148                tool: "echo".to_string(),
149                input: json!({"x": 1}),
150                scoped_input: json!({"input": {"foo": "bar"}}),
151            })
152            .await
153            .expect("worker pool adapter should return output");
154        assert_eq!(output, json!({"input": {"x": 1}}));
155    }
156
157    #[tokio::test]
158    async fn maps_worker_errors_to_tool_errors() {
159        let executor = WorkerPoolToolExecutor::new("wf", Some(500), Arc::new(MockPool));
160        let error = executor
161            .execute_tool(ToolExecutionInput {
162                node_id: "node-1".to_string(),
163                tool: "fail".to_string(),
164                input: json!({}),
165                scoped_input: json!({"input": {}}),
166            })
167            .await
168            .expect_err("worker error should map to tool error");
169
170        assert!(matches!(error, ToolExecutionError::Failed(_)));
171    }
172}