Skip to main content

walrus_daemon/daemon/
protocol.rs

1//! Server trait implementation for the Daemon.
2
3use crate::daemon::Daemon;
4use anyhow::{Context, Result};
5use futures_util::{StreamExt, pin_mut};
6use std::sync::Arc;
7use wcore::AgentEvent;
8use wcore::protocol::{
9    api::Server,
10    message::{
11        DownloadEvent, DownloadInfo, HubAction, SendMsg, SendResponse, SessionInfo, StreamChunk,
12        StreamEnd, StreamEvent, StreamMsg, StreamStart, StreamThinking, TaskEvent, TaskInfo,
13        ToolCallInfo, ToolResultEvent, ToolStartEvent, ToolsCompleteEvent, stream_event,
14    },
15};
16
17impl Server for Daemon {
18    async fn send(&self, req: SendMsg) -> Result<SendResponse> {
19        let rt: Arc<_> = self.runtime.read().await.clone();
20        let sender = req.sender.as_deref().unwrap_or("");
21        let created_by = if sender.is_empty() { "user" } else { sender };
22        let (session_id, is_new) = match req.session {
23            Some(id) => (id, false),
24            None => (rt.create_session(&req.agent, created_by).await?, true),
25        };
26        let response = rt.send_to(session_id, &req.content, sender).await?;
27        if is_new {
28            rt.close_session(session_id).await;
29        }
30        Ok(SendResponse {
31            agent: req.agent,
32            content: response.final_response.unwrap_or_default(),
33            session: session_id,
34        })
35    }
36
37    fn stream(
38        &self,
39        req: StreamMsg,
40    ) -> impl futures_core::Stream<Item = Result<StreamEvent>> + Send {
41        let runtime = self.runtime.clone();
42        let agent = req.agent;
43        let content = req.content;
44        let req_session = req.session;
45        let sender = req.sender.unwrap_or_default();
46        async_stream::try_stream! {
47            let rt: Arc<_> = runtime.read().await.clone();
48            let created_by = if sender.is_empty() { "user".into() } else { sender.clone() };
49            let (session_id, is_new) = match req_session {
50                Some(id) => (id, false),
51                None => (rt.create_session(&agent, created_by.as_str()).await?, true),
52            };
53
54            yield StreamEvent { event: Some(stream_event::Event::Start(StreamStart { agent: agent.clone(), session: session_id })) };
55
56            let stream = rt.stream_to(session_id, &content, &sender);
57            pin_mut!(stream);
58            while let Some(event) = stream.next().await {
59                match event {
60                    AgentEvent::TextDelta(text) => {
61                        yield StreamEvent { event: Some(stream_event::Event::Chunk(StreamChunk { content: text })) };
62                    }
63                    AgentEvent::ThinkingDelta(text) => {
64                        yield StreamEvent { event: Some(stream_event::Event::Thinking(StreamThinking { content: text })) };
65                    }
66                    AgentEvent::ToolCallsStart(calls) => {
67                        yield StreamEvent { event: Some(stream_event::Event::ToolStart(ToolStartEvent {
68                            calls: calls.into_iter().map(|c| ToolCallInfo {
69                                name: c.function.name.to_string(),
70                                arguments: c.function.arguments,
71                            }).collect(),
72                        })) };
73                    }
74                    AgentEvent::ToolResult { call_id, output } => {
75                        yield StreamEvent { event: Some(stream_event::Event::ToolResult(ToolResultEvent { call_id: call_id.to_string(), output })) };
76                    }
77                    AgentEvent::ToolCallsComplete => {
78                        yield StreamEvent { event: Some(stream_event::Event::ToolsComplete(ToolsCompleteEvent {})) };
79                    }
80                    AgentEvent::Compact { .. } => {
81                        // Compact events are handled by on_event in the hook layer.
82                    }
83                    AgentEvent::Done(resp) => {
84                        if let wcore::AgentStopReason::Error(e) = &resp.stop_reason {
85                            if is_new {
86                                rt.close_session(session_id).await;
87                            }
88                            Err(anyhow::anyhow!("{e}"))?;
89                        }
90                        break;
91                    }
92                }
93            }
94            yield StreamEvent { event: Some(stream_event::Event::End(StreamEnd { agent: agent.clone() })) };
95        }
96    }
97
98    async fn ping(&self) -> Result<()> {
99        Ok(())
100    }
101
102    async fn list_sessions(&self) -> Result<Vec<SessionInfo>> {
103        let rt = self.runtime.read().await.clone();
104        let sessions = rt.sessions().await;
105        let mut infos = Vec::with_capacity(sessions.len());
106        for s in sessions {
107            let s = s.lock().await;
108            infos.push(SessionInfo {
109                id: s.id,
110                agent: s.agent.to_string(),
111                created_by: s.created_by.to_string(),
112                message_count: s.history.len() as u64,
113                alive_secs: s.created_at.elapsed().as_secs(),
114            });
115        }
116        Ok(infos)
117    }
118
119    async fn kill_session(&self, session: u64) -> Result<bool> {
120        let rt = self.runtime.read().await.clone();
121        Ok(rt.close_session(session).await)
122    }
123
124    async fn list_tasks(&self) -> Result<Vec<TaskInfo>> {
125        let rt = self.runtime.read().await.clone();
126        let registry = rt.hook.tasks.lock().await;
127        let tasks = registry.list(None, None, None);
128        Ok(tasks
129            .into_iter()
130            .map(|t| crate::hook::system::task::TaskRegistry::task_info(t))
131            .collect())
132    }
133
134    async fn kill_task(&self, task_id: u64) -> Result<bool> {
135        use crate::hook::system::task::TaskStatus;
136        let rt = self.runtime.read().await.clone();
137        let tasks = rt.hook.tasks.clone();
138        let mut registry = tasks.lock().await;
139        let Some(task) = registry.get(task_id) else {
140            return Ok(false);
141        };
142        match task.status {
143            TaskStatus::InProgress | TaskStatus::Blocked => {
144                let session_id = task.session_id;
145                registry.kill(task_id);
146                crate::hook::system::task::tool::try_promote(
147                    &mut registry,
148                    tasks.clone(),
149                    rt.hook.event_tx.clone(),
150                    rt.hook.task_timeout,
151                );
152                // Close associated session outside the lock.
153                if let Some(sid) = session_id {
154                    drop(registry);
155                    rt.close_session(sid).await;
156                }
157                Ok(true)
158            }
159            TaskStatus::Queued => {
160                registry.remove(task_id);
161                Ok(true)
162            }
163            _ => Ok(false),
164        }
165    }
166
167    async fn approve_task(&self, task_id: u64, response: String) -> Result<bool> {
168        let rt = self.runtime.read().await.clone();
169        let mut registry = rt.hook.tasks.lock().await;
170        Ok(registry.approve(task_id, response))
171    }
172
173    fn hub(
174        &self,
175        package: String,
176        action: HubAction,
177        filters: Vec<String>,
178    ) -> impl futures_core::Stream<Item = Result<DownloadEvent>> + Send {
179        let runtime = self.runtime.clone();
180        async_stream::try_stream! {
181            let rt = runtime.read().await.clone();
182            let registry = rt.hook.downloads.clone();
183            let package = compact_str::CompactString::from(package.as_str());
184            match action {
185                HubAction::Install => {
186                    let s = crate::ext::hub::package::install(package, registry, filters);
187                    pin_mut!(s);
188                    while let Some(event) = s.next().await {
189                        yield event?;
190                    }
191                }
192                HubAction::Uninstall => {
193                    let s = crate::ext::hub::package::uninstall(package, registry, filters);
194                    pin_mut!(s);
195                    while let Some(event) = s.next().await {
196                        yield event?;
197                    }
198                }
199            }
200        }
201    }
202
203    fn subscribe_tasks(&self) -> impl futures_core::Stream<Item = Result<TaskEvent>> + Send {
204        let runtime = self.runtime.clone();
205        async_stream::try_stream! {
206            let rt = runtime.read().await.clone();
207            let mut rx = rt.hook.tasks.lock().await.subscribe();
208            loop {
209                match rx.recv().await {
210                    Ok(event) => yield event,
211                    Err(tokio::sync::broadcast::error::RecvError::Closed) => break,
212                    Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => continue,
213                }
214            }
215        }
216    }
217
218    async fn list_downloads(&self) -> Result<Vec<DownloadInfo>> {
219        let rt = self.runtime.read().await.clone();
220        let registry = rt.hook.downloads.lock().await;
221        Ok(registry.list())
222    }
223
224    fn subscribe_downloads(
225        &self,
226    ) -> impl futures_core::Stream<Item = Result<DownloadEvent>> + Send {
227        let runtime = self.runtime.clone();
228        async_stream::try_stream! {
229            let rt = runtime.read().await.clone();
230            let mut rx = rt.hook.downloads.lock().await.subscribe();
231            loop {
232                match rx.recv().await {
233                    Ok(event) => yield event,
234                    Err(tokio::sync::broadcast::error::RecvError::Closed) => break,
235                    Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => continue,
236                }
237            }
238        }
239    }
240
241    async fn get_config(&self) -> Result<String> {
242        let config = self.load_config()?;
243        serde_json::to_string(&config).context("failed to serialize config")
244    }
245
246    async fn set_config(&self, config: String) -> Result<()> {
247        let parsed: crate::DaemonConfig =
248            serde_json::from_str(&config).context("invalid DaemonConfig JSON")?;
249        let toml_str =
250            toml::to_string_pretty(&parsed).context("failed to serialize config to TOML")?;
251        let config_path = self.config_dir.join("walrus.toml");
252        std::fs::write(&config_path, toml_str)
253            .with_context(|| format!("failed to write {}", config_path.display()))?;
254        self.reload().await
255    }
256
257    async fn service_query(&self, service: String, query: String) -> Result<String> {
258        let rt = self.runtime.read().await.clone();
259        let registry = rt
260            .hook
261            .registry
262            .as_ref()
263            .ok_or_else(|| anyhow::anyhow!("no service registry"))?;
264        let handle = registry
265            .query
266            .get(&service)
267            .ok_or_else(|| anyhow::anyhow!("service '{}' not available", service))?;
268        let req = wcore::protocol::ext::ExtRequest {
269            msg: Some(wcore::protocol::ext::ext_request::Msg::ServiceQuery(
270                wcore::protocol::ext::ExtServiceQuery { query },
271            )),
272        };
273        let resp = handle.request(&req).await?;
274        match resp.msg {
275            Some(wcore::protocol::ext::ext_response::Msg::ServiceQueryResult(result)) => {
276                Ok(result.result)
277            }
278            Some(wcore::protocol::ext::ext_response::Msg::Error(e)) => {
279                anyhow::bail!("service '{}' error: {}", service, e.message)
280            }
281            other => anyhow::bail!("unexpected response from service '{}': {other:?}", service),
282        }
283    }
284
285    async fn get_service_schema(&self, service: String) -> Result<String> {
286        let rt = self.runtime.read().await.clone();
287        let registry = rt
288            .hook
289            .registry
290            .as_ref()
291            .ok_or_else(|| anyhow::anyhow!("no service registry"))?;
292        let handle = registry
293            .query
294            .get(&service)
295            .or_else(|| registry.tools.values().find(|h| h.name.as_str() == service))
296            .ok_or_else(|| anyhow::anyhow!("service '{}' not found", service))?;
297        let req = wcore::protocol::ext::ExtRequest {
298            msg: Some(wcore::protocol::ext::ext_request::Msg::GetSchema(
299                wcore::protocol::ext::ExtGetSchema {},
300            )),
301        };
302        let resp = handle.request(&req).await?;
303        match resp.msg {
304            Some(wcore::protocol::ext::ext_response::Msg::SchemaResult(result)) => {
305                Ok(result.schema)
306            }
307            Some(wcore::protocol::ext::ext_response::Msg::Error(e)) => {
308                anyhow::bail!("service '{}' schema error: {}", service, e.message)
309            }
310            other => anyhow::bail!(
311                "unexpected schema response from service '{}': {other:?}",
312                service
313            ),
314        }
315    }
316
317    async fn get_all_schemas(&self) -> Result<std::collections::HashMap<String, String>> {
318        let rt = self.runtime.read().await.clone();
319        let registry = rt
320            .hook
321            .registry
322            .as_ref()
323            .ok_or_else(|| anyhow::anyhow!("no service registry"))?;
324        let mut schemas = std::collections::HashMap::new();
325        // Collect unique service handles from the query registry.
326        for (name, handle) in &registry.query {
327            let req = wcore::protocol::ext::ExtRequest {
328                msg: Some(wcore::protocol::ext::ext_request::Msg::GetSchema(
329                    wcore::protocol::ext::ExtGetSchema {},
330                )),
331            };
332            if let Ok(resp) = handle.request(&req).await
333                && let Some(wcore::protocol::ext::ext_response::Msg::SchemaResult(result)) =
334                    resp.msg
335            {
336                schemas.insert(name.clone(), result.schema);
337            }
338        }
339        Ok(schemas)
340    }
341
342    async fn list_services(&self) -> Result<Vec<wcore::protocol::message::ServiceInfoMsg>> {
343        let rt = self.runtime.read().await.clone();
344        let registry = rt.hook.registry.as_ref();
345        let mut services = Vec::new();
346        if let Some(reg) = registry {
347            // Collect unique service names from all capability buckets.
348            let mut seen = std::collections::HashSet::new();
349            let all_handles: Vec<_> = reg
350                .build_agent
351                .iter()
352                .chain(reg.before_run.iter())
353                .chain(reg.compact.iter())
354                .chain(reg.event_observer.iter())
355                .chain(reg.after_run.iter())
356                .chain(reg.after_compact.iter())
357                .chain(reg.query.values())
358                .chain(reg.tools.values())
359                .collect();
360            for handle in all_handles {
361                let name = handle.name.to_string();
362                if !seen.insert(name.clone()) {
363                    continue;
364                }
365                let capabilities: Vec<String> = handle
366                    .capabilities
367                    .iter()
368                    .filter_map(|c| match &c.cap {
369                        Some(wcore::protocol::ext::capability::Cap::Tools(_)) => {
370                            Some("tools".into())
371                        }
372                        Some(wcore::protocol::ext::capability::Cap::Query(_)) => {
373                            Some("query".into())
374                        }
375                        Some(wcore::protocol::ext::capability::Cap::BuildAgent(_)) => {
376                            Some("build_agent".into())
377                        }
378                        Some(wcore::protocol::ext::capability::Cap::BeforeRun(_)) => {
379                            Some("before_run".into())
380                        }
381                        Some(wcore::protocol::ext::capability::Cap::Compact(_)) => {
382                            Some("compact".into())
383                        }
384                        Some(wcore::protocol::ext::capability::Cap::EventObserver(_)) => {
385                            Some("event_observer".into())
386                        }
387                        Some(wcore::protocol::ext::capability::Cap::AfterRun(_)) => {
388                            Some("after_run".into())
389                        }
390                        Some(wcore::protocol::ext::capability::Cap::Infer(_)) => {
391                            Some("infer".into())
392                        }
393                        Some(wcore::protocol::ext::capability::Cap::AfterCompact(_)) => {
394                            Some("after_compact".into())
395                        }
396                        None => None,
397                    })
398                    .collect();
399                services.push(wcore::protocol::message::ServiceInfoMsg {
400                    name,
401                    kind: "extension".into(),
402                    status: "running".into(),
403                    capabilities,
404                    has_config: true,
405                });
406            }
407        }
408        Ok(services)
409    }
410
411    async fn set_service_config(&self, service: String, config: String) -> Result<()> {
412        let mut daemon_config = self.load_config()?;
413        let svc = daemon_config
414            .services
415            .get_mut(&service)
416            .ok_or_else(|| anyhow::anyhow!("service '{}' not found in config", service))?;
417        let parsed: serde_json::Value =
418            serde_json::from_str(&config).context("invalid service config JSON")?;
419        svc.config = parsed;
420        let toml_str =
421            toml::to_string_pretty(&daemon_config).context("failed to serialize config to TOML")?;
422        let config_path = self.config_dir.join("walrus.toml");
423        std::fs::write(&config_path, toml_str)
424            .with_context(|| format!("failed to write {}", config_path.display()))?;
425        self.reload().await
426    }
427
428    async fn reload(&self) -> Result<()> {
429        self.reload().await
430    }
431}
432
433impl Daemon {
434    /// Load the current `DaemonConfig` from disk.
435    fn load_config(&self) -> Result<crate::DaemonConfig> {
436        crate::DaemonConfig::load(&self.config_dir.join("walrus.toml"))
437    }
438}