Skip to main content

walrus_core/protocol/api/
server.rs

1//! Server trait — one async method per protocol operation.
2
3use crate::protocol::message::{
4    DownloadEvent, DownloadRequest, HubAction, MemoryOp, MemoryResult, SendRequest, SendResponse,
5    StreamEvent, StreamRequest, TaskEvent,
6    client::ClientMessage,
7    server::{DownloadInfo, ServerMessage, SessionInfo, TaskInfo},
8};
9use anyhow::Result;
10use futures_core::Stream;
11use futures_util::StreamExt;
12
13/// Server-side protocol handler.
14///
15/// Each method corresponds to one `ClientMessage` variant. Implementations
16/// receive typed request structs and return typed responses — no enum matching
17/// required. Streaming operations return `impl Stream`.
18///
19/// The provided [`dispatch`](Server::dispatch) method routes a raw
20/// `ClientMessage` to the appropriate handler, returning a stream of
21/// `ServerMessage`s.
22pub trait Server: Sync {
23    /// Handle `Send` — run agent and return complete response.
24    fn send(
25        &self,
26        req: SendRequest,
27    ) -> impl std::future::Future<Output = Result<SendResponse>> + Send;
28
29    /// Handle `Stream` — run agent and stream response events.
30    fn stream(&self, req: StreamRequest) -> impl Stream<Item = Result<StreamEvent>> + Send;
31
32    /// Handle `Download` — download model files with progress.
33    fn download(&self, req: DownloadRequest) -> impl Stream<Item = Result<DownloadEvent>> + Send;
34
35    /// Handle `Ping` — keepalive.
36    fn ping(&self) -> impl std::future::Future<Output = Result<()>> + Send;
37
38    /// Handle `Hub` — install or uninstall a hub package.
39    fn hub(
40        &self,
41        package: compact_str::CompactString,
42        action: HubAction,
43    ) -> impl Stream<Item = Result<DownloadEvent>> + Send;
44
45    /// Handle `Sessions` — list active sessions.
46    fn list_sessions(&self) -> impl std::future::Future<Output = Result<Vec<SessionInfo>>> + Send;
47
48    /// Handle `Kill` — close a session by ID.
49    fn kill_session(&self, session: u64) -> impl std::future::Future<Output = Result<bool>> + Send;
50
51    /// Handle `Tasks` — list tasks in the task registry.
52    fn list_tasks(&self) -> impl std::future::Future<Output = Result<Vec<TaskInfo>>> + Send;
53
54    /// Handle `KillTask` — cancel a task by ID.
55    fn kill_task(&self, task_id: u64) -> impl std::future::Future<Output = Result<bool>> + Send;
56
57    /// Handle `Approve` — approve a blocked task's inbox item.
58    fn approve_task(
59        &self,
60        task_id: u64,
61        response: String,
62    ) -> impl std::future::Future<Output = Result<bool>> + Send;
63
64    /// Handle `Evaluate` — decide whether the agent should respond (DD#39).
65    fn evaluate(&self, req: SendRequest) -> impl std::future::Future<Output = Result<bool>> + Send;
66
67    /// Handle `SubscribeTasks` — stream task lifecycle events.
68    fn subscribe_tasks(&self) -> impl Stream<Item = Result<TaskEvent>> + Send;
69
70    /// Handle `Downloads` — list downloads in the registry.
71    fn list_downloads(&self)
72    -> impl std::future::Future<Output = Result<Vec<DownloadInfo>>> + Send;
73
74    /// Handle `SubscribeDownloads` — stream download lifecycle events.
75    fn subscribe_downloads(&self) -> impl Stream<Item = Result<DownloadEvent>> + Send;
76
77    /// Handle `GetConfig` — return the full daemon config as JSON.
78    fn get_config(&self) -> impl std::future::Future<Output = Result<String>> + Send;
79
80    /// Handle `SetConfig` — replace the daemon config from JSON.
81    fn set_config(&self, config: String) -> impl std::future::Future<Output = Result<()>> + Send;
82
83    /// Handle `MemoryQuery` — query the memory graph.
84    fn memory_query(
85        &self,
86        query: MemoryOp,
87    ) -> impl std::future::Future<Output = Result<MemoryResult>> + Send;
88
89    /// Dispatch a `ClientMessage` to the appropriate handler method.
90    ///
91    /// Returns a stream of `ServerMessage`s. Request-response operations
92    /// yield exactly one message; streaming operations yield many.
93    fn dispatch(&self, msg: ClientMessage) -> impl Stream<Item = ServerMessage> + Send + '_ {
94        async_stream::stream! {
95            match msg {
96                ClientMessage::Send { agent, content, session, sender } => {
97                    yield result_to_msg(self.send(SendRequest { agent, content, session, sender }).await);
98                }
99                ClientMessage::Stream { agent, content, session, sender } => {
100                    let s = self.stream(StreamRequest { agent, content, session, sender });
101                    tokio::pin!(s);
102                    while let Some(result) = s.next().await {
103                        yield result_to_msg(result);
104                    }
105                }
106                ClientMessage::Download { model } => {
107                    let s = self.download(DownloadRequest { model });
108                    tokio::pin!(s);
109                    while let Some(result) = s.next().await {
110                        yield result_to_msg(result);
111                    }
112                }
113                ClientMessage::Ping => {
114                    yield match self.ping().await {
115                        Ok(()) => ServerMessage::Pong,
116                        Err(e) => ServerMessage::Error {
117                            code: 500,
118                            message: e.to_string(),
119                        },
120                    };
121                }
122                ClientMessage::Hub { package, action } => {
123                    let s = self.hub(package, action);
124                    tokio::pin!(s);
125                    while let Some(result) = s.next().await {
126                        yield result_to_msg(result);
127                    }
128                }
129                ClientMessage::Sessions => {
130                    yield match self.list_sessions().await {
131                        Ok(sessions) => ServerMessage::Sessions(sessions),
132                        Err(e) => ServerMessage::Error {
133                            code: 500,
134                            message: e.to_string(),
135                        },
136                    };
137                }
138                ClientMessage::Kill { session } => {
139                    yield match self.kill_session(session).await {
140                        Ok(true) => ServerMessage::Pong,
141                        Ok(false) => ServerMessage::Error {
142                            code: 404,
143                            message: format!("session {session} not found"),
144                        },
145                        Err(e) => ServerMessage::Error {
146                            code: 500,
147                            message: e.to_string(),
148                        },
149                    };
150                }
151                ClientMessage::Tasks => {
152                    yield match self.list_tasks().await {
153                        Ok(tasks) => ServerMessage::Tasks(tasks),
154                        Err(e) => ServerMessage::Error {
155                            code: 500,
156                            message: e.to_string(),
157                        },
158                    };
159                }
160                ClientMessage::KillTask { task_id } => {
161                    yield match self.kill_task(task_id).await {
162                        Ok(true) => ServerMessage::Pong,
163                        Ok(false) => ServerMessage::Error {
164                            code: 404,
165                            message: format!("task {task_id} not found"),
166                        },
167                        Err(e) => ServerMessage::Error {
168                            code: 500,
169                            message: e.to_string(),
170                        },
171                    };
172                }
173                ClientMessage::Approve { task_id, response } => {
174                    yield match self.approve_task(task_id, response).await {
175                        Ok(true) => ServerMessage::Pong,
176                        Ok(false) => ServerMessage::Error {
177                            code: 404,
178                            message: format!("task {task_id} not found or not blocked"),
179                        },
180                        Err(e) => ServerMessage::Error {
181                            code: 500,
182                            message: e.to_string(),
183                        },
184                    };
185                }
186                ClientMessage::Evaluate { agent, content, session, sender } => {
187                    yield match self.evaluate(SendRequest { agent, content, session, sender }).await {
188                        Ok(respond) => ServerMessage::Evaluation { respond },
189                        Err(e) => ServerMessage::Error {
190                            code: 500,
191                            message: e.to_string(),
192                        },
193                    };
194                }
195                ClientMessage::SubscribeTasks => {
196                    let s = self.subscribe_tasks();
197                    tokio::pin!(s);
198                    while let Some(result) = s.next().await {
199                        yield result_to_msg(result);
200                    }
201                }
202                ClientMessage::Downloads => {
203                    yield match self.list_downloads().await {
204                        Ok(downloads) => ServerMessage::Downloads(downloads),
205                        Err(e) => ServerMessage::Error {
206                            code: 500,
207                            message: e.to_string(),
208                        },
209                    };
210                }
211                ClientMessage::SubscribeDownloads => {
212                    let s = self.subscribe_downloads();
213                    tokio::pin!(s);
214                    while let Some(result) = s.next().await {
215                        yield result_to_msg(result);
216                    }
217                }
218                ClientMessage::GetConfig => {
219                    yield match self.get_config().await {
220                        Ok(config) => ServerMessage::Config { config },
221                        Err(e) => ServerMessage::Error {
222                            code: 500,
223                            message: e.to_string(),
224                        },
225                    };
226                }
227                ClientMessage::SetConfig { config } => {
228                    yield match self.set_config(config).await {
229                        Ok(()) => ServerMessage::Pong,
230                        Err(e) => ServerMessage::Error {
231                            code: 500,
232                            message: e.to_string(),
233                        },
234                    };
235                }
236                ClientMessage::MemoryQuery { query } => {
237                    yield match self.memory_query(query).await {
238                        Ok(result) => ServerMessage::Memory(result),
239                        Err(e) => ServerMessage::Error {
240                            code: 500,
241                            message: e.to_string(),
242                        },
243                    };
244                }
245            }
246        }
247    }
248}
249
250/// Convert a typed `Result` into a `ServerMessage`.
251fn result_to_msg<T: Into<ServerMessage>>(result: Result<T>) -> ServerMessage {
252    match result {
253        Ok(resp) => resp.into(),
254        Err(e) => ServerMessage::Error {
255            code: 500,
256            message: e.to_string(),
257        },
258    }
259}