strands_agents/multiagent/a2a/
server.rs

1//! A2A-compatible server for Strands Agent.
2//!
3//! This module provides the A2AServer, which wraps a Strands Agent
4//! and exposes it via the A2A protocol.
5
6use std::collections::HashMap;
7use std::sync::Arc;
8
9use tokio::sync::RwLock;
10
11use super::executor::StrandsA2AExecutor;
12use super::types::{
13    A2AError, A2ARequest, A2AResponse, A2ATask, A2ATaskState, AgentCard, AgentSkill,
14};
15use crate::agent::Agent;
16
17/// Configuration for the A2A server.
18#[derive(Debug, Clone)]
19pub struct A2AServerConfig {
20    /// Host to bind the server to.
21    pub host: String,
22    /// Port to bind the server to.
23    pub port: u16,
24    /// Public HTTP URL where the agent is accessible.
25    pub http_url: Option<String>,
26    /// Whether to serve at root path.
27    pub serve_at_root: bool,
28    /// Version of the agent.
29    pub version: String,
30    /// Skills exposed by the agent.
31    pub skills: Vec<AgentSkill>,
32}
33
34impl Default for A2AServerConfig {
35    fn default() -> Self {
36        Self {
37            host: "127.0.0.1".to_string(),
38            port: 9000,
39            http_url: None,
40            serve_at_root: false,
41            version: "0.0.1".to_string(),
42            skills: Vec::new(),
43        }
44    }
45}
46
47impl A2AServerConfig {
48    pub fn new() -> Self {
49        Self::default()
50    }
51
52    pub fn with_host(mut self, host: impl Into<String>) -> Self {
53        self.host = host.into();
54        self
55    }
56
57    pub fn with_port(mut self, port: u16) -> Self {
58        self.port = port;
59        self
60    }
61
62    pub fn with_version(mut self, version: impl Into<String>) -> Self {
63        self.version = version.into();
64        self
65    }
66
67    pub fn with_skill(mut self, skill: AgentSkill) -> Self {
68        self.skills.push(skill);
69        self
70    }
71}
72
73/// A2A-compatible server wrapping a Strands Agent.
74pub struct A2AServer {
75    config: A2AServerConfig,
76    executor: Arc<StrandsA2AExecutor>,
77    agent_name: String,
78    agent_description: Option<String>,
79    tasks: Arc<RwLock<HashMap<String, A2ATask>>>,
80}
81
82impl A2AServer {
83    /// Create a new A2A server from a Strands Agent.
84    pub fn new(agent: Agent, config: A2AServerConfig) -> Self {
85        let agent_name = agent.name().map(|s| s.to_string()).unwrap_or_else(|| "Strands Agent".to_string());
86        let agent_description = None;
87
88        Self {
89            config,
90            executor: Arc::new(StrandsA2AExecutor::new(agent)),
91            agent_name,
92            agent_description,
93            tasks: Arc::new(RwLock::new(HashMap::new())),
94        }
95    }
96
97    /// Get the agent card describing this agent.
98    pub fn agent_card(&self) -> AgentCard {
99        let url = self
100            .config
101            .http_url
102            .clone()
103            .unwrap_or_else(|| format!("http://{}:{}/", self.config.host, self.config.port));
104
105        let mut card = AgentCard::new(&self.agent_name, url, &self.config.version)
106            .with_streaming(true);
107
108        if let Some(desc) = &self.agent_description {
109            card = card.with_description(desc);
110        }
111
112        if !self.config.skills.is_empty() {
113            card = card.with_skills(self.config.skills.clone());
114        }
115
116        card
117    }
118
119    /// Handle an A2A JSON-RPC request.
120    pub async fn handle_request(&self, request: A2ARequest) -> A2AResponse {
121        match request.method.as_str() {
122            "agent/card" => self.handle_agent_card(request.id).await,
123            "tasks/send" => self.handle_tasks_send(request).await,
124            "tasks/get" => self.handle_tasks_get(request).await,
125            "tasks/cancel" => self.handle_tasks_cancel(request).await,
126            _ => A2AResponse::error(
127                request.id,
128                A2AError::method_not_found(format!("Unknown method: {}", request.method)),
129            ),
130        }
131    }
132
133    async fn handle_agent_card(&self, id: serde_json::Value) -> A2AResponse {
134        let card = self.agent_card();
135        A2AResponse::success(id, serde_json::to_value(card).unwrap_or_default())
136    }
137
138    async fn handle_tasks_send(&self, request: A2ARequest) -> A2AResponse {
139        let params = match request.params {
140            Some(p) => p,
141            None => {
142                return A2AResponse::error(
143                    request.id,
144                    A2AError::invalid_request("Missing params"),
145                );
146            }
147        };
148
149        let message = match serde_json::from_value(params.get("message").cloned().unwrap_or_default()) {
150            Ok(m) => m,
151            Err(e) => {
152                return A2AResponse::error(
153                    request.id,
154                    A2AError::invalid_request(format!("Invalid message: {}", e)),
155                );
156            }
157        };
158
159        match self.executor.execute(message).await {
160            Ok(task) => {
161                let mut tasks = self.tasks.write().await;
162                tasks.insert(task.id.clone(), task.clone());
163                A2AResponse::success(request.id, serde_json::to_value(task).unwrap_or_default())
164            }
165            Err(e) => A2AResponse::error(request.id, e),
166        }
167    }
168
169    async fn handle_tasks_get(&self, request: A2ARequest) -> A2AResponse {
170        let params = match request.params {
171            Some(p) => p,
172            None => {
173                return A2AResponse::error(
174                    request.id,
175                    A2AError::invalid_request("Missing params"),
176                );
177            }
178        };
179
180        let task_id = match params.get("id").and_then(|v| v.as_str()) {
181            Some(id) => id,
182            None => {
183                return A2AResponse::error(
184                    request.id,
185                    A2AError::invalid_request("Missing task id"),
186                );
187            }
188        };
189
190        let tasks = self.tasks.read().await;
191        match tasks.get(task_id) {
192            Some(task) => {
193                A2AResponse::success(request.id, serde_json::to_value(task).unwrap_or_default())
194            }
195            None => A2AResponse::error(
196                request.id,
197                A2AError::invalid_request(format!("Task not found: {}", task_id)),
198            ),
199        }
200    }
201
202    async fn handle_tasks_cancel(&self, request: A2ARequest) -> A2AResponse {
203        let params = match request.params {
204            Some(p) => p,
205            None => {
206                return A2AResponse::error(
207                    request.id,
208                    A2AError::invalid_request("Missing params"),
209                );
210            }
211        };
212
213        let task_id = match params.get("id").and_then(|v| v.as_str()) {
214            Some(id) => id,
215            None => {
216                return A2AResponse::error(
217                    request.id,
218                    A2AError::invalid_request("Missing task id"),
219                );
220            }
221        };
222
223        let mut tasks = self.tasks.write().await;
224        match tasks.get_mut(task_id) {
225            Some(task) => {
226                task.state = A2ATaskState::Cancelled;
227                A2AResponse::success(request.id, serde_json::to_value(task.clone()).unwrap_or_default())
228            }
229            None => A2AResponse::error(
230                request.id,
231                A2AError::invalid_request(format!("Task not found: {}", task_id)),
232            ),
233        }
234    }
235
236    /// Get the host address.
237    pub fn host(&self) -> &str {
238        &self.config.host
239    }
240
241    /// Get the port.
242    pub fn port(&self) -> u16 {
243        self.config.port
244    }
245
246    /// Get the executor.
247    pub fn executor(&self) -> &Arc<StrandsA2AExecutor> {
248        &self.executor
249    }
250}
251
252#[cfg(test)]
253mod tests {
254    use super::*;
255
256    #[test]
257    fn test_server_config() {
258        let config = A2AServerConfig::new()
259            .with_host("0.0.0.0")
260            .with_port(8080)
261            .with_version("1.0.0");
262
263        assert_eq!(config.host, "0.0.0.0");
264        assert_eq!(config.port, 8080);
265        assert_eq!(config.version, "1.0.0");
266    }
267
268    #[test]
269    fn test_agent_card_creation() {
270        let card = AgentCard::new("Test Agent", "http://localhost:9000/", "1.0.0")
271            .with_streaming(true)
272            .with_description("A test agent");
273
274        assert_eq!(card.name, "Test Agent");
275        assert_eq!(card.version, "1.0.0");
276        assert!(card.capabilities.streaming);
277        assert_eq!(card.description, Some("A test agent".to_string()));
278    }
279
280    #[test]
281    fn test_agent_skill() {
282        let skill = AgentSkill::new("search", "Search the web")
283            .with_description("Searches the web for information");
284
285        assert_eq!(skill.id, "search");
286        assert_eq!(skill.name, "Search the web");
287        assert!(skill.description.is_some());
288    }
289}
290