sayr_engine/
team.rs

1use std::collections::BTreeMap;
2use std::sync::Arc;
3
4use serde_json::Value;
5use tokio::sync::{broadcast, Mutex, RwLock};
6
7use crate::agent::Agent;
8use crate::memory::ConversationMemory;
9use crate::{LanguageModel, Result};
10
11/// Events emitted by the team bus.
12#[derive(Debug, Clone)]
13pub enum TeamEvent {
14    Broadcast { from: String, content: String },
15    KnowledgeAdded(String),
16}
17
18/// A coordination surface for multiple agents that share context and a message bus.
19pub struct Team<M: LanguageModel> {
20    name: String,
21    members: BTreeMap<String, Arc<Mutex<Agent<M>>>>,
22    shared_memory: Arc<RwLock<ConversationMemory>>,
23    shared_context: Arc<RwLock<Value>>,
24    knowledge: Arc<RwLock<Vec<String>>>,
25    tx: broadcast::Sender<TeamEvent>,
26}
27
28impl<M: LanguageModel> Clone for Team<M> {
29    fn clone(&self) -> Self {
30        Self {
31            name: self.name.clone(),
32            members: self.members.clone(),
33            shared_memory: Arc::clone(&self.shared_memory),
34            shared_context: Arc::clone(&self.shared_context),
35            knowledge: Arc::clone(&self.knowledge),
36            tx: self.tx.clone(),
37        }
38    }
39}
40
41impl<M: LanguageModel> Team<M> {
42    /// Create an empty team with a broadcast bus and shared memory.
43    pub fn new(name: impl Into<String>) -> Self {
44        let (tx, _) = broadcast::channel(128);
45        Self {
46            name: name.into(),
47            members: BTreeMap::new(),
48            shared_memory: Arc::new(RwLock::new(ConversationMemory::default())),
49            shared_context: Arc::new(RwLock::new(Value::Null)),
50            knowledge: Arc::new(RwLock::new(Vec::new())),
51            tx,
52        }
53    }
54
55    pub fn name(&self) -> &str {
56        &self.name
57    }
58
59    /// Register a new agent under the given identifier.
60    pub fn add_agent(&mut self, id: impl Into<String>, agent: Agent<M>) {
61        self.members.insert(id.into(), Arc::new(Mutex::new(agent)));
62    }
63
64    /// Number of registered agents.
65    pub fn size(&self) -> usize {
66        self.members.len()
67    }
68
69    /// Subscribe to the broadcast bus for inter-agent notifications.
70    pub fn subscribe(&self) -> broadcast::Receiver<TeamEvent> {
71        self.tx.subscribe()
72    }
73
74    /// Append shared knowledge that all agents can reference.
75    pub async fn add_knowledge(&self, fact: impl Into<String>) {
76        let fact = fact.into();
77        self.knowledge.write().await.push(fact.clone());
78        let _ = self.tx.send(TeamEvent::KnowledgeAdded(fact));
79    }
80
81    /// Update the shared context blob (typically JSON state shared across steps).
82    pub async fn set_context(&self, ctx: Value) {
83        *self.shared_context.write().await = ctx;
84    }
85
86    /// Retrieve a copy of the shared context.
87    pub async fn context(&self) -> Value {
88        self.shared_context.read().await.clone()
89    }
90
91    /// Send a broadcast message to all listeners and append to shared memory.
92    pub async fn broadcast(&self, from: impl Into<String>, content: impl Into<String>) {
93        let from = from.into();
94        let content = content.into();
95        if let Ok(mut memory) = self.shared_memory.try_write() {
96            memory.push(crate::message::Message::assistant(format!(
97                "[{from}] {content}"
98            )));
99        }
100        let _ = self.tx.send(TeamEvent::Broadcast { from, content });
101    }
102
103    /// Run the same prompt through every agent, synchronizing memory back into the shared
104    /// transcript after each response. Returns agent replies in registration order.
105    pub async fn fan_out(&self, prompt: &str) -> Result<Vec<(String, String)>> {
106        let mut replies = Vec::new();
107        for (id, agent) in &self.members {
108            let mut guard = agent.lock().await;
109            // Share the latest transcript with the agent before it responds.
110            let snapshot = { self.shared_memory.read().await.clone() };
111            guard.sync_memory_from(&snapshot);
112            let reply = guard.respond(prompt).await?;
113            // Persist the updated transcript back into the shared memory.
114            let updated = guard.take_memory_snapshot();
115            *self.shared_memory.write().await = updated;
116            replies.push((id.clone(), reply));
117        }
118        Ok(replies)
119    }
120}
121
122#[cfg(test)]
123mod tests {
124    use super::*;
125    use crate::{Agent, StubModel};
126
127    #[tokio::test]
128    async fn runs_agents_with_shared_memory() {
129        let a_model = StubModel::new(vec![
130            r#"{"action":"respond","content":"a1"}"#.into(),
131            r#"{"action":"respond","content":"a2"}"#.into(),
132        ]);
133        let b_model = StubModel::new(vec![
134            r#"{"action":"respond","content":"b1"}"#.into(),
135            r#"{"action":"respond","content":"b2"}"#.into(),
136        ]);
137
138        let mut team = Team::new("demo");
139        team.add_agent("alpha", Agent::new(a_model));
140        team.add_agent("beta", Agent::new(b_model));
141
142        let replies = team.fan_out("hello world").await.unwrap();
143        assert_eq!(replies.len(), 2);
144        assert_eq!(replies[0].1, "a1");
145        assert_eq!(replies[1].1, "b1");
146
147        // second pass reads and writes the shared transcript again
148        let replies = team.fan_out("follow up").await.unwrap();
149        assert_eq!(replies[0].1, "a2");
150        assert_eq!(replies[1].1, "b2");
151    }
152}