Skip to main content

walrus_daemon/daemon/
protocol.rs

1//! Server trait implementation for the Daemon.
2
3use crate::daemon::Daemon;
4use anyhow::{Context, Result};
5use futures_util::{StreamExt, pin_mut};
6use std::sync::Arc;
7use wcore::AgentEvent;
8use wcore::protocol::{
9    api::Server,
10    message::{
11        DownloadEvent, DownloadInfo, HubAction, SendMsg, SendResponse, SessionInfo, StreamChunk,
12        StreamEnd, StreamEvent, StreamMsg, StreamStart, StreamThinking, TaskEvent, TaskInfo,
13        ToolCallInfo, ToolResultEvent, ToolStartEvent, ToolsCompleteEvent, stream_event,
14    },
15};
16
17impl Server for Daemon {
18    async fn send(&self, req: SendMsg) -> Result<SendResponse> {
19        let rt: Arc<_> = self.runtime.read().await.clone();
20        let sender = req.sender.as_deref().unwrap_or("");
21        let created_by = if sender.is_empty() { "user" } else { sender };
22        let (session_id, is_new) = match req.session {
23            Some(id) => (id, false),
24            None => (rt.create_session(&req.agent, created_by).await?, true),
25        };
26        let response = rt.send_to(session_id, &req.content, sender).await?;
27        if is_new {
28            rt.close_session(session_id).await;
29        }
30        Ok(SendResponse {
31            agent: req.agent,
32            content: response.final_response.unwrap_or_default(),
33            session: session_id,
34        })
35    }
36
37    fn stream(
38        &self,
39        req: StreamMsg,
40    ) -> impl futures_core::Stream<Item = Result<StreamEvent>> + Send {
41        let runtime = self.runtime.clone();
42        let agent = req.agent;
43        let content = req.content;
44        let req_session = req.session;
45        let sender = req.sender.unwrap_or_default();
46        async_stream::try_stream! {
47            let rt: Arc<_> = runtime.read().await.clone();
48            let created_by = if sender.is_empty() { "user".into() } else { sender.clone() };
49            let (session_id, is_new) = match req_session {
50                Some(id) => (id, false),
51                None => (rt.create_session(&agent, created_by.as_str()).await?, true),
52            };
53
54            yield StreamEvent { event: Some(stream_event::Event::Start(StreamStart { agent: agent.clone(), session: session_id })) };
55
56            let stream = rt.stream_to(session_id, &content, &sender);
57            pin_mut!(stream);
58            while let Some(event) = stream.next().await {
59                match event {
60                    AgentEvent::TextDelta(text) => {
61                        yield StreamEvent { event: Some(stream_event::Event::Chunk(StreamChunk { content: text })) };
62                    }
63                    AgentEvent::ThinkingDelta(text) => {
64                        yield StreamEvent { event: Some(stream_event::Event::Thinking(StreamThinking { content: text })) };
65                    }
66                    AgentEvent::ToolCallsStart(calls) => {
67                        yield StreamEvent { event: Some(stream_event::Event::ToolStart(ToolStartEvent {
68                            calls: calls.into_iter().map(|c| ToolCallInfo {
69                                name: c.function.name.to_string(),
70                                arguments: c.function.arguments,
71                            }).collect(),
72                        })) };
73                    }
74                    AgentEvent::ToolResult { call_id, output } => {
75                        yield StreamEvent { event: Some(stream_event::Event::ToolResult(ToolResultEvent { call_id: call_id.to_string(), output })) };
76                    }
77                    AgentEvent::ToolCallsComplete => {
78                        yield StreamEvent { event: Some(stream_event::Event::ToolsComplete(ToolsCompleteEvent {})) };
79                    }
80                    AgentEvent::Done(resp) => {
81                        if let wcore::AgentStopReason::Error(e) = &resp.stop_reason {
82                            if is_new {
83                                rt.close_session(session_id).await;
84                            }
85                            Err(anyhow::anyhow!("{e}"))?;
86                        }
87                        break;
88                    }
89                }
90            }
91            if is_new {
92                rt.close_session(session_id).await;
93            }
94
95            yield StreamEvent { event: Some(stream_event::Event::End(StreamEnd { agent: agent.clone() })) };
96        }
97    }
98
99    async fn ping(&self) -> Result<()> {
100        Ok(())
101    }
102
103    async fn list_sessions(&self) -> Result<Vec<SessionInfo>> {
104        let rt = self.runtime.read().await.clone();
105        let sessions = rt.sessions().await;
106        let mut infos = Vec::with_capacity(sessions.len());
107        for s in sessions {
108            let s = s.lock().await;
109            infos.push(SessionInfo {
110                id: s.id,
111                agent: s.agent.to_string(),
112                created_by: s.created_by.to_string(),
113                message_count: s.history.len() as u64,
114                alive_secs: s.created_at.elapsed().as_secs(),
115            });
116        }
117        Ok(infos)
118    }
119
120    async fn kill_session(&self, session: u64) -> Result<bool> {
121        let rt = self.runtime.read().await.clone();
122        Ok(rt.close_session(session).await)
123    }
124
125    async fn list_tasks(&self) -> Result<Vec<TaskInfo>> {
126        let rt = self.runtime.read().await.clone();
127        let registry = rt.hook.tasks.lock().await;
128        let tasks = registry.list(None, None, None);
129        Ok(tasks
130            .into_iter()
131            .map(|t| TaskInfo {
132                id: t.id,
133                parent_id: t.parent_id,
134                agent: t.agent.to_string(),
135                status: t.status.to_string(),
136                description: t.description.clone(),
137                result: t.result.clone(),
138                error: t.error.clone(),
139                created_by: t.created_by.to_string(),
140                prompt_tokens: t.prompt_tokens,
141                completion_tokens: t.completion_tokens,
142                alive_secs: t.created_at.elapsed().as_secs(),
143                blocked_on: t.blocked_on.as_ref().map(|i| i.question.clone()),
144            })
145            .collect())
146    }
147
148    async fn kill_task(&self, task_id: u64) -> Result<bool> {
149        let rt = self.runtime.read().await.clone();
150        let tasks = rt.hook.tasks.clone();
151        let mut registry = tasks.lock().await;
152        let Some(task) = registry.get(task_id) else {
153            return Ok(false);
154        };
155        match task.status {
156            crate::hook::task::TaskStatus::InProgress | crate::hook::task::TaskStatus::Blocked => {
157                if let Some(handle) = &task.abort_handle {
158                    handle.abort();
159                }
160                registry.set_status(task_id, crate::hook::task::TaskStatus::Failed);
161                if let Some(task) = registry.get_mut(task_id) {
162                    task.error = Some("killed by user".into());
163                }
164                // Close associated session.
165                if let Some(sid) = registry.get(task_id).and_then(|t| t.session_id) {
166                    drop(registry);
167                    rt.close_session(sid).await;
168                    let mut registry = tasks.lock().await;
169                    registry.promote_next(tasks.clone());
170                } else {
171                    registry.promote_next(tasks.clone());
172                }
173                Ok(true)
174            }
175            crate::hook::task::TaskStatus::Queued => {
176                registry.remove(task_id);
177                Ok(true)
178            }
179            _ => Ok(false),
180        }
181    }
182
183    async fn approve_task(&self, task_id: u64, response: String) -> Result<bool> {
184        let rt = self.runtime.read().await.clone();
185        let mut registry = rt.hook.tasks.lock().await;
186        Ok(registry.approve(task_id, response))
187    }
188
189    fn hub(
190        &self,
191        package: String,
192        action: HubAction,
193        filters: Vec<String>,
194    ) -> impl futures_core::Stream<Item = Result<DownloadEvent>> + Send {
195        let runtime = self.runtime.clone();
196        async_stream::try_stream! {
197            let rt = runtime.read().await.clone();
198            let registry = rt.hook.downloads.clone();
199            let package = compact_str::CompactString::from(package.as_str());
200            match action {
201                HubAction::Install => {
202                    let s = crate::ext::hub::package::install(package, registry, filters);
203                    pin_mut!(s);
204                    while let Some(event) = s.next().await {
205                        yield event?;
206                    }
207                }
208                HubAction::Uninstall => {
209                    let s = crate::ext::hub::package::uninstall(package, registry, filters);
210                    pin_mut!(s);
211                    while let Some(event) = s.next().await {
212                        yield event?;
213                    }
214                }
215            }
216        }
217    }
218
219    fn subscribe_tasks(&self) -> impl futures_core::Stream<Item = Result<TaskEvent>> + Send {
220        let runtime = self.runtime.clone();
221        async_stream::try_stream! {
222            let rt = runtime.read().await.clone();
223            let mut rx = rt.hook.tasks.lock().await.subscribe();
224            loop {
225                match rx.recv().await {
226                    Ok(event) => yield event,
227                    Err(tokio::sync::broadcast::error::RecvError::Closed) => break,
228                    Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => continue,
229                }
230            }
231        }
232    }
233
234    async fn list_downloads(&self) -> Result<Vec<DownloadInfo>> {
235        let rt = self.runtime.read().await.clone();
236        let registry = rt.hook.downloads.lock().await;
237        Ok(registry.list())
238    }
239
240    fn subscribe_downloads(
241        &self,
242    ) -> impl futures_core::Stream<Item = Result<DownloadEvent>> + Send {
243        let runtime = self.runtime.clone();
244        async_stream::try_stream! {
245            let rt = runtime.read().await.clone();
246            let mut rx = rt.hook.downloads.lock().await.subscribe();
247            loop {
248                match rx.recv().await {
249                    Ok(event) => yield event,
250                    Err(tokio::sync::broadcast::error::RecvError::Closed) => break,
251                    Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => continue,
252                }
253            }
254        }
255    }
256
257    async fn get_config(&self) -> Result<String> {
258        let config = self.load_config()?;
259        serde_json::to_string(&config).context("failed to serialize config")
260    }
261
262    async fn set_config(&self, config: String) -> Result<()> {
263        let parsed: crate::DaemonConfig =
264            serde_json::from_str(&config).context("invalid DaemonConfig JSON")?;
265        let toml_str =
266            toml::to_string_pretty(&parsed).context("failed to serialize config to TOML")?;
267        let config_path = self.config_dir.join("walrus.toml");
268        std::fs::write(&config_path, toml_str)
269            .with_context(|| format!("failed to write {}", config_path.display()))?;
270        self.reload().await
271    }
272
273    async fn service_query(&self, service: String, query: String) -> Result<String> {
274        let rt = self.runtime.read().await.clone();
275        let registry = rt
276            .hook
277            .registry
278            .as_ref()
279            .ok_or_else(|| anyhow::anyhow!("no service registry"))?;
280        let handle = registry
281            .query
282            .get(&service)
283            .ok_or_else(|| anyhow::anyhow!("service '{}' not available", service))?;
284        let req = wcore::protocol::ext::ExtRequest {
285            msg: Some(wcore::protocol::ext::ext_request::Msg::ServiceQuery(
286                wcore::protocol::ext::ExtServiceQuery { query },
287            )),
288        };
289        let resp = handle.request(&req).await?;
290        match resp.msg {
291            Some(wcore::protocol::ext::ext_response::Msg::ServiceQueryResult(result)) => {
292                Ok(result.result)
293            }
294            Some(wcore::protocol::ext::ext_response::Msg::Error(e)) => {
295                anyhow::bail!("service '{}' error: {}", service, e.message)
296            }
297            other => anyhow::bail!("unexpected response from service '{}': {other:?}", service),
298        }
299    }
300
301    async fn get_service_schema(&self, service: String) -> Result<String> {
302        let rt = self.runtime.read().await.clone();
303        let registry = rt
304            .hook
305            .registry
306            .as_ref()
307            .ok_or_else(|| anyhow::anyhow!("no service registry"))?;
308        let handle = registry
309            .query
310            .get(&service)
311            .or_else(|| registry.tools.values().find(|h| h.name.as_str() == service))
312            .ok_or_else(|| anyhow::anyhow!("service '{}' not found", service))?;
313        let req = wcore::protocol::ext::ExtRequest {
314            msg: Some(wcore::protocol::ext::ext_request::Msg::GetSchema(
315                wcore::protocol::ext::ExtGetSchema {},
316            )),
317        };
318        let resp = handle.request(&req).await?;
319        match resp.msg {
320            Some(wcore::protocol::ext::ext_response::Msg::SchemaResult(result)) => {
321                Ok(result.schema)
322            }
323            Some(wcore::protocol::ext::ext_response::Msg::Error(e)) => {
324                anyhow::bail!("service '{}' schema error: {}", service, e.message)
325            }
326            other => anyhow::bail!(
327                "unexpected schema response from service '{}': {other:?}",
328                service
329            ),
330        }
331    }
332
333    async fn get_all_schemas(&self) -> Result<std::collections::HashMap<String, String>> {
334        let rt = self.runtime.read().await.clone();
335        let registry = rt
336            .hook
337            .registry
338            .as_ref()
339            .ok_or_else(|| anyhow::anyhow!("no service registry"))?;
340        let mut schemas = std::collections::HashMap::new();
341        // Collect unique service handles from the query registry.
342        for (name, handle) in &registry.query {
343            let req = wcore::protocol::ext::ExtRequest {
344                msg: Some(wcore::protocol::ext::ext_request::Msg::GetSchema(
345                    wcore::protocol::ext::ExtGetSchema {},
346                )),
347            };
348            if let Ok(resp) = handle.request(&req).await
349                && let Some(wcore::protocol::ext::ext_response::Msg::SchemaResult(result)) =
350                    resp.msg
351            {
352                schemas.insert(name.clone(), result.schema);
353            }
354        }
355        Ok(schemas)
356    }
357
358    async fn list_services(&self) -> Result<Vec<wcore::protocol::message::ServiceInfoMsg>> {
359        let rt = self.runtime.read().await.clone();
360        let registry = rt.hook.registry.as_ref();
361        let mut services = Vec::new();
362        if let Some(reg) = registry {
363            // Collect unique service names from all capability buckets.
364            let mut seen = std::collections::HashSet::new();
365            let all_handles: Vec<_> = reg
366                .build_agent
367                .iter()
368                .chain(reg.before_run.iter())
369                .chain(reg.compact.iter())
370                .chain(reg.event_observer.iter())
371                .chain(reg.query.values())
372                .chain(reg.tools.values())
373                .collect();
374            for handle in all_handles {
375                let name = handle.name.to_string();
376                if !seen.insert(name.clone()) {
377                    continue;
378                }
379                let capabilities: Vec<String> = handle
380                    .capabilities
381                    .iter()
382                    .filter_map(|c| match &c.cap {
383                        Some(wcore::protocol::ext::capability::Cap::Tools(_)) => {
384                            Some("tools".into())
385                        }
386                        Some(wcore::protocol::ext::capability::Cap::Query(_)) => {
387                            Some("query".into())
388                        }
389                        Some(wcore::protocol::ext::capability::Cap::BuildAgent(_)) => {
390                            Some("build_agent".into())
391                        }
392                        Some(wcore::protocol::ext::capability::Cap::BeforeRun(_)) => {
393                            Some("before_run".into())
394                        }
395                        Some(wcore::protocol::ext::capability::Cap::Compact(_)) => {
396                            Some("compact".into())
397                        }
398                        Some(wcore::protocol::ext::capability::Cap::EventObserver(_)) => {
399                            Some("event_observer".into())
400                        }
401                        Some(wcore::protocol::ext::capability::Cap::AfterRun(_)) => {
402                            Some("after_run".into())
403                        }
404                        Some(wcore::protocol::ext::capability::Cap::Infer(_)) => {
405                            Some("infer".into())
406                        }
407                        None => None,
408                    })
409                    .collect();
410                services.push(wcore::protocol::message::ServiceInfoMsg {
411                    name,
412                    kind: "extension".into(),
413                    status: "running".into(),
414                    capabilities,
415                    has_config: true,
416                });
417            }
418        }
419        Ok(services)
420    }
421
422    async fn set_service_config(&self, service: String, config: String) -> Result<()> {
423        let mut daemon_config = self.load_config()?;
424        let svc = daemon_config
425            .services
426            .get_mut(&service)
427            .ok_or_else(|| anyhow::anyhow!("service '{}' not found in config", service))?;
428        let parsed: serde_json::Value =
429            serde_json::from_str(&config).context("invalid service config JSON")?;
430        svc.config = parsed;
431        let toml_str =
432            toml::to_string_pretty(&daemon_config).context("failed to serialize config to TOML")?;
433        let config_path = self.config_dir.join("walrus.toml");
434        std::fs::write(&config_path, toml_str)
435            .with_context(|| format!("failed to write {}", config_path.display()))?;
436        self.reload().await
437    }
438
439    async fn reload(&self) -> Result<()> {
440        self.reload().await
441    }
442}
443
444impl Daemon {
445    /// Load the current `DaemonConfig` from disk.
446    fn load_config(&self) -> Result<crate::DaemonConfig> {
447        crate::DaemonConfig::load(&self.config_dir.join("walrus.toml"))
448    }
449}