strands_agents/multiagent/
base.rs

1//! Multi-agent base types and traits.
2//!
3//! Provides the foundation for multi-agent patterns (Swarm, Graph).
4
5use std::collections::HashMap;
6
7use async_trait::async_trait;
8use serde::{Deserialize, Serialize};
9
10use crate::agent::AgentResult;
11use crate::types::content::ContentBlock;
12use crate::types::errors::Result;
13use crate::types::streaming::{Metrics, Usage};
14
15/// Execution status for both graphs and nodes.
16#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
17#[serde(rename_all = "lowercase")]
18pub enum Status {
19    Pending,
20    Executing,
21    Completed,
22    Failed,
23    Interrupted,
24}
25
26impl Default for Status {
27    fn default() -> Self {
28        Self::Pending
29    }
30}
31
32use crate::types::interrupt::InterruptResponseContent;
33
34/// Input type for multi-agent systems.
35///
36/// This type can represent:
37/// - Plain text input
38/// - A sequence of content blocks
39/// - Interrupt responses for resuming after an interrupt
40#[derive(Debug, Clone)]
41pub enum MultiAgentInput {
42    /// Plain text input.
43    Text(String),
44    /// A sequence of content blocks (text, images, etc.).
45    ContentBlocks(Vec<ContentBlock>),
46    /// Interrupt responses for resuming execution after an interrupt.
47    InterruptResponses(Vec<InterruptResponseContent>),
48}
49
50impl From<&str> for MultiAgentInput {
51    fn from(s: &str) -> Self {
52        MultiAgentInput::Text(s.to_string())
53    }
54}
55
56impl From<String> for MultiAgentInput {
57    fn from(s: String) -> Self {
58        MultiAgentInput::Text(s)
59    }
60}
61
62impl From<Vec<ContentBlock>> for MultiAgentInput {
63    fn from(blocks: Vec<ContentBlock>) -> Self {
64        MultiAgentInput::ContentBlocks(blocks)
65    }
66}
67
68impl From<Vec<InterruptResponseContent>> for MultiAgentInput {
69    fn from(responses: Vec<InterruptResponseContent>) -> Self {
70        MultiAgentInput::InterruptResponses(responses)
71    }
72}
73
74impl MultiAgentInput {
75    /// Returns the input as text if it is a Text variant.
76    pub fn as_text(&self) -> Option<&str> {
77        match self {
78            MultiAgentInput::Text(s) => Some(s),
79            _ => None,
80        }
81    }
82
83    /// Returns the input as content blocks if it is a ContentBlocks variant.
84    pub fn as_content_blocks(&self) -> Option<&[ContentBlock]> {
85        match self {
86            MultiAgentInput::ContentBlocks(blocks) => Some(blocks),
87            _ => None,
88        }
89    }
90
91    /// Returns the input as interrupt responses if it is an InterruptResponses variant.
92    pub fn as_interrupt_responses(&self) -> Option<&[InterruptResponseContent]> {
93        match self {
94            MultiAgentInput::InterruptResponses(responses) => Some(responses),
95            _ => None,
96        }
97    }
98
99    /// Returns true if this is an interrupt response input.
100    pub fn is_interrupt_response(&self) -> bool {
101        matches!(self, MultiAgentInput::InterruptResponses(_))
102    }
103
104    /// Converts the input to a string, losing type information.
105    pub fn to_string_lossy(&self) -> String {
106        match self {
107            MultiAgentInput::Text(s) => s.clone(),
108            MultiAgentInput::ContentBlocks(blocks) => blocks
109                .iter()
110                .filter_map(|b| b.text.as_ref())
111                .cloned()
112                .collect::<Vec<_>>()
113                .join("\n"),
114            MultiAgentInput::InterruptResponses(responses) => responses
115                .iter()
116                .map(|r| {
117                    format!(
118                        "{}:{}",
119                        r.interrupt_response.interrupt_id,
120                        r.interrupt_response.response
121                    )
122                })
123                .collect::<Vec<_>>()
124                .join("; "),
125        }
126    }
127}
128
129/// Interrupt for pausing multi-agent execution.
130#[derive(Debug, Clone, Serialize, Deserialize)]
131pub struct Interrupt {
132    pub id: String,
133    pub tool_name: String,
134    pub tool_use_id: String,
135    pub message: Option<String>,
136    #[serde(skip_serializing_if = "Option::is_none")]
137    pub response: Option<serde_json::Value>,
138}
139
140impl Interrupt {
141    pub fn new(id: impl Into<String>, tool_name: impl Into<String>, tool_use_id: impl Into<String>) -> Self {
142        Self {
143            id: id.into(),
144            tool_name: tool_name.into(),
145            tool_use_id: tool_use_id.into(),
146            message: None,
147            response: None,
148        }
149    }
150
151    pub fn with_message(mut self, message: impl Into<String>) -> Self {
152        self.message = Some(message.into());
153        self
154    }
155
156    pub fn with_response(mut self, response: serde_json::Value) -> Self {
157        self.response = Some(response);
158        self
159    }
160
161    pub fn has_response(&self) -> bool {
162        self.response.is_some()
163    }
164}
165
166/// Result from a single node execution.
167#[derive(Debug, Clone)]
168pub struct NodeResult {
169    pub result: NodeResultValue,
170    pub execution_time_ms: u64,
171    pub status: Status,
172    pub accumulated_usage: Usage,
173    pub accumulated_metrics: Metrics,
174    pub execution_count: u32,
175    pub interrupts: Vec<Interrupt>,
176}
177
178/// The value contained in a NodeResult.
179#[derive(Debug, Clone)]
180pub enum NodeResultValue {
181    Agent(AgentResult),
182    MultiAgent(Box<MultiAgentResult>),
183    Error(String),
184}
185
186impl NodeResult {
187    pub fn from_agent(result: AgentResult, execution_time_ms: u64) -> Self {
188        Self {
189            result: NodeResultValue::Agent(result),
190            execution_time_ms,
191            status: Status::Completed,
192            accumulated_usage: Usage::default(),
193            accumulated_metrics: Metrics::default(),
194            execution_count: 1,
195            interrupts: Vec::new(),
196        }
197    }
198
199    pub fn from_error(error: impl Into<String>, execution_time_ms: u64) -> Self {
200        Self {
201            result: NodeResultValue::Error(error.into()),
202            execution_time_ms,
203            status: Status::Failed,
204            accumulated_usage: Usage::default(),
205            accumulated_metrics: Metrics::default(),
206            execution_count: 1,
207            interrupts: Vec::new(),
208        }
209    }
210
211    pub fn get_agent_results(&self) -> Vec<&AgentResult> {
212        match &self.result {
213            NodeResultValue::Agent(r) => vec![r],
214            NodeResultValue::MultiAgent(m) => m
215                .results
216                .values()
217                .flat_map(|nr| nr.get_agent_results())
218                .collect(),
219            NodeResultValue::Error(_) => vec![],
220        }
221    }
222
223    pub fn is_error(&self) -> bool {
224        matches!(self.result, NodeResultValue::Error(_))
225    }
226
227    pub fn is_interrupted(&self) -> bool {
228        self.status == Status::Interrupted
229    }
230}
231
232/// Result from multi-agent execution with accumulated metrics.
233#[derive(Debug, Clone, Default)]
234pub struct MultiAgentResult {
235    pub status: Status,
236    pub results: HashMap<String, NodeResult>,
237    pub accumulated_usage: Usage,
238    pub accumulated_metrics: Metrics,
239    pub execution_count: u32,
240    pub execution_time_ms: u64,
241    pub interrupts: Vec<Interrupt>,
242}
243
244impl MultiAgentResult {
245    pub fn new() -> Self {
246        Self::default()
247    }
248
249    pub fn with_status(mut self, status: Status) -> Self {
250        self.status = status;
251        self
252    }
253
254    pub fn add_node_result(&mut self, node_id: impl Into<String>, result: NodeResult) {
255        self.accumulated_usage.add(&result.accumulated_usage);
256        self.accumulated_metrics.latency_ms += result.accumulated_metrics.latency_ms;
257        self.execution_count += result.execution_count;
258        self.results.insert(node_id.into(), result);
259    }
260}
261
262/// Events emitted during multi-agent execution.
263#[derive(Debug, Clone)]
264pub enum MultiAgentEvent {
265    /// A node has started execution.
266    NodeStart {
267        node_id: String,
268        node_type: String,
269    },
270    /// A node has stopped execution.
271    NodeStop {
272        node_id: String,
273        node_result: NodeResult,
274    },
275    /// A handoff occurred between nodes.
276    Handoff {
277        from_node_ids: Vec<String>,
278        to_node_ids: Vec<String>,
279        message: Option<String>,
280    },
281    /// Streaming event from a node.
282    NodeStream {
283        node_id: String,
284        event: serde_json::Value,
285    },
286    /// A node was cancelled.
287    NodeCancel {
288        node_id: String,
289        message: String,
290    },
291    /// A node was interrupted.
292    NodeInterrupt {
293        node_id: String,
294        interrupts: Vec<Interrupt>,
295    },
296    /// Final result.
297    Result(MultiAgentResult),
298}
299
300impl MultiAgentEvent {
301    pub fn node_start(node_id: impl Into<String>, node_type: impl Into<String>) -> Self {
302        Self::NodeStart {
303            node_id: node_id.into(),
304            node_type: node_type.into(),
305        }
306    }
307
308    pub fn node_stop(node_id: impl Into<String>, node_result: NodeResult) -> Self {
309        Self::NodeStop {
310            node_id: node_id.into(),
311            node_result,
312        }
313    }
314
315    pub fn handoff(
316        from_node_ids: Vec<String>,
317        to_node_ids: Vec<String>,
318        message: Option<String>,
319    ) -> Self {
320        Self::Handoff {
321            from_node_ids,
322            to_node_ids,
323            message,
324        }
325    }
326
327    pub fn node_stream(node_id: impl Into<String>, event: serde_json::Value) -> Self {
328        Self::NodeStream {
329            node_id: node_id.into(),
330            event,
331        }
332    }
333
334    pub fn node_cancel(node_id: impl Into<String>, message: impl Into<String>) -> Self {
335        Self::NodeCancel {
336            node_id: node_id.into(),
337            message: message.into(),
338        }
339    }
340
341    pub fn node_interrupt(node_id: impl Into<String>, interrupts: Vec<Interrupt>) -> Self {
342        Self::NodeInterrupt {
343            node_id: node_id.into(),
344            interrupts,
345        }
346    }
347
348    pub fn result(result: MultiAgentResult) -> Self {
349        Self::Result(result)
350    }
351
352    pub fn is_result(&self) -> bool {
353        matches!(self, Self::Result(_))
354    }
355
356    pub fn as_result(&self) -> Option<&MultiAgentResult> {
357        match self {
358            Self::Result(r) => Some(r),
359            _ => None,
360        }
361    }
362}
363
364/// Stream of multi-agent events.
365pub type MultiAgentEventStream<'a> =
366    std::pin::Pin<Box<dyn futures::Stream<Item = MultiAgentEvent> + Send + 'a>>;
367
368/// Invocation state passed through multi-agent execution.
369#[derive(Debug, Clone, Default)]
370pub struct InvocationState {
371    pub data: HashMap<String, serde_json::Value>,
372}
373
374impl InvocationState {
375    pub fn new() -> Self {
376        Self::default()
377    }
378
379    pub fn get<T: serde::de::DeserializeOwned>(&self, key: &str) -> Option<T> {
380        self.data.get(key).and_then(|v| serde_json::from_value(v.clone()).ok())
381    }
382
383    pub fn set(&mut self, key: impl Into<String>, value: impl serde::Serialize) {
384        if let Ok(v) = serde_json::to_value(value) {
385            self.data.insert(key.into(), v);
386        }
387    }
388}
389
390/// Base trait for multi-agent orchestration patterns.
391#[async_trait]
392pub trait MultiAgentBase: Send + Sync {
393    /// Returns the unique ID of this multi-agent system.
394    fn id(&self) -> &str;
395
396    /// Invokes the multi-agent system asynchronously.
397    async fn invoke_async(
398        &mut self,
399        task: MultiAgentInput,
400        invocation_state: Option<&InvocationState>,
401    ) -> Result<MultiAgentResult>;
402
403    /// Streams events during multi-agent execution.
404    fn stream_async<'a>(
405        &'a mut self,
406        task: MultiAgentInput,
407        invocation_state: Option<&'a InvocationState>,
408    ) -> MultiAgentEventStream<'a>;
409
410    /// Serializes the current state for persistence.
411    fn serialize_state(&self) -> serde_json::Value;
412
413    /// Deserializes state from persistence.
414    fn deserialize_state(&mut self, payload: &serde_json::Value) -> Result<()>;
415}
416
417/// State tracking for interrupt handling.
418#[derive(Debug, Clone, Default)]
419pub struct InterruptState {
420    pub activated: bool,
421    pub interrupts: HashMap<String, Interrupt>,
422    pub context: HashMap<String, serde_json::Value>,
423    pub responses: Option<serde_json::Value>,
424}
425
426impl InterruptState {
427    pub fn new() -> Self {
428        Self::default()
429    }
430
431    pub fn activate(&mut self) {
432        self.activated = true;
433    }
434
435    pub fn deactivate(&mut self) {
436        self.activated = false;
437        self.interrupts.clear();
438        self.context.clear();
439        self.responses = None;
440    }
441
442    pub fn resume(&mut self, responses: serde_json::Value) {
443        self.responses = Some(responses);
444    }
445
446    /// Add an interrupt to the state.
447    pub fn add(&mut self, interrupt: Interrupt) {
448        self.interrupts.insert(interrupt.id.clone(), interrupt);
449    }
450
451    /// Serialize to a dictionary for session persistence.
452    pub fn to_dict(&self) -> HashMap<String, serde_json::Value> {
453        let mut dict = HashMap::new();
454        dict.insert("activated".to_string(), serde_json::json!(self.activated));
455        dict.insert(
456            "interrupts".to_string(),
457            serde_json::json!(self.interrupts
458                .iter()
459                .map(|(k, v)| (k.clone(), serde_json::json!({
460                    "id": v.id,
461                    "tool_name": v.tool_name,
462                    "tool_use_id": v.tool_use_id,
463                    "message": v.message,
464                    "response": v.response,
465                })))
466                .collect::<HashMap<_, _>>()),
467        );
468        dict.insert("context".to_string(), serde_json::json!(self.context));
469        dict.insert("responses".to_string(), serde_json::json!(self.responses));
470        dict
471    }
472
473    /// Restore from a dictionary.
474    pub fn from_dict(data: HashMap<String, serde_json::Value>) -> Self {
475        let activated = data
476            .get("activated")
477            .and_then(|v| v.as_bool())
478            .unwrap_or(false);
479
480        let interrupts = data
481            .get("interrupts")
482            .and_then(|v| v.as_object())
483            .map(|obj| {
484                obj.iter()
485                    .filter_map(|(k, v)| {
486                        let id = v.get("id")?.as_str()?.to_string();
487                        let tool_name = v.get("tool_name")?.as_str()?.to_string();
488                        let tool_use_id = v.get("tool_use_id")?.as_str()?.to_string();
489                        let message = v.get("message").and_then(|m| m.as_str().map(|s| s.to_string()));
490                        let response = v.get("response").cloned();
491                        Some((k.clone(), Interrupt {
492                            id,
493                            tool_name,
494                            tool_use_id,
495                            message,
496                            response,
497                        }))
498                    })
499                    .collect()
500            })
501            .unwrap_or_default();
502
503        let context = data
504            .get("context")
505            .and_then(|v| v.as_object())
506            .map(|obj| obj.iter().map(|(k, v)| (k.clone(), v.clone())).collect())
507            .unwrap_or_default();
508
509        let responses = data.get("responses").cloned();
510
511        Self {
512            activated,
513            interrupts,
514            context,
515            responses,
516        }
517    }
518}
519
520#[cfg(test)]
521mod tests {
522    use super::*;
523
524    #[test]
525    fn test_status_default() {
526        assert_eq!(Status::default(), Status::Pending);
527    }
528
529    #[test]
530    fn test_multi_agent_input_from_str() {
531        let input = MultiAgentInput::from("test task");
532        assert_eq!(input.as_text(), Some("test task"));
533    }
534
535    #[test]
536    fn test_interrupt_creation() {
537        let interrupt = Interrupt::new("int-1", "my_tool", "tu-1")
538            .with_message("Please provide more info");
539        assert_eq!(interrupt.id, "int-1");
540        assert_eq!(interrupt.message, Some("Please provide more info".to_string()));
541    }
542
543    #[test]
544    fn test_multi_agent_event_variants() {
545        let event = MultiAgentEvent::node_start("node1", "agent");
546        assert!(!event.is_result());
547
548        let result = MultiAgentResult::new();
549        let event = MultiAgentEvent::result(result);
550        assert!(event.is_result());
551    }
552}
553