walrus_core/runtime/
mod.rs1use crate::{
9 Agent, AgentBuilder, AgentConfig, AgentEvent, AgentResponse, AgentStopReason,
10 agent::tool::{Handler, ToolRegistry},
11 model::{Message, Model, Tool},
12 runtime::hook::Hook,
13};
14use anyhow::Result;
15use async_stream::stream;
16use compact_str::CompactString;
17use futures_core::Stream;
18use futures_util::StreamExt;
19use std::{collections::BTreeMap, sync::Arc};
20use tokio::sync::{Mutex, RwLock, mpsc};
21
22pub mod hook;
23
24pub struct Runtime<M: Model, H: Hook> {
31 pub model: M,
32 pub hook: H,
33 agents: BTreeMap<CompactString, Arc<Mutex<Agent<M>>>>,
34 tools: Arc<RwLock<ToolRegistry>>,
35}
36
37impl<M: Model + Send + Sync + Clone + 'static, H: Hook + 'static> Runtime<M, H> {
38 pub async fn new(model: M, hook: H) -> Self {
43 let mut registry = ToolRegistry::new();
44 hook.on_register_tools(&mut registry).await;
45 Self {
46 model,
47 hook,
48 agents: BTreeMap::new(),
49 tools: Arc::new(RwLock::new(registry)),
50 }
51 }
52
53 pub async fn register_tool(&self, tool: Tool, handler: Handler) {
60 self.tools.write().await.insert(tool, handler);
61 }
62
63 pub async fn unregister_tool(&self, name: &str) -> bool {
65 self.tools.write().await.remove(name)
66 }
67
68 pub async fn replace_tools(
73 &self,
74 old_names: &[CompactString],
75 new_tools: Vec<(Tool, Handler)>,
76 ) {
77 let mut registry = self.tools.write().await;
78 for name in old_names {
79 registry.remove(name);
80 }
81 for (tool, handler) in new_tools {
82 registry.insert(tool, handler);
83 }
84 }
85
86 async fn dispatcher_for(&self, agent: &str) -> ToolRegistry {
91 let registry = self.tools.read().await;
92
93 let filter: Vec<CompactString> = self
94 .agents
95 .get(agent)
96 .and_then(|m| m.try_lock().ok())
97 .map(|g| g.config.tools.to_vec())
98 .unwrap_or_default();
99
100 registry.filtered_snapshot(&filter)
101 }
102
103 pub fn add_agent(&mut self, config: AgentConfig) {
110 let config = self.hook.on_build_agent(config);
111 let name = config.name.clone();
112 let agent = AgentBuilder::new(self.model.clone()).config(config).build();
113 self.agents.insert(name, Arc::new(Mutex::new(agent)));
114 }
115
116 pub async fn agent(&self, name: &str) -> Option<AgentConfig> {
118 let mutex = self.agents.get(name)?;
119 Some(mutex.lock().await.config.clone())
120 }
121
122 pub async fn agents(&self) -> Vec<AgentConfig> {
124 let mut configs = Vec::with_capacity(self.agents.len());
125 for mutex in self.agents.values() {
126 configs.push(mutex.lock().await.config.clone());
127 }
128 configs
129 }
130
131 pub fn agent_mutex(&self, name: &str) -> Option<Arc<Mutex<Agent<M>>>> {
133 self.agents.get(name).cloned()
134 }
135
136 pub async fn clear_session(&self, agent: &str) {
138 if let Some(mutex) = self.agents.get(agent) {
139 mutex.lock().await.clear_history();
140 }
141 }
142
143 pub async fn send_to(&self, agent: &str, content: &str) -> Result<AgentResponse> {
151 let mutex = self
152 .agents
153 .get(agent)
154 .ok_or_else(|| anyhow::anyhow!("agent '{agent}' not registered"))?;
155
156 let dispatcher = self.dispatcher_for(agent).await;
157 let mut guard = mutex.lock().await;
158 guard.push_message(Message::user(content));
159
160 let (tx, mut rx) = mpsc::unbounded_channel();
161 let response = guard.run(&dispatcher, tx).await;
162
163 while let Ok(event) = rx.try_recv() {
164 self.hook.on_event(agent, &event);
165 }
166
167 Ok(response)
168 }
169
170 pub fn stream_to<'a>(
176 &'a self,
177 agent: &'a str,
178 content: &'a str,
179 ) -> impl Stream<Item = AgentEvent> + 'a {
180 stream! {
181 let mutex = match self.agents.get(agent) {
182 Some(m) => m,
183 None => {
184 let resp = AgentResponse {
185 final_response: None,
186 iterations: 0,
187 stop_reason: AgentStopReason::Error(
188 format!("agent '{agent}' not registered"),
189 ),
190 steps: vec![],
191 };
192 yield AgentEvent::Done(resp);
193 return;
194 }
195 };
196
197 let dispatcher = self.dispatcher_for(agent).await;
198 let mut guard = mutex.lock().await;
199 guard.push_message(Message::user(content));
200
201 let mut event_stream = std::pin::pin!(guard.run_stream(&dispatcher));
202 while let Some(event) = event_stream.next().await {
203 self.hook.on_event(agent, &event);
204 yield event;
205 }
206 }
207 }
208}