strands_agents/multiagent/a2a/
executor.rs

1//! Strands Agent executor for the A2A protocol.
2//!
3//! This module provides the StrandsA2AExecutor, which adapts a Strands Agent
4//! to be used as an executor in the A2A protocol.
5
6use std::collections::HashMap;
7use std::pin::pin;
8use std::sync::{Arc, LazyLock};
9
10use base64::Engine;
11use tokio::sync::Mutex;
12
13use super::types::{A2AArtifact, A2AError, A2AMessage, A2APart, A2ATask, A2ATaskState};
14use crate::agent::Agent;
15use crate::types::content::{
16    ContentBlock, DocumentContent, DocumentSource, ImageContent, ImageSource, VideoContent,
17    VideoSource,
18};
19
20/// Default formats for each file type when MIME type is unavailable or unrecognized.
21static DEFAULT_FORMATS: LazyLock<HashMap<&'static str, &'static str>> = LazyLock::new(|| {
22    let mut m = HashMap::new();
23    m.insert("document", "txt");
24    m.insert("image", "png");
25    m.insert("video", "mp4");
26    m.insert("unknown", "txt");
27    m
28});
29
30/// Special case format mappings where format differs from extension.
31static FORMAT_MAPPINGS: LazyLock<HashMap<&'static str, &'static str>> = LazyLock::new(|| {
32    let mut m = HashMap::new();
33    m.insert("jpg", "jpeg");
34    m.insert("htm", "html");
35    m.insert("3gp", "three_gp");
36    m.insert("3gpp", "three_gp");
37    m.insert("3g2", "three_gp");
38    m
39});
40
41/// Executor that adapts a Strands Agent to the A2A protocol.
42pub struct StrandsA2AExecutor {
43    agent: Arc<Mutex<Agent>>,
44}
45
46impl StrandsA2AExecutor {
47    /// Create a new A2A executor wrapping a Strands Agent.
48    pub fn new(agent: Agent) -> Self {
49        Self {
50            agent: Arc::new(Mutex::new(agent)),
51        }
52    }
53
54    /// Execute a request using the Strands Agent.
55    pub async fn execute(&self, message: A2AMessage) -> Result<A2ATask, A2AError> {
56        let content_blocks = self.convert_a2a_parts_to_content_blocks(&message.parts)?;
57
58        let mut agent = self.agent.lock().await;
59
60        let task_id = uuid::Uuid::new_v4().to_string();
61        let context_id = message.context_id.unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
62
63        let mut task = A2ATask::new(&task_id, &context_id);
64        task.state = A2ATaskState::Working;
65
66        match agent.invoke_async(content_blocks).await {
67            Ok(result) => {
68                task.state = A2ATaskState::Completed;
69
70                let response_text = result.text();
71                let artifact = A2AArtifact {
72                    name: "response".to_string(),
73                    parts: vec![A2APart::text(response_text)],
74                    index: Some(0),
75                };
76
77                task.artifacts = Some(vec![artifact]);
78                task.message = Some(A2AMessage::agent(
79                    vec![A2APart::text(result.text())],
80                    Some(context_id),
81                    Some(task_id),
82                ));
83
84                Ok(task)
85            }
86            Err(e) => {
87                task.state = A2ATaskState::Failed;
88                Err(A2AError::internal(e.to_string()))
89            }
90        }
91    }
92
93    /// Execute a request with streaming.
94    pub async fn execute_streaming<F>(
95        &self,
96        message: A2AMessage,
97        mut on_update: F,
98    ) -> Result<A2ATask, A2AError>
99    where
100        F: FnMut(A2ATask) + Send,
101    {
102        use futures::StreamExt;
103
104        let content_blocks = self.convert_a2a_parts_to_content_blocks(&message.parts)?;
105
106        let mut agent = self.agent.lock().await;
107
108        let task_id = uuid::Uuid::new_v4().to_string();
109        let context_id = message.context_id.unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
110
111        let mut task = A2ATask::new(&task_id, &context_id);
112        task.state = A2ATaskState::Working;
113
114        on_update(task.clone());
115
116        let stream = agent.stream_async(content_blocks).await;
117        let mut pinned_stream = pin!(stream);
118        let mut accumulated_text = String::new();
119
120        while let Some(event) = pinned_stream.next().await {
121            match event {
122                Ok(stream_event) => {
123                    if let Some(text) = stream_event.as_text() {
124                        accumulated_text.push_str(&text);
125
126                        let mut update_task = task.clone();
127                        update_task.message = Some(A2AMessage::agent(
128                            vec![A2APart::text(&accumulated_text)],
129                            Some(context_id.clone()),
130                            Some(task_id.clone()),
131                        ));
132
133                        on_update(update_task);
134                    }
135                }
136                Err(e) => {
137                    task.state = A2ATaskState::Failed;
138                    return Err(A2AError::internal(e.to_string()));
139                }
140            }
141        }
142
143        task.state = A2ATaskState::Completed;
144
145        let artifact = A2AArtifact {
146            name: "response".to_string(),
147            parts: vec![A2APart::text(&accumulated_text)],
148            index: Some(0),
149        };
150
151        task.artifacts = Some(vec![artifact]);
152        task.message = Some(A2AMessage::agent(
153            vec![A2APart::text(accumulated_text)],
154            Some(context_id),
155            Some(task_id),
156        ));
157
158        Ok(task)
159    }
160
161    /// Convert A2A message parts to Strands content blocks.
162    fn convert_a2a_parts_to_content_blocks(
163        &self,
164        parts: &[A2APart],
165    ) -> Result<Vec<ContentBlock>, A2AError> {
166        let mut content_blocks = Vec::new();
167
168        for part in parts {
169            match part {
170                A2APart::Text { text } => {
171                    content_blocks.push(ContentBlock::text(text));
172                }
173                A2APart::Data { data } => {
174                    let text = serde_json::to_string_pretty(data)
175                        .map(|json| format!("[Structured Data]\n{}", json))
176                        .unwrap_or_else(|_| data.to_string());
177                    content_blocks.push(ContentBlock::text(text));
178                }
179                A2APart::File { file } => {
180                    let file_type = Self::classify_file_type(file.mime_type.as_deref());
181                    let file_format = Self::get_file_format_from_mime_type(
182                        file.mime_type.as_deref(),
183                        file_type,
184                    );
185                    let file_name = Self::strip_file_extension(&file.name);
186
187                    if let Some(ref bytes_str) = file.bytes {
188                        match base64::engine::general_purpose::STANDARD.decode(bytes_str) {
189                            Ok(decoded_bytes) => {
190                                let bytes_base64 = base64::engine::general_purpose::STANDARD
191                                    .encode(&decoded_bytes);
192
193                                match file_type {
194                                    "image" => {
195                                        content_blocks.push(ContentBlock {
196                                            image: Some(ImageContent {
197                                                format: file_format,
198                                                source: ImageSource {
199                                                    bytes: Some(bytes_base64),
200                                                },
201                                            }),
202                                            ..Default::default()
203                                        });
204                                    }
205                                    "video" => {
206                                        content_blocks.push(ContentBlock {
207                                            video: Some(VideoContent {
208                                                format: file_format,
209                                                source: VideoSource {
210                                                    bytes: Some(bytes_base64),
211                                                },
212                                            }),
213                                            ..Default::default()
214                                        });
215                                    }
216                                    _ => {
217                                        content_blocks.push(ContentBlock {
218                                            document: Some(DocumentContent {
219                                                format: file_format,
220                                                name: file_name.to_string(),
221                                                source: DocumentSource {
222                                                    bytes: Some(bytes_base64),
223                                                },
224                                            }),
225                                            ..Default::default()
226                                        });
227                                    }
228                                }
229                            }
230                            Err(e) => {
231                                tracing::warn!(
232                                    "Failed to decode base64 data for file '{}': {}",
233                                    file.name,
234                                    e
235                                );
236                                let text = format!(
237                                    "[File: {} ({:?})] - Failed to decode base64 data",
238                                    file.name, file.mime_type
239                                );
240                                content_blocks.push(ContentBlock::text(text));
241                            }
242                        }
243                    } else if let Some(ref uri) = file.uri {
244                        let text = format!(
245                            "[File: {} ({:?})] - Referenced file at: {}",
246                            file_name, file.mime_type, uri
247                        );
248                        content_blocks.push(ContentBlock::text(text));
249                    } else {
250                        let text = format!("[File: {}]", file.name);
251                        content_blocks.push(ContentBlock::text(text));
252                    }
253                }
254            }
255        }
256
257        if content_blocks.is_empty() {
258            return Err(A2AError::invalid_request("No content blocks available"));
259        }
260
261        Ok(content_blocks)
262    }
263
264    /// Convert Strands content blocks to A2A message parts.
265    pub fn convert_content_blocks_to_a2a_parts(blocks: &[ContentBlock]) -> Vec<A2APart> {
266        blocks
267            .iter()
268            .filter_map(|block| {
269                if let Some(text) = &block.text {
270                    Some(A2APart::text(text))
271                } else {
272                    None
273                }
274            })
275            .collect()
276    }
277
278    /// Get file format from MIME type.
279    ///
280    /// Uses the MIME type to determine the appropriate file format.
281    /// Falls back to default formats if MIME type is unavailable or unrecognized.
282    pub fn get_file_format_from_mime_type(mime_type: Option<&str>, file_type: &str) -> String {
283        let Some(mime_type) = mime_type else {
284            return DEFAULT_FORMATS.get(file_type).copied().unwrap_or("txt").to_string();
285        };
286
287        let mime_lower = mime_type.to_lowercase();
288
289        if let Some(subtype) = mime_lower.split('/').last() {
290            if let Some(mapped) = FORMAT_MAPPINGS.get(subtype) {
291                return mapped.to_string();
292            }
293        }
294
295        if let Some(subtype) = mime_lower.split('/').last() {
296            if let Some(mapped) = FORMAT_MAPPINGS.get(subtype) {
297                return mapped.to_string();
298            }
299            return subtype.to_string();
300        }
301
302        DEFAULT_FORMATS.get(file_type).copied().unwrap_or("txt").to_string()
303    }
304
305    /// Strip the file extension from a file name.
306    pub fn strip_file_extension(file_name: &str) -> &str {
307        if let Some(pos) = file_name.rfind('.') {
308            &file_name[..pos]
309        } else {
310            file_name
311        }
312    }
313
314    /// Classify file type based on MIME type.
315    pub fn classify_file_type(mime_type: Option<&str>) -> &'static str {
316        let Some(mime) = mime_type else {
317            return "unknown";
318        };
319
320        let mime_lower = mime.to_lowercase();
321
322        if mime_lower.starts_with("image/") {
323            "image"
324        } else if mime_lower.starts_with("video/") {
325            "video"
326        } else if mime_lower.starts_with("text/")
327            || mime_lower.starts_with("application/pdf")
328            || mime_lower.starts_with("application/msword")
329            || mime_lower.contains("document")
330        {
331            "document"
332        } else {
333            "unknown"
334        }
335    }
336}
337
338#[cfg(test)]
339mod tests {
340    use super::*;
341
342    #[test]
343    fn test_a2a_part_text() {
344        let part = A2APart::text("Hello");
345        match part {
346            A2APart::Text { text } => assert_eq!(text, "Hello"),
347            _ => panic!("Expected text part"),
348        }
349    }
350
351    #[test]
352    fn test_a2a_part_data() {
353        let part = A2APart::data(serde_json::json!({"key": "value"}));
354        match part {
355            A2APart::Data { data } => assert_eq!(data["key"], "value"),
356            _ => panic!("Expected data part"),
357        }
358    }
359
360    #[test]
361    fn test_content_blocks_to_parts() {
362        let blocks = vec![
363            ContentBlock::text("Hello"),
364            ContentBlock::text("World"),
365        ];
366
367        let parts = StrandsA2AExecutor::convert_content_blocks_to_a2a_parts(&blocks);
368        assert_eq!(parts.len(), 2);
369    }
370}
371