walrus_daemon/daemon/
protocol.rs1use crate::daemon::Daemon;
4use anyhow::{Result, bail};
5use futures_util::{StreamExt, pin_mut};
6use memory::Memory;
7use protocol::{
8 api::Server,
9 message::{
10 AgentDetail, AgentInfoRequest, AgentList, AgentSummary, ClearSessionRequest, DownloadEvent,
11 DownloadRequest, GetMemoryRequest, McpAddRequest, McpAdded, McpReloaded, McpRemoveRequest,
12 McpRemoved, McpServerList, McpServerSummary, MemoryEntry, MemoryList, SendRequest,
13 SendResponse, SessionCleared, SkillsReloaded, StreamEvent, StreamRequest,
14 },
15};
16use wcore::AgentEvent;
17
18impl Server for Daemon {
19 async fn send(&self, req: SendRequest) -> Result<SendResponse> {
20 let response = self.runtime.send_to(&req.agent, &req.content).await?;
21 Ok(SendResponse {
22 agent: req.agent,
23 content: response.final_response.unwrap_or_default(),
24 })
25 }
26
27 fn stream(
28 &self,
29 req: StreamRequest,
30 ) -> impl futures_core::Stream<Item = Result<StreamEvent>> + Send {
31 let runtime = self.runtime.clone();
32 let agent = req.agent;
33 let content = req.content;
34 async_stream::try_stream! {
35 yield StreamEvent::Start { agent: agent.clone() };
36
37 let stream = runtime.stream_to(&agent, &content);
38 pin_mut!(stream);
39 while let Some(event) = stream.next().await {
40 match event {
41 AgentEvent::TextDelta(text) => {
42 yield StreamEvent::Chunk { content: text };
43 }
44 AgentEvent::Done(_) => break,
45 _ => {}
46 }
47 }
48
49 yield StreamEvent::End { agent: agent.clone() };
50 }
51 }
52
53 async fn clear_session(&self, req: ClearSessionRequest) -> Result<SessionCleared> {
54 self.runtime.clear_session(&req.agent).await;
55 Ok(SessionCleared { agent: req.agent })
56 }
57
58 async fn list_agents(&self) -> Result<AgentList> {
59 let agents = self
60 .runtime
61 .agents()
62 .await
63 .into_iter()
64 .map(|a| AgentSummary {
65 name: a.name.clone(),
66 description: a.description.clone(),
67 })
68 .collect();
69 Ok(AgentList { agents })
70 }
71
72 async fn agent_info(&self, req: AgentInfoRequest) -> Result<AgentDetail> {
73 match self.runtime.agent(&req.agent).await {
74 Some(a) => Ok(AgentDetail {
75 name: a.name.clone(),
76 description: a.description.clone(),
77 tools: a.tools.to_vec(),
78 skill_tags: a.skill_tags.to_vec(),
79 system_prompt: a.system_prompt.clone(),
80 }),
81 None => bail!("agent not found: {}", req.agent),
82 }
83 }
84
85 async fn list_memory(&self) -> Result<MemoryList> {
86 let entries = self.runtime.hook.memory.entries();
87 Ok(MemoryList { entries })
88 }
89
90 async fn get_memory(&self, req: GetMemoryRequest) -> Result<MemoryEntry> {
91 let value = self.runtime.hook.memory.get(&req.key);
92 Ok(MemoryEntry {
93 key: req.key,
94 value,
95 })
96 }
97
98 fn download(
99 &self,
100 req: DownloadRequest,
101 ) -> impl futures_core::Stream<Item = Result<DownloadEvent>> + Send {
102 #[cfg(feature = "local")]
103 {
104 use tokio::sync::mpsc;
105 async_stream::try_stream! {
106 yield DownloadEvent::Start { model: req.model.clone() };
107
108 let (dtx, mut drx) = mpsc::unbounded_channel();
109 let model_str = req.model.to_string();
110 let download_handle = tokio::spawn(async move {
111 model::local::download::download_model(&model_str, dtx).await
112 });
113
114 while let Some(event) = drx.recv().await {
115 let dl_event = match event {
116 model::local::download::DownloadEvent::FileStart { filename, size } => {
117 DownloadEvent::FileStart { filename, size }
118 }
119 model::local::download::DownloadEvent::Progress { bytes } => {
120 DownloadEvent::Progress { bytes }
121 }
122 model::local::download::DownloadEvent::FileEnd { filename } => {
123 DownloadEvent::FileEnd { filename }
124 }
125 };
126 yield dl_event;
127 }
128
129 match download_handle.await {
130 Ok(Ok(())) => {
131 yield DownloadEvent::End { model: req.model };
132 }
133 Ok(Err(e)) => {
134 Err(anyhow::anyhow!("download failed: {e}"))?;
135 }
136 Err(e) => {
137 Err(anyhow::anyhow!("download task panicked: {e}"))?;
138 }
139 }
140 }
141 }
142 #[cfg(not(feature = "local"))]
143 {
144 let _ = req;
145 async_stream::stream! {
146 yield Err(anyhow::anyhow!("this daemon was built without local model support"));
147 }
148 }
149 }
150
151 async fn reload_skills(&self) -> Result<SkillsReloaded> {
152 let count = self.runtime.hook.skills.reload().await?;
153 tracing::info!("reloaded {count} skill(s)");
154 Ok(SkillsReloaded { count })
155 }
156
157 async fn mcp_add(&self, req: McpAddRequest) -> Result<McpAdded> {
158 let config = mcp::McpServerConfig {
159 name: req.name.clone(),
160 command: req.command,
161 args: req.args,
162 env: req.env,
163 auto_restart: true,
164 };
165 let tools = self.runtime.hook.mcp.add(config).await?;
166
167 for (tool, handler) in self.runtime.hook.mcp.tool_handlers().await {
169 if tools.iter().any(|t| t == &*tool.name) {
170 self.runtime.register_tool(tool, handler).await;
171 }
172 }
173
174 Ok(McpAdded {
175 name: req.name,
176 tools,
177 })
178 }
179
180 async fn mcp_remove(&self, req: McpRemoveRequest) -> Result<McpRemoved> {
181 let tools = self.runtime.hook.mcp.remove(&req.name).await?;
182
183 for tool_name in &tools {
185 self.runtime.unregister_tool(tool_name).await;
186 }
187
188 Ok(McpRemoved {
189 name: req.name,
190 tools,
191 })
192 }
193
194 async fn mcp_reload(&self) -> Result<McpReloaded> {
195 let old_tool_names: Vec<compact_str::CompactString> = self
197 .runtime
198 .hook
199 .mcp
200 .tool_handlers()
201 .await
202 .into_iter()
203 .map(|(t, _)| t.name)
204 .collect();
205
206 let servers = self
207 .runtime
208 .hook
209 .mcp
210 .reload(|path| {
211 let config = crate::DaemonConfig::load(path)?;
212 Ok(config.mcp_servers)
213 })
214 .await?;
215
216 let new_tools = self.runtime.hook.mcp.tool_handlers().await;
218 self.runtime.replace_tools(&old_tool_names, new_tools).await;
219
220 let servers = servers
221 .into_iter()
222 .map(|(name, tools)| McpServerSummary { name, tools })
223 .collect();
224 Ok(McpReloaded { servers })
225 }
226
227 async fn mcp_list(&self) -> Result<McpServerList> {
228 let servers = self
229 .runtime
230 .hook
231 .mcp
232 .list()
233 .await
234 .into_iter()
235 .map(|(name, tools)| McpServerSummary { name, tools })
236 .collect();
237 Ok(McpServerList { servers })
238 }
239
240 async fn ping(&self) -> Result<()> {
241 Ok(())
242 }
243}