strands_agents/multiagent/a2a/
executor.rs1use std::collections::HashMap;
7use std::pin::pin;
8use std::sync::{Arc, LazyLock};
9
10use base64::Engine;
11use tokio::sync::Mutex;
12
13use super::types::{A2AArtifact, A2AError, A2AMessage, A2APart, A2ATask, A2ATaskState};
14use crate::agent::Agent;
15use crate::types::content::{
16 ContentBlock, DocumentContent, DocumentSource, ImageContent, ImageSource, VideoContent,
17 VideoSource,
18};
19
20static DEFAULT_FORMATS: LazyLock<HashMap<&'static str, &'static str>> = LazyLock::new(|| {
22 let mut m = HashMap::new();
23 m.insert("document", "txt");
24 m.insert("image", "png");
25 m.insert("video", "mp4");
26 m.insert("unknown", "txt");
27 m
28});
29
30static FORMAT_MAPPINGS: LazyLock<HashMap<&'static str, &'static str>> = LazyLock::new(|| {
32 let mut m = HashMap::new();
33 m.insert("jpg", "jpeg");
34 m.insert("htm", "html");
35 m.insert("3gp", "three_gp");
36 m.insert("3gpp", "three_gp");
37 m.insert("3g2", "three_gp");
38 m
39});
40
41pub struct StrandsA2AExecutor {
43 agent: Arc<Mutex<Agent>>,
44}
45
46impl StrandsA2AExecutor {
47 pub fn new(agent: Agent) -> Self {
49 Self {
50 agent: Arc::new(Mutex::new(agent)),
51 }
52 }
53
54 pub async fn execute(&self, message: A2AMessage) -> Result<A2ATask, A2AError> {
56 let content_blocks = self.convert_a2a_parts_to_content_blocks(&message.parts)?;
57
58 let mut agent = self.agent.lock().await;
59
60 let task_id = uuid::Uuid::new_v4().to_string();
61 let context_id = message.context_id.unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
62
63 let mut task = A2ATask::new(&task_id, &context_id);
64 task.state = A2ATaskState::Working;
65
66 match agent.invoke_async(content_blocks).await {
67 Ok(result) => {
68 task.state = A2ATaskState::Completed;
69
70 let response_text = result.text();
71 let artifact = A2AArtifact {
72 name: "response".to_string(),
73 parts: vec![A2APart::text(response_text)],
74 index: Some(0),
75 };
76
77 task.artifacts = Some(vec![artifact]);
78 task.message = Some(A2AMessage::agent(
79 vec![A2APart::text(result.text())],
80 Some(context_id),
81 Some(task_id),
82 ));
83
84 Ok(task)
85 }
86 Err(e) => {
87 task.state = A2ATaskState::Failed;
88 Err(A2AError::internal(e.to_string()))
89 }
90 }
91 }
92
93 pub async fn execute_streaming<F>(
95 &self,
96 message: A2AMessage,
97 mut on_update: F,
98 ) -> Result<A2ATask, A2AError>
99 where
100 F: FnMut(A2ATask) + Send,
101 {
102 use futures::StreamExt;
103
104 let content_blocks = self.convert_a2a_parts_to_content_blocks(&message.parts)?;
105
106 let mut agent = self.agent.lock().await;
107
108 let task_id = uuid::Uuid::new_v4().to_string();
109 let context_id = message.context_id.unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
110
111 let mut task = A2ATask::new(&task_id, &context_id);
112 task.state = A2ATaskState::Working;
113
114 on_update(task.clone());
115
116 let stream = agent.stream_async(content_blocks).await;
117 let mut pinned_stream = pin!(stream);
118 let mut accumulated_text = String::new();
119
120 while let Some(event) = pinned_stream.next().await {
121 match event {
122 Ok(stream_event) => {
123 if let Some(text) = stream_event.as_text() {
124 accumulated_text.push_str(&text);
125
126 let mut update_task = task.clone();
127 update_task.message = Some(A2AMessage::agent(
128 vec![A2APart::text(&accumulated_text)],
129 Some(context_id.clone()),
130 Some(task_id.clone()),
131 ));
132
133 on_update(update_task);
134 }
135 }
136 Err(e) => {
137 task.state = A2ATaskState::Failed;
138 return Err(A2AError::internal(e.to_string()));
139 }
140 }
141 }
142
143 task.state = A2ATaskState::Completed;
144
145 let artifact = A2AArtifact {
146 name: "response".to_string(),
147 parts: vec![A2APart::text(&accumulated_text)],
148 index: Some(0),
149 };
150
151 task.artifacts = Some(vec![artifact]);
152 task.message = Some(A2AMessage::agent(
153 vec![A2APart::text(accumulated_text)],
154 Some(context_id),
155 Some(task_id),
156 ));
157
158 Ok(task)
159 }
160
161 fn convert_a2a_parts_to_content_blocks(
163 &self,
164 parts: &[A2APart],
165 ) -> Result<Vec<ContentBlock>, A2AError> {
166 let mut content_blocks = Vec::new();
167
168 for part in parts {
169 match part {
170 A2APart::Text { text } => {
171 content_blocks.push(ContentBlock::text(text));
172 }
173 A2APart::Data { data } => {
174 let text = serde_json::to_string_pretty(data)
175 .map(|json| format!("[Structured Data]\n{}", json))
176 .unwrap_or_else(|_| data.to_string());
177 content_blocks.push(ContentBlock::text(text));
178 }
179 A2APart::File { file } => {
180 let file_type = Self::classify_file_type(file.mime_type.as_deref());
181 let file_format = Self::get_file_format_from_mime_type(
182 file.mime_type.as_deref(),
183 file_type,
184 );
185 let file_name = Self::strip_file_extension(&file.name);
186
187 if let Some(ref bytes_str) = file.bytes {
188 match base64::engine::general_purpose::STANDARD.decode(bytes_str) {
189 Ok(decoded_bytes) => {
190 let bytes_base64 = base64::engine::general_purpose::STANDARD
191 .encode(&decoded_bytes);
192
193 match file_type {
194 "image" => {
195 content_blocks.push(ContentBlock {
196 image: Some(ImageContent {
197 format: file_format,
198 source: ImageSource {
199 bytes: Some(bytes_base64),
200 },
201 }),
202 ..Default::default()
203 });
204 }
205 "video" => {
206 content_blocks.push(ContentBlock {
207 video: Some(VideoContent {
208 format: file_format,
209 source: VideoSource {
210 bytes: Some(bytes_base64),
211 },
212 }),
213 ..Default::default()
214 });
215 }
216 _ => {
217 content_blocks.push(ContentBlock {
218 document: Some(DocumentContent {
219 format: file_format,
220 name: file_name.to_string(),
221 source: DocumentSource {
222 bytes: Some(bytes_base64),
223 },
224 }),
225 ..Default::default()
226 });
227 }
228 }
229 }
230 Err(e) => {
231 tracing::warn!(
232 "Failed to decode base64 data for file '{}': {}",
233 file.name,
234 e
235 );
236 let text = format!(
237 "[File: {} ({:?})] - Failed to decode base64 data",
238 file.name, file.mime_type
239 );
240 content_blocks.push(ContentBlock::text(text));
241 }
242 }
243 } else if let Some(ref uri) = file.uri {
244 let text = format!(
245 "[File: {} ({:?})] - Referenced file at: {}",
246 file_name, file.mime_type, uri
247 );
248 content_blocks.push(ContentBlock::text(text));
249 } else {
250 let text = format!("[File: {}]", file.name);
251 content_blocks.push(ContentBlock::text(text));
252 }
253 }
254 }
255 }
256
257 if content_blocks.is_empty() {
258 return Err(A2AError::invalid_request("No content blocks available"));
259 }
260
261 Ok(content_blocks)
262 }
263
264 pub fn convert_content_blocks_to_a2a_parts(blocks: &[ContentBlock]) -> Vec<A2APart> {
266 blocks
267 .iter()
268 .filter_map(|block| {
269 if let Some(text) = &block.text {
270 Some(A2APart::text(text))
271 } else {
272 None
273 }
274 })
275 .collect()
276 }
277
278 pub fn get_file_format_from_mime_type(mime_type: Option<&str>, file_type: &str) -> String {
283 let Some(mime_type) = mime_type else {
284 return DEFAULT_FORMATS.get(file_type).copied().unwrap_or("txt").to_string();
285 };
286
287 let mime_lower = mime_type.to_lowercase();
288
289 if let Some(subtype) = mime_lower.split('/').last() {
290 if let Some(mapped) = FORMAT_MAPPINGS.get(subtype) {
291 return mapped.to_string();
292 }
293 }
294
295 if let Some(subtype) = mime_lower.split('/').last() {
296 if let Some(mapped) = FORMAT_MAPPINGS.get(subtype) {
297 return mapped.to_string();
298 }
299 return subtype.to_string();
300 }
301
302 DEFAULT_FORMATS.get(file_type).copied().unwrap_or("txt").to_string()
303 }
304
305 pub fn strip_file_extension(file_name: &str) -> &str {
307 if let Some(pos) = file_name.rfind('.') {
308 &file_name[..pos]
309 } else {
310 file_name
311 }
312 }
313
314 pub fn classify_file_type(mime_type: Option<&str>) -> &'static str {
316 let Some(mime) = mime_type else {
317 return "unknown";
318 };
319
320 let mime_lower = mime.to_lowercase();
321
322 if mime_lower.starts_with("image/") {
323 "image"
324 } else if mime_lower.starts_with("video/") {
325 "video"
326 } else if mime_lower.starts_with("text/")
327 || mime_lower.starts_with("application/pdf")
328 || mime_lower.starts_with("application/msword")
329 || mime_lower.contains("document")
330 {
331 "document"
332 } else {
333 "unknown"
334 }
335 }
336}
337
338#[cfg(test)]
339mod tests {
340 use super::*;
341
342 #[test]
343 fn test_a2a_part_text() {
344 let part = A2APart::text("Hello");
345 match part {
346 A2APart::Text { text } => assert_eq!(text, "Hello"),
347 _ => panic!("Expected text part"),
348 }
349 }
350
351 #[test]
352 fn test_a2a_part_data() {
353 let part = A2APart::data(serde_json::json!({"key": "value"}));
354 match part {
355 A2APart::Data { data } => assert_eq!(data["key"], "value"),
356 _ => panic!("Expected data part"),
357 }
358 }
359
360 #[test]
361 fn test_content_blocks_to_parts() {
362 let blocks = vec![
363 ContentBlock::text("Hello"),
364 ContentBlock::text("World"),
365 ];
366
367 let parts = StrandsA2AExecutor::convert_content_blocks_to_a2a_parts(&blocks);
368 assert_eq!(parts.len(), 2);
369 }
370}
371