walrus_core/runtime/
mod.rs1use crate::{
10 Agent, AgentBuilder, AgentConfig, AgentEvent, AgentResponse, AgentStopReason,
11 agent::tool::{ToolRegistry, ToolSender},
12 model::{Message, Model},
13 runtime::hook::Hook,
14};
15use anyhow::{Result, bail};
16use async_stream::stream;
17use compact_str::CompactString;
18use futures_core::Stream;
19use futures_util::StreamExt;
20use std::{
21 collections::BTreeMap,
22 sync::{
23 Arc,
24 atomic::{AtomicU64, Ordering},
25 },
26};
27use tokio::sync::{Mutex, RwLock, mpsc};
28
29pub mod hook;
30pub mod session;
31
32pub use session::Session;
33
34pub struct Runtime<M: Model, H: Hook> {
40 pub model: M,
41 pub hook: H,
42 agents: BTreeMap<CompactString, Agent<M>>,
43 sessions: RwLock<BTreeMap<u64, Arc<Mutex<Session>>>>,
44 next_session_id: AtomicU64,
45 tools: ToolRegistry,
46 tool_tx: Option<ToolSender>,
47}
48
49impl<M: Model + Send + Sync + Clone + 'static, H: Hook + 'static> Runtime<M, H> {
50 pub async fn new(model: M, hook: H, tool_tx: Option<ToolSender>) -> Self {
56 let mut tools = ToolRegistry::new();
57 hook.on_register_tools(&mut tools).await;
58 Self {
59 model,
60 hook,
61 agents: BTreeMap::new(),
62 sessions: RwLock::new(BTreeMap::new()),
63 next_session_id: AtomicU64::new(1),
64 tools,
65 tool_tx,
66 }
67 }
68
69 pub fn register_tool(&mut self, tool: crate::model::Tool) {
73 self.tools.insert(tool);
74 }
75
76 pub fn unregister_tool(&mut self, name: &str) -> bool {
78 self.tools.remove(name)
79 }
80
81 pub fn add_agent(&mut self, config: AgentConfig) {
88 let config = self.hook.on_build_agent(config);
89 let name = config.name.clone();
90 let tools = self.tools.filtered_snapshot(&config.tools);
91 let mut builder = AgentBuilder::new(self.model.clone())
92 .config(config)
93 .tools(tools);
94 if let Some(tx) = &self.tool_tx {
95 builder = builder.tool_tx(tx.clone());
96 }
97 let agent = builder.build();
98 self.agents.insert(name, agent);
99 }
100
101 pub fn agent(&self, name: &str) -> Option<AgentConfig> {
103 self.agents.get(name).map(|a| a.config.clone())
104 }
105
106 pub fn agents(&self) -> Vec<AgentConfig> {
108 self.agents.values().map(|a| a.config.clone()).collect()
109 }
110
111 pub fn get_agent(&self, name: &str) -> Option<&Agent<M>> {
113 self.agents.get(name)
114 }
115
116 pub async fn create_session(&self, agent: &str, created_by: &str) -> Result<u64> {
120 if !self.agents.contains_key(agent) {
121 bail!("agent '{agent}' not registered");
122 }
123 let id = self.next_session_id.fetch_add(1, Ordering::Relaxed);
124 let session = Session::new(id, agent, created_by);
125 self.sessions
126 .write()
127 .await
128 .insert(id, Arc::new(Mutex::new(session)));
129 Ok(id)
130 }
131
132 pub async fn close_session(&self, id: u64) -> bool {
134 self.sessions.write().await.remove(&id).is_some()
135 }
136
137 pub async fn session(&self, id: u64) -> Option<Arc<Mutex<Session>>> {
139 self.sessions.read().await.get(&id).cloned()
140 }
141
142 pub async fn sessions(&self) -> Vec<Arc<Mutex<Session>>> {
144 self.sessions.read().await.values().cloned().collect()
145 }
146
147 pub async fn send_to(
154 &self,
155 session_id: u64,
156 content: &str,
157 sender: &str,
158 ) -> Result<AgentResponse> {
159 let session_mutex = self
160 .sessions
161 .read()
162 .await
163 .get(&session_id)
164 .cloned()
165 .ok_or_else(|| anyhow::anyhow!("session {session_id} not found"))?;
166
167 let mut session = session_mutex.lock().await;
168 let agent_ref = self
169 .agents
170 .get(&session.agent)
171 .ok_or_else(|| anyhow::anyhow!("agent '{}' not registered", session.agent))?;
172
173 if sender.is_empty() {
174 session.history.push(Message::user(content));
175 } else {
176 session
177 .history
178 .push(Message::user_with_sender(content, sender));
179 }
180
181 let (tx, mut rx) = mpsc::unbounded_channel();
182 let agent_name = session.agent.clone();
183 let response = agent_ref.run(&mut session.history, tx).await;
184
185 while let Ok(event) = rx.try_recv() {
186 self.hook.on_event(&agent_name, &event);
187 }
188
189 Ok(response)
190 }
191
192 pub fn stream_to(
197 &self,
198 session_id: u64,
199 content: &str,
200 sender: &str,
201 ) -> impl Stream<Item = AgentEvent> + '_ {
202 let content = content.to_owned();
203 let sender = sender.to_owned();
204 stream! {
205 let session_mutex = match self
206 .sessions
207 .read()
208 .await
209 .get(&session_id)
210 .cloned()
211 {
212 Some(m) => m,
213 None => {
214 let resp = AgentResponse {
215 final_response: None,
216 iterations: 0,
217 stop_reason: AgentStopReason::Error(
218 format!("session {session_id} not found"),
219 ),
220 steps: vec![],
221 };
222 yield AgentEvent::Done(resp);
223 return;
224 }
225 };
226
227 let mut session = session_mutex.lock().await;
228 let agent_ref = match self.agents.get(&session.agent) {
229 Some(a) => a,
230 None => {
231 let resp = AgentResponse {
232 final_response: None,
233 iterations: 0,
234 stop_reason: AgentStopReason::Error(
235 format!("agent '{}' not registered", session.agent),
236 ),
237 steps: vec![],
238 };
239 yield AgentEvent::Done(resp);
240 return;
241 }
242 };
243
244 if sender.is_empty() {
245 session.history.push(Message::user(&content));
246 } else {
247 session.history.push(Message::user_with_sender(&content, &sender));
248 }
249 let agent_name = session.agent.clone();
250
251 let mut event_stream = std::pin::pin!(agent_ref.run_stream(&mut session.history));
252 while let Some(event) = event_stream.next().await {
253 self.hook.on_event(&agent_name, &event);
254 yield event;
255 }
256 }
257 }
258}