strands_agents/multiagent/a2a/
server.rs1use std::collections::HashMap;
7use std::sync::Arc;
8
9use tokio::sync::RwLock;
10
11use super::executor::StrandsA2AExecutor;
12use super::types::{
13 A2AError, A2ARequest, A2AResponse, A2ATask, A2ATaskState, AgentCard, AgentSkill,
14};
15use crate::agent::Agent;
16
17#[derive(Debug, Clone)]
19pub struct A2AServerConfig {
20 pub host: String,
22 pub port: u16,
24 pub http_url: Option<String>,
26 pub serve_at_root: bool,
28 pub version: String,
30 pub skills: Vec<AgentSkill>,
32}
33
34impl Default for A2AServerConfig {
35 fn default() -> Self {
36 Self {
37 host: "127.0.0.1".to_string(),
38 port: 9000,
39 http_url: None,
40 serve_at_root: false,
41 version: "0.0.1".to_string(),
42 skills: Vec::new(),
43 }
44 }
45}
46
47impl A2AServerConfig {
48 pub fn new() -> Self {
49 Self::default()
50 }
51
52 pub fn with_host(mut self, host: impl Into<String>) -> Self {
53 self.host = host.into();
54 self
55 }
56
57 pub fn with_port(mut self, port: u16) -> Self {
58 self.port = port;
59 self
60 }
61
62 pub fn with_version(mut self, version: impl Into<String>) -> Self {
63 self.version = version.into();
64 self
65 }
66
67 pub fn with_skill(mut self, skill: AgentSkill) -> Self {
68 self.skills.push(skill);
69 self
70 }
71}
72
73pub struct A2AServer {
75 config: A2AServerConfig,
76 executor: Arc<StrandsA2AExecutor>,
77 agent_name: String,
78 agent_description: Option<String>,
79 tasks: Arc<RwLock<HashMap<String, A2ATask>>>,
80}
81
82impl A2AServer {
83 pub fn new(agent: Agent, config: A2AServerConfig) -> Self {
85 let agent_name = agent.name().map(|s| s.to_string()).unwrap_or_else(|| "Strands Agent".to_string());
86 let agent_description = None;
87
88 Self {
89 config,
90 executor: Arc::new(StrandsA2AExecutor::new(agent)),
91 agent_name,
92 agent_description,
93 tasks: Arc::new(RwLock::new(HashMap::new())),
94 }
95 }
96
97 pub fn agent_card(&self) -> AgentCard {
99 let url = self
100 .config
101 .http_url
102 .clone()
103 .unwrap_or_else(|| format!("http://{}:{}/", self.config.host, self.config.port));
104
105 let mut card = AgentCard::new(&self.agent_name, url, &self.config.version)
106 .with_streaming(true);
107
108 if let Some(desc) = &self.agent_description {
109 card = card.with_description(desc);
110 }
111
112 if !self.config.skills.is_empty() {
113 card = card.with_skills(self.config.skills.clone());
114 }
115
116 card
117 }
118
119 pub async fn handle_request(&self, request: A2ARequest) -> A2AResponse {
121 match request.method.as_str() {
122 "agent/card" => self.handle_agent_card(request.id).await,
123 "tasks/send" => self.handle_tasks_send(request).await,
124 "tasks/get" => self.handle_tasks_get(request).await,
125 "tasks/cancel" => self.handle_tasks_cancel(request).await,
126 _ => A2AResponse::error(
127 request.id,
128 A2AError::method_not_found(format!("Unknown method: {}", request.method)),
129 ),
130 }
131 }
132
133 async fn handle_agent_card(&self, id: serde_json::Value) -> A2AResponse {
134 let card = self.agent_card();
135 A2AResponse::success(id, serde_json::to_value(card).unwrap_or_default())
136 }
137
138 async fn handle_tasks_send(&self, request: A2ARequest) -> A2AResponse {
139 let params = match request.params {
140 Some(p) => p,
141 None => {
142 return A2AResponse::error(
143 request.id,
144 A2AError::invalid_request("Missing params"),
145 );
146 }
147 };
148
149 let message = match serde_json::from_value(params.get("message").cloned().unwrap_or_default()) {
150 Ok(m) => m,
151 Err(e) => {
152 return A2AResponse::error(
153 request.id,
154 A2AError::invalid_request(format!("Invalid message: {}", e)),
155 );
156 }
157 };
158
159 match self.executor.execute(message).await {
160 Ok(task) => {
161 let mut tasks = self.tasks.write().await;
162 tasks.insert(task.id.clone(), task.clone());
163 A2AResponse::success(request.id, serde_json::to_value(task).unwrap_or_default())
164 }
165 Err(e) => A2AResponse::error(request.id, e),
166 }
167 }
168
169 async fn handle_tasks_get(&self, request: A2ARequest) -> A2AResponse {
170 let params = match request.params {
171 Some(p) => p,
172 None => {
173 return A2AResponse::error(
174 request.id,
175 A2AError::invalid_request("Missing params"),
176 );
177 }
178 };
179
180 let task_id = match params.get("id").and_then(|v| v.as_str()) {
181 Some(id) => id,
182 None => {
183 return A2AResponse::error(
184 request.id,
185 A2AError::invalid_request("Missing task id"),
186 );
187 }
188 };
189
190 let tasks = self.tasks.read().await;
191 match tasks.get(task_id) {
192 Some(task) => {
193 A2AResponse::success(request.id, serde_json::to_value(task).unwrap_or_default())
194 }
195 None => A2AResponse::error(
196 request.id,
197 A2AError::invalid_request(format!("Task not found: {}", task_id)),
198 ),
199 }
200 }
201
202 async fn handle_tasks_cancel(&self, request: A2ARequest) -> A2AResponse {
203 let params = match request.params {
204 Some(p) => p,
205 None => {
206 return A2AResponse::error(
207 request.id,
208 A2AError::invalid_request("Missing params"),
209 );
210 }
211 };
212
213 let task_id = match params.get("id").and_then(|v| v.as_str()) {
214 Some(id) => id,
215 None => {
216 return A2AResponse::error(
217 request.id,
218 A2AError::invalid_request("Missing task id"),
219 );
220 }
221 };
222
223 let mut tasks = self.tasks.write().await;
224 match tasks.get_mut(task_id) {
225 Some(task) => {
226 task.state = A2ATaskState::Cancelled;
227 A2AResponse::success(request.id, serde_json::to_value(task.clone()).unwrap_or_default())
228 }
229 None => A2AResponse::error(
230 request.id,
231 A2AError::invalid_request(format!("Task not found: {}", task_id)),
232 ),
233 }
234 }
235
236 pub fn host(&self) -> &str {
238 &self.config.host
239 }
240
241 pub fn port(&self) -> u16 {
243 self.config.port
244 }
245
246 pub fn executor(&self) -> &Arc<StrandsA2AExecutor> {
248 &self.executor
249 }
250}
251
252#[cfg(test)]
253mod tests {
254 use super::*;
255
256 #[test]
257 fn test_server_config() {
258 let config = A2AServerConfig::new()
259 .with_host("0.0.0.0")
260 .with_port(8080)
261 .with_version("1.0.0");
262
263 assert_eq!(config.host, "0.0.0.0");
264 assert_eq!(config.port, 8080);
265 assert_eq!(config.version, "1.0.0");
266 }
267
268 #[test]
269 fn test_agent_card_creation() {
270 let card = AgentCard::new("Test Agent", "http://localhost:9000/", "1.0.0")
271 .with_streaming(true)
272 .with_description("A test agent");
273
274 assert_eq!(card.name, "Test Agent");
275 assert_eq!(card.version, "1.0.0");
276 assert!(card.capabilities.streaming);
277 assert_eq!(card.description, Some("A test agent".to_string()));
278 }
279
280 #[test]
281 fn test_agent_skill() {
282 let skill = AgentSkill::new("search", "Search the web")
283 .with_description("Searches the web for information");
284
285 assert_eq!(skill.id, "search");
286 assert_eq!(skill.name, "Search the web");
287 assert!(skill.description.is_some());
288 }
289}
290