Skip to main content

walrus_daemon/daemon/
protocol.rs

1//! Server trait implementation for the Daemon.
2
3use crate::daemon::Daemon;
4use anyhow::{Context, Result};
5use compact_str::CompactString;
6use futures_util::{StreamExt, pin_mut};
7use std::sync::Arc;
8use wcore::protocol::{
9    api::Server,
10    message::{
11        DownloadEvent, DownloadRequest, HubAction, MemoryOp, MemoryResult, SendRequest,
12        SendResponse, StreamEvent, StreamRequest, TaskEvent,
13        server::{
14            DownloadInfo, EntityInfo, JournalInfo, RelationInfo, SessionInfo, TaskInfo,
15            ToolCallInfo,
16        },
17    },
18};
19use wcore::{AgentEvent, model::Model};
20
21impl Server for Daemon {
22    async fn send(&self, req: SendRequest) -> Result<SendResponse> {
23        let rt: Arc<_> = self.runtime.read().await.clone();
24        let sender = req.sender.as_deref().unwrap_or("");
25        let created_by = if sender.is_empty() { "user" } else { sender };
26        let (session_id, is_new) = match req.session {
27            Some(id) => (id, false),
28            None => (rt.create_session(&req.agent, created_by).await?, true),
29        };
30        let response = rt.send_to(session_id, &req.content, sender).await?;
31        if is_new {
32            rt.close_session(session_id).await;
33        }
34        Ok(SendResponse {
35            agent: req.agent,
36            content: response.final_response.unwrap_or_default(),
37            session: session_id,
38        })
39    }
40
41    fn stream(
42        &self,
43        req: StreamRequest,
44    ) -> impl futures_core::Stream<Item = Result<StreamEvent>> + Send {
45        let runtime = self.runtime.clone();
46        let agent = req.agent;
47        let content = req.content;
48        let req_session = req.session;
49        let sender = req.sender.unwrap_or_default();
50        async_stream::try_stream! {
51            let rt: Arc<_> = runtime.read().await.clone();
52            let created_by = if sender.is_empty() { "user".into() } else { sender.clone() };
53            let (session_id, is_new) = match req_session {
54                Some(id) => (id, false),
55                None => (rt.create_session(&agent, created_by.as_str()).await?, true),
56            };
57
58            yield StreamEvent::Start { agent: agent.clone(), session: session_id };
59
60            let stream = rt.stream_to(session_id, &content, &sender);
61            pin_mut!(stream);
62            while let Some(event) = stream.next().await {
63                match event {
64                    AgentEvent::TextDelta(text) => {
65                        yield StreamEvent::Chunk { content: text };
66                    }
67                    AgentEvent::ThinkingDelta(text) => {
68                        yield StreamEvent::Thinking { content: text };
69                    }
70                    AgentEvent::ToolCallsStart(calls) => {
71                        yield StreamEvent::ToolStart {
72                            calls: calls.into_iter().map(|c| ToolCallInfo {
73                                name: CompactString::from(c.function.name.as_str()),
74                                arguments: c.function.arguments,
75                            }).collect(),
76                        };
77                    }
78                    AgentEvent::ToolResult { call_id, output } => {
79                        yield StreamEvent::ToolResult { call_id, output };
80                    }
81                    AgentEvent::ToolCallsComplete => {
82                        yield StreamEvent::ToolsComplete;
83                    }
84                    AgentEvent::Done(resp) => {
85                        if let wcore::AgentStopReason::Error(e) = &resp.stop_reason {
86                            if is_new {
87                                rt.close_session(session_id).await;
88                            }
89                            Err(anyhow::anyhow!("{e}"))?;
90                        }
91                        break;
92                    }
93                }
94            }
95            if is_new {
96                rt.close_session(session_id).await;
97            }
98
99            yield StreamEvent::End { agent: agent.clone() };
100        }
101    }
102
103    fn download(
104        &self,
105        req: DownloadRequest,
106    ) -> impl futures_core::Stream<Item = Result<DownloadEvent>> + Send {
107        let runtime = self.runtime.clone();
108        async_stream::try_stream! {
109            let rt = runtime.read().await.clone();
110            let registry = rt.hook.downloads.clone();
111            let s = crate::ext::hub::model::download(req.model, registry);
112            pin_mut!(s);
113            while let Some(event) = s.next().await {
114                yield event?;
115            }
116        }
117    }
118
119    async fn ping(&self) -> Result<()> {
120        Ok(())
121    }
122
123    async fn list_sessions(&self) -> Result<Vec<SessionInfo>> {
124        let rt = self.runtime.read().await.clone();
125        let sessions = rt.sessions().await;
126        let mut infos = Vec::with_capacity(sessions.len());
127        for s in sessions {
128            let s = s.lock().await;
129            infos.push(SessionInfo {
130                id: s.id,
131                agent: s.agent.clone(),
132                created_by: s.created_by.clone(),
133                message_count: s.history.len(),
134                alive_secs: s.created_at.elapsed().as_secs(),
135            });
136        }
137        Ok(infos)
138    }
139
140    async fn kill_session(&self, session: u64) -> Result<bool> {
141        let rt = self.runtime.read().await.clone();
142        Ok(rt.close_session(session).await)
143    }
144
145    async fn list_tasks(&self) -> Result<Vec<TaskInfo>> {
146        let rt = self.runtime.read().await.clone();
147        let registry = rt.hook.tasks.lock().await;
148        let tasks = registry.list(None, None, None);
149        Ok(tasks
150            .into_iter()
151            .map(|t| TaskInfo {
152                id: t.id,
153                parent_id: t.parent_id,
154                agent: t.agent.clone(),
155                status: t.status.to_string(),
156                description: t.description.clone(),
157                result: t.result.clone(),
158                error: t.error.clone(),
159                created_by: t.created_by.clone(),
160                prompt_tokens: t.prompt_tokens,
161                completion_tokens: t.completion_tokens,
162                alive_secs: t.created_at.elapsed().as_secs(),
163                blocked_on: t.blocked_on.as_ref().map(|i| i.question.clone()),
164            })
165            .collect())
166    }
167
168    async fn kill_task(&self, task_id: u64) -> Result<bool> {
169        let rt = self.runtime.read().await.clone();
170        let tasks = rt.hook.tasks.clone();
171        let mut registry = tasks.lock().await;
172        let Some(task) = registry.get(task_id) else {
173            return Ok(false);
174        };
175        match task.status {
176            crate::hook::task::TaskStatus::InProgress | crate::hook::task::TaskStatus::Blocked => {
177                if let Some(handle) = &task.abort_handle {
178                    handle.abort();
179                }
180                registry.set_status(task_id, crate::hook::task::TaskStatus::Failed);
181                if let Some(task) = registry.get_mut(task_id) {
182                    task.error = Some("killed by user".into());
183                }
184                // Close associated session.
185                if let Some(sid) = registry.get(task_id).and_then(|t| t.session_id) {
186                    drop(registry);
187                    rt.close_session(sid).await;
188                    let mut registry = tasks.lock().await;
189                    registry.promote_next(tasks.clone());
190                } else {
191                    registry.promote_next(tasks.clone());
192                }
193                Ok(true)
194            }
195            crate::hook::task::TaskStatus::Queued => {
196                registry.remove(task_id);
197                Ok(true)
198            }
199            _ => Ok(false),
200        }
201    }
202
203    async fn approve_task(&self, task_id: u64, response: String) -> Result<bool> {
204        let rt = self.runtime.read().await.clone();
205        let mut registry = rt.hook.tasks.lock().await;
206        Ok(registry.approve(task_id, response))
207    }
208
209    async fn evaluate(&self, req: SendRequest) -> Result<bool> {
210        let rt: Arc<_> = self.runtime.read().await.clone();
211        let agent = rt
212            .get_agent(&req.agent)
213            .ok_or_else(|| anyhow::anyhow!("agent '{}' not found", req.agent))?;
214
215        let sender = req.sender.as_deref().unwrap_or("");
216
217        // Build sender context from memory.
218        let sender_context = if !sender.is_empty() {
219            let query = format!("{sender} profile");
220            let args = serde_json::json!({ "query": query, "entity_type": "profile", "limit": 3 });
221            let recall_result = rt.hook.memory.dispatch_recall(&args.to_string()).await;
222            if recall_result == "no entities found" {
223                String::new()
224            } else {
225                recall_result
226            }
227        } else {
228            String::new()
229        };
230
231        // Build a minimal evaluation prompt.
232        let mut eval_prompt = String::from(
233            "You are deciding whether to respond to a message in a group chat. \
234             Reply with exactly \"yes\" or \"no\".\n\n",
235        );
236        if !sender_context.is_empty() {
237            eval_prompt.push_str("Sender profile:\n");
238            eval_prompt.push_str(&sender_context);
239            eval_prompt.push('\n');
240        }
241        eval_prompt.push_str("Message: ");
242        eval_prompt.push_str(&req.content);
243        eval_prompt.push_str("\n\nShould you respond? (yes/no)");
244
245        let model_name = agent
246            .config
247            .model
248            .clone()
249            .unwrap_or_else(|| rt.model.active_model());
250
251        let messages = vec![
252            wcore::model::Message::system(&agent.config.system_prompt),
253            wcore::model::Message::user(eval_prompt),
254        ];
255
256        let request = wcore::model::Request::new(model_name).with_messages(messages);
257
258        match rt.model.send(&request).await {
259            Ok(response) => {
260                let text = response.message().map(|m| m.content).unwrap_or_default();
261                let lower = text.trim().to_lowercase();
262                Ok(lower.starts_with("yes"))
263            }
264            Err(e) => {
265                tracing::warn!(agent = %req.agent, "evaluate LLM call failed: {e}, defaulting to respond");
266                Ok(true)
267            }
268        }
269    }
270
271    fn hub(
272        &self,
273        package: CompactString,
274        action: HubAction,
275    ) -> impl futures_core::Stream<Item = Result<DownloadEvent>> + Send {
276        let runtime = self.runtime.clone();
277        async_stream::try_stream! {
278            let rt = runtime.read().await.clone();
279            let registry = rt.hook.downloads.clone();
280            match action {
281                HubAction::Install => {
282                    let s = crate::ext::hub::package::install(package, registry);
283                    pin_mut!(s);
284                    while let Some(event) = s.next().await {
285                        yield event?;
286                    }
287                }
288                HubAction::Uninstall => {
289                    let s = crate::ext::hub::package::uninstall(package, registry);
290                    pin_mut!(s);
291                    while let Some(event) = s.next().await {
292                        yield event?;
293                    }
294                }
295            }
296        }
297    }
298
299    fn subscribe_tasks(&self) -> impl futures_core::Stream<Item = Result<TaskEvent>> + Send {
300        let runtime = self.runtime.clone();
301        async_stream::try_stream! {
302            let rt = runtime.read().await.clone();
303            let mut rx = rt.hook.tasks.lock().await.subscribe();
304            loop {
305                match rx.recv().await {
306                    Ok(event) => yield event,
307                    Err(tokio::sync::broadcast::error::RecvError::Closed) => break,
308                    Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => continue,
309                }
310            }
311        }
312    }
313
314    async fn list_downloads(&self) -> Result<Vec<DownloadInfo>> {
315        let rt = self.runtime.read().await.clone();
316        let registry = rt.hook.downloads.lock().await;
317        Ok(registry.list())
318    }
319
320    fn subscribe_downloads(
321        &self,
322    ) -> impl futures_core::Stream<Item = Result<DownloadEvent>> + Send {
323        let runtime = self.runtime.clone();
324        async_stream::try_stream! {
325            let rt = runtime.read().await.clone();
326            let mut rx = rt.hook.downloads.lock().await.subscribe();
327            loop {
328                match rx.recv().await {
329                    Ok(event) => yield event,
330                    Err(tokio::sync::broadcast::error::RecvError::Closed) => break,
331                    Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => continue,
332                }
333            }
334        }
335    }
336
337    async fn get_config(&self) -> Result<String> {
338        let config = self.load_config()?;
339        serde_json::to_string(&config).context("failed to serialize config")
340    }
341
342    async fn set_config(&self, config: String) -> Result<()> {
343        let parsed: crate::DaemonConfig =
344            serde_json::from_str(&config).context("invalid DaemonConfig JSON")?;
345        let toml_str =
346            toml::to_string_pretty(&parsed).context("failed to serialize config to TOML")?;
347        let config_path = self.config_dir.join("walrus.toml");
348        std::fs::write(&config_path, toml_str)
349            .with_context(|| format!("failed to write {}", config_path.display()))?;
350        self.reload().await
351    }
352
353    async fn memory_query(&self, query: MemoryOp) -> Result<MemoryResult> {
354        let rt = self.runtime.read().await.clone();
355        let lance = &rt.hook.memory.lance;
356        let default_limit = 50;
357
358        match query {
359            MemoryOp::Entities { entity_type, limit } => {
360                let limit = limit.unwrap_or(default_limit) as usize;
361                let entities = lance.list_entities(entity_type.as_deref(), limit).await?;
362                Ok(MemoryResult::Entities(
363                    entities
364                        .into_iter()
365                        .map(|e| EntityInfo {
366                            entity_type: e.entity_type.into(),
367                            key: e.key.into(),
368                            value: e.value,
369                            created_at: e.created_at,
370                        })
371                        .collect(),
372                ))
373            }
374            MemoryOp::Relations { entity_id, limit } => {
375                let limit = limit.unwrap_or(default_limit) as usize;
376                let relations = lance.list_relations(entity_id.as_deref(), limit).await?;
377                Ok(MemoryResult::Relations(
378                    relations
379                        .into_iter()
380                        .map(|r| RelationInfo {
381                            source_id: r.source.into(),
382                            relation: r.relation.into(),
383                            target_id: r.target.into(),
384                            created_at: r.created_at,
385                        })
386                        .collect(),
387                ))
388            }
389            MemoryOp::Journals { agent, limit } => {
390                let limit = limit.unwrap_or(default_limit) as usize;
391                let journals = lance.list_journals(agent.as_deref(), limit).await?;
392                Ok(MemoryResult::Journals(
393                    journals
394                        .into_iter()
395                        .map(|j| JournalInfo {
396                            summary: j.summary,
397                            agent: j.agent.into(),
398                            created_at: j.created_at,
399                        })
400                        .collect(),
401                ))
402            }
403            MemoryOp::Search {
404                query,
405                entity_type,
406                limit,
407            } => {
408                let limit = limit.unwrap_or(default_limit) as usize;
409                let entities = lance
410                    .search_entities(&query, entity_type.as_deref(), limit)
411                    .await?;
412                Ok(MemoryResult::Entities(
413                    entities
414                        .into_iter()
415                        .map(|e| EntityInfo {
416                            entity_type: e.entity_type.into(),
417                            key: e.key.into(),
418                            value: e.value,
419                            created_at: e.created_at,
420                        })
421                        .collect(),
422                ))
423            }
424        }
425    }
426}
427
428impl Daemon {
429    /// Load the current `DaemonConfig` from disk.
430    fn load_config(&self) -> Result<crate::DaemonConfig> {
431        crate::DaemonConfig::load(&self.config_dir.join("walrus.toml"))
432    }
433}