1use crate::daemon::Daemon;
4use anyhow::{Context, Result};
5use futures_util::{StreamExt, pin_mut};
6use std::sync::Arc;
7use wcore::AgentEvent;
8use wcore::protocol::{
9 api::Server,
10 message::{
11 DownloadEvent, DownloadInfo, HubAction, SendMsg, SendResponse, SessionInfo, StreamChunk,
12 StreamEnd, StreamEvent, StreamMsg, StreamStart, StreamThinking, TaskEvent, TaskInfo,
13 ToolCallInfo, ToolResultEvent, ToolStartEvent, ToolsCompleteEvent, stream_event,
14 },
15};
16
17impl Server for Daemon {
18 async fn send(&self, req: SendMsg) -> Result<SendResponse> {
19 let rt: Arc<_> = self.runtime.read().await.clone();
20 let sender = req.sender.as_deref().unwrap_or("");
21 let created_by = if sender.is_empty() { "user" } else { sender };
22 let (session_id, is_new) = match req.session {
23 Some(id) => (id, false),
24 None => (rt.create_session(&req.agent, created_by).await?, true),
25 };
26 let response = rt.send_to(session_id, &req.content, sender).await?;
27 if is_new {
28 rt.close_session(session_id).await;
29 }
30 Ok(SendResponse {
31 agent: req.agent,
32 content: response.final_response.unwrap_or_default(),
33 session: session_id,
34 })
35 }
36
37 fn stream(
38 &self,
39 req: StreamMsg,
40 ) -> impl futures_core::Stream<Item = Result<StreamEvent>> + Send {
41 let runtime = self.runtime.clone();
42 let agent = req.agent;
43 let content = req.content;
44 let req_session = req.session;
45 let sender = req.sender.unwrap_or_default();
46 async_stream::try_stream! {
47 let rt: Arc<_> = runtime.read().await.clone();
48 let created_by = if sender.is_empty() { "user".into() } else { sender.clone() };
49 let (session_id, is_new) = match req_session {
50 Some(id) => (id, false),
51 None => (rt.create_session(&agent, created_by.as_str()).await?, true),
52 };
53
54 yield StreamEvent { event: Some(stream_event::Event::Start(StreamStart { agent: agent.clone(), session: session_id })) };
55
56 let stream = rt.stream_to(session_id, &content, &sender);
57 pin_mut!(stream);
58 while let Some(event) = stream.next().await {
59 match event {
60 AgentEvent::TextDelta(text) => {
61 yield StreamEvent { event: Some(stream_event::Event::Chunk(StreamChunk { content: text })) };
62 }
63 AgentEvent::ThinkingDelta(text) => {
64 yield StreamEvent { event: Some(stream_event::Event::Thinking(StreamThinking { content: text })) };
65 }
66 AgentEvent::ToolCallsStart(calls) => {
67 yield StreamEvent { event: Some(stream_event::Event::ToolStart(ToolStartEvent {
68 calls: calls.into_iter().map(|c| ToolCallInfo {
69 name: c.function.name.to_string(),
70 arguments: c.function.arguments,
71 }).collect(),
72 })) };
73 }
74 AgentEvent::ToolResult { call_id, output } => {
75 yield StreamEvent { event: Some(stream_event::Event::ToolResult(ToolResultEvent { call_id: call_id.to_string(), output })) };
76 }
77 AgentEvent::ToolCallsComplete => {
78 yield StreamEvent { event: Some(stream_event::Event::ToolsComplete(ToolsCompleteEvent {})) };
79 }
80 AgentEvent::Compact { .. } => {
81 }
83 AgentEvent::Done(resp) => {
84 if let wcore::AgentStopReason::Error(e) = &resp.stop_reason {
85 if is_new {
86 rt.close_session(session_id).await;
87 }
88 Err(anyhow::anyhow!("{e}"))?;
89 }
90 break;
91 }
92 }
93 }
94 yield StreamEvent { event: Some(stream_event::Event::End(StreamEnd { agent: agent.clone() })) };
95 }
96 }
97
98 async fn ping(&self) -> Result<()> {
99 Ok(())
100 }
101
102 async fn list_sessions(&self) -> Result<Vec<SessionInfo>> {
103 let rt = self.runtime.read().await.clone();
104 let sessions = rt.sessions().await;
105 let mut infos = Vec::with_capacity(sessions.len());
106 for s in sessions {
107 let s = s.lock().await;
108 infos.push(SessionInfo {
109 id: s.id,
110 agent: s.agent.to_string(),
111 created_by: s.created_by.to_string(),
112 message_count: s.history.len() as u64,
113 alive_secs: s.created_at.elapsed().as_secs(),
114 });
115 }
116 Ok(infos)
117 }
118
119 async fn kill_session(&self, session: u64) -> Result<bool> {
120 let rt = self.runtime.read().await.clone();
121 Ok(rt.close_session(session).await)
122 }
123
124 async fn list_tasks(&self) -> Result<Vec<TaskInfo>> {
125 let rt = self.runtime.read().await.clone();
126 let registry = rt.hook.tasks.lock().await;
127 let tasks = registry.list(None, None, None);
128 Ok(tasks
129 .into_iter()
130 .map(|t| crate::hook::system::task::TaskRegistry::task_info(t))
131 .collect())
132 }
133
134 async fn kill_task(&self, task_id: u64) -> Result<bool> {
135 use crate::hook::system::task::TaskStatus;
136 let rt = self.runtime.read().await.clone();
137 let tasks = rt.hook.tasks.clone();
138 let mut registry = tasks.lock().await;
139 let Some(task) = registry.get(task_id) else {
140 return Ok(false);
141 };
142 match task.status {
143 TaskStatus::InProgress | TaskStatus::Blocked => {
144 let session_id = task.session_id;
145 registry.kill(task_id);
146 crate::hook::system::task::tool::try_promote(
147 &mut registry,
148 tasks.clone(),
149 rt.hook.event_tx.clone(),
150 rt.hook.task_timeout,
151 );
152 if let Some(sid) = session_id {
154 drop(registry);
155 rt.close_session(sid).await;
156 }
157 Ok(true)
158 }
159 TaskStatus::Queued => {
160 registry.remove(task_id);
161 Ok(true)
162 }
163 _ => Ok(false),
164 }
165 }
166
167 async fn approve_task(&self, task_id: u64, response: String) -> Result<bool> {
168 let rt = self.runtime.read().await.clone();
169 let mut registry = rt.hook.tasks.lock().await;
170 Ok(registry.approve(task_id, response))
171 }
172
173 fn hub(
174 &self,
175 package: String,
176 action: HubAction,
177 filters: Vec<String>,
178 ) -> impl futures_core::Stream<Item = Result<DownloadEvent>> + Send {
179 let runtime = self.runtime.clone();
180 async_stream::try_stream! {
181 let rt = runtime.read().await.clone();
182 let registry = rt.hook.downloads.clone();
183 let package = compact_str::CompactString::from(package.as_str());
184 match action {
185 HubAction::Install => {
186 let s = crate::ext::hub::package::install(package, registry, filters);
187 pin_mut!(s);
188 while let Some(event) = s.next().await {
189 yield event?;
190 }
191 }
192 HubAction::Uninstall => {
193 let s = crate::ext::hub::package::uninstall(package, registry, filters);
194 pin_mut!(s);
195 while let Some(event) = s.next().await {
196 yield event?;
197 }
198 }
199 }
200 }
201 }
202
203 fn subscribe_tasks(&self) -> impl futures_core::Stream<Item = Result<TaskEvent>> + Send {
204 let runtime = self.runtime.clone();
205 async_stream::try_stream! {
206 let rt = runtime.read().await.clone();
207 let mut rx = rt.hook.tasks.lock().await.subscribe();
208 loop {
209 match rx.recv().await {
210 Ok(event) => yield event,
211 Err(tokio::sync::broadcast::error::RecvError::Closed) => break,
212 Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => continue,
213 }
214 }
215 }
216 }
217
218 async fn list_downloads(&self) -> Result<Vec<DownloadInfo>> {
219 let rt = self.runtime.read().await.clone();
220 let registry = rt.hook.downloads.lock().await;
221 Ok(registry.list())
222 }
223
224 fn subscribe_downloads(
225 &self,
226 ) -> impl futures_core::Stream<Item = Result<DownloadEvent>> + Send {
227 let runtime = self.runtime.clone();
228 async_stream::try_stream! {
229 let rt = runtime.read().await.clone();
230 let mut rx = rt.hook.downloads.lock().await.subscribe();
231 loop {
232 match rx.recv().await {
233 Ok(event) => yield event,
234 Err(tokio::sync::broadcast::error::RecvError::Closed) => break,
235 Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => continue,
236 }
237 }
238 }
239 }
240
241 async fn get_config(&self) -> Result<String> {
242 let config = self.load_config()?;
243 serde_json::to_string(&config).context("failed to serialize config")
244 }
245
246 async fn set_config(&self, config: String) -> Result<()> {
247 let parsed: crate::DaemonConfig =
248 serde_json::from_str(&config).context("invalid DaemonConfig JSON")?;
249 let toml_str =
250 toml::to_string_pretty(&parsed).context("failed to serialize config to TOML")?;
251 let config_path = self.config_dir.join("walrus.toml");
252 std::fs::write(&config_path, toml_str)
253 .with_context(|| format!("failed to write {}", config_path.display()))?;
254 self.reload().await
255 }
256
257 async fn service_query(&self, service: String, query: String) -> Result<String> {
258 let rt = self.runtime.read().await.clone();
259 let registry = rt
260 .hook
261 .registry
262 .as_ref()
263 .ok_or_else(|| anyhow::anyhow!("no service registry"))?;
264 let handle = registry
265 .query
266 .get(&service)
267 .ok_or_else(|| anyhow::anyhow!("service '{}' not available", service))?;
268 let req = wcore::protocol::ext::ExtRequest {
269 msg: Some(wcore::protocol::ext::ext_request::Msg::ServiceQuery(
270 wcore::protocol::ext::ExtServiceQuery { query },
271 )),
272 };
273 let resp = handle.request(&req).await?;
274 match resp.msg {
275 Some(wcore::protocol::ext::ext_response::Msg::ServiceQueryResult(result)) => {
276 Ok(result.result)
277 }
278 Some(wcore::protocol::ext::ext_response::Msg::Error(e)) => {
279 anyhow::bail!("service '{}' error: {}", service, e.message)
280 }
281 other => anyhow::bail!("unexpected response from service '{}': {other:?}", service),
282 }
283 }
284
285 async fn get_service_schema(&self, service: String) -> Result<String> {
286 let rt = self.runtime.read().await.clone();
287 let registry = rt
288 .hook
289 .registry
290 .as_ref()
291 .ok_or_else(|| anyhow::anyhow!("no service registry"))?;
292 let handle = registry
293 .query
294 .get(&service)
295 .or_else(|| registry.tools.values().find(|h| h.name.as_str() == service))
296 .ok_or_else(|| anyhow::anyhow!("service '{}' not found", service))?;
297 let req = wcore::protocol::ext::ExtRequest {
298 msg: Some(wcore::protocol::ext::ext_request::Msg::GetSchema(
299 wcore::protocol::ext::ExtGetSchema {},
300 )),
301 };
302 let resp = handle.request(&req).await?;
303 match resp.msg {
304 Some(wcore::protocol::ext::ext_response::Msg::SchemaResult(result)) => {
305 Ok(result.schema)
306 }
307 Some(wcore::protocol::ext::ext_response::Msg::Error(e)) => {
308 anyhow::bail!("service '{}' schema error: {}", service, e.message)
309 }
310 other => anyhow::bail!(
311 "unexpected schema response from service '{}': {other:?}",
312 service
313 ),
314 }
315 }
316
317 async fn get_all_schemas(&self) -> Result<std::collections::HashMap<String, String>> {
318 let rt = self.runtime.read().await.clone();
319 let registry = rt
320 .hook
321 .registry
322 .as_ref()
323 .ok_or_else(|| anyhow::anyhow!("no service registry"))?;
324 let mut schemas = std::collections::HashMap::new();
325 for (name, handle) in ®istry.query {
327 let req = wcore::protocol::ext::ExtRequest {
328 msg: Some(wcore::protocol::ext::ext_request::Msg::GetSchema(
329 wcore::protocol::ext::ExtGetSchema {},
330 )),
331 };
332 if let Ok(resp) = handle.request(&req).await
333 && let Some(wcore::protocol::ext::ext_response::Msg::SchemaResult(result)) =
334 resp.msg
335 {
336 schemas.insert(name.clone(), result.schema);
337 }
338 }
339 Ok(schemas)
340 }
341
342 async fn list_services(&self) -> Result<Vec<wcore::protocol::message::ServiceInfoMsg>> {
343 let rt = self.runtime.read().await.clone();
344 let registry = rt.hook.registry.as_ref();
345 let mut services = Vec::new();
346 if let Some(reg) = registry {
347 let mut seen = std::collections::HashSet::new();
349 let all_handles: Vec<_> = reg
350 .build_agent
351 .iter()
352 .chain(reg.before_run.iter())
353 .chain(reg.compact.iter())
354 .chain(reg.event_observer.iter())
355 .chain(reg.after_run.iter())
356 .chain(reg.after_compact.iter())
357 .chain(reg.query.values())
358 .chain(reg.tools.values())
359 .collect();
360 for handle in all_handles {
361 let name = handle.name.to_string();
362 if !seen.insert(name.clone()) {
363 continue;
364 }
365 let capabilities: Vec<String> = handle
366 .capabilities
367 .iter()
368 .filter_map(|c| match &c.cap {
369 Some(wcore::protocol::ext::capability::Cap::Tools(_)) => {
370 Some("tools".into())
371 }
372 Some(wcore::protocol::ext::capability::Cap::Query(_)) => {
373 Some("query".into())
374 }
375 Some(wcore::protocol::ext::capability::Cap::BuildAgent(_)) => {
376 Some("build_agent".into())
377 }
378 Some(wcore::protocol::ext::capability::Cap::BeforeRun(_)) => {
379 Some("before_run".into())
380 }
381 Some(wcore::protocol::ext::capability::Cap::Compact(_)) => {
382 Some("compact".into())
383 }
384 Some(wcore::protocol::ext::capability::Cap::EventObserver(_)) => {
385 Some("event_observer".into())
386 }
387 Some(wcore::protocol::ext::capability::Cap::AfterRun(_)) => {
388 Some("after_run".into())
389 }
390 Some(wcore::protocol::ext::capability::Cap::Infer(_)) => {
391 Some("infer".into())
392 }
393 Some(wcore::protocol::ext::capability::Cap::AfterCompact(_)) => {
394 Some("after_compact".into())
395 }
396 None => None,
397 })
398 .collect();
399 services.push(wcore::protocol::message::ServiceInfoMsg {
400 name,
401 kind: "extension".into(),
402 status: "running".into(),
403 capabilities,
404 has_config: true,
405 });
406 }
407 }
408 Ok(services)
409 }
410
411 async fn set_service_config(&self, service: String, config: String) -> Result<()> {
412 let mut daemon_config = self.load_config()?;
413 let svc = daemon_config
414 .services
415 .get_mut(&service)
416 .ok_or_else(|| anyhow::anyhow!("service '{}' not found in config", service))?;
417 let parsed: serde_json::Value =
418 serde_json::from_str(&config).context("invalid service config JSON")?;
419 svc.config = parsed;
420 let toml_str =
421 toml::to_string_pretty(&daemon_config).context("failed to serialize config to TOML")?;
422 let config_path = self.config_dir.join("walrus.toml");
423 std::fs::write(&config_path, toml_str)
424 .with_context(|| format!("failed to write {}", config_path.display()))?;
425 self.reload().await
426 }
427
428 async fn reload(&self) -> Result<()> {
429 self.reload().await
430 }
431}
432
433impl Daemon {
434 fn load_config(&self) -> Result<crate::DaemonConfig> {
436 crate::DaemonConfig::load(&self.config_dir.join("walrus.toml"))
437 }
438}