Skip to main content

walrus_core/protocol/api/
server.rs

1//! Server trait — one async method per protocol operation.
2
3use crate::protocol::message::{
4    AllSchemasMsg, ClientMessage, ConfigMsg, DownloadEvent, DownloadInfo, DownloadList, ErrorMsg,
5    HubAction, Pong, SendMsg, SendResponse, ServerMessage, ServiceInfoMsg, ServiceListMsg,
6    ServiceQueryResultMsg, ServiceSchemaMsg, SessionInfo, SessionList, StreamEvent, StreamMsg,
7    TaskEvent, TaskInfo, TaskList, client_message, server_message,
8};
9use anyhow::Result;
10use futures_core::Stream;
11use futures_util::StreamExt;
12
13/// Construct an error `ServerMessage`.
14fn server_error(code: u32, message: String) -> ServerMessage {
15    ServerMessage {
16        msg: Some(server_message::Msg::Error(ErrorMsg { code, message })),
17    }
18}
19
20/// Construct a pong `ServerMessage`.
21fn server_pong() -> ServerMessage {
22    ServerMessage {
23        msg: Some(server_message::Msg::Pong(Pong {})),
24    }
25}
26
27/// Convert a typed `Result` into a `ServerMessage`.
28fn result_to_msg<T: Into<ServerMessage>>(result: Result<T>) -> ServerMessage {
29    match result {
30        Ok(resp) => resp.into(),
31        Err(e) => server_error(500, e.to_string()),
32    }
33}
34
35/// Server-side protocol handler.
36///
37/// Each method corresponds to one `ClientMessage` variant. Implementations
38/// receive typed request structs and return typed responses — no enum matching
39/// required. Streaming operations return `impl Stream`.
40///
41/// The provided [`dispatch`](Server::dispatch) method routes a raw
42/// `ClientMessage` to the appropriate handler, returning a stream of
43/// `ServerMessage`s.
44pub trait Server: Sync {
45    /// Handle `Send` — run agent and return complete response.
46    fn send(&self, req: SendMsg) -> impl std::future::Future<Output = Result<SendResponse>> + Send;
47
48    /// Handle `Stream` — run agent and stream response events.
49    fn stream(&self, req: StreamMsg) -> impl Stream<Item = Result<StreamEvent>> + Send;
50
51    /// Handle `Ping` — keepalive.
52    fn ping(&self) -> impl std::future::Future<Output = Result<()>> + Send;
53
54    /// Handle `Hub` — install or uninstall a hub package.
55    ///
56    /// `filters` restricts which components to install. Format: `"kind:name"`
57    /// (e.g. `"skill:playwright-cli"`, `"mcp:playwright"`). Empty = install all.
58    fn hub(
59        &self,
60        package: String,
61        action: HubAction,
62        filters: Vec<String>,
63    ) -> impl Stream<Item = Result<DownloadEvent>> + Send;
64
65    /// Handle `Sessions` — list active sessions.
66    fn list_sessions(&self) -> impl std::future::Future<Output = Result<Vec<SessionInfo>>> + Send;
67
68    /// Handle `Kill` — close a session by ID.
69    fn kill_session(&self, session: u64) -> impl std::future::Future<Output = Result<bool>> + Send;
70
71    /// Handle `Tasks` — list tasks in the task registry.
72    fn list_tasks(&self) -> impl std::future::Future<Output = Result<Vec<TaskInfo>>> + Send;
73
74    /// Handle `KillTask` — cancel a task by ID.
75    fn kill_task(&self, task_id: u64) -> impl std::future::Future<Output = Result<bool>> + Send;
76
77    /// Handle `Approve` — approve a blocked task's inbox item.
78    fn approve_task(
79        &self,
80        task_id: u64,
81        response: String,
82    ) -> impl std::future::Future<Output = Result<bool>> + Send;
83
84    /// Handle `SubscribeTasks` — stream task lifecycle events.
85    fn subscribe_tasks(&self) -> impl Stream<Item = Result<TaskEvent>> + Send;
86
87    /// Handle `Downloads` — list downloads in the registry.
88    fn list_downloads(&self)
89    -> impl std::future::Future<Output = Result<Vec<DownloadInfo>>> + Send;
90
91    /// Handle `SubscribeDownloads` — stream download lifecycle events.
92    fn subscribe_downloads(&self) -> impl Stream<Item = Result<DownloadEvent>> + Send;
93
94    /// Handle `GetConfig` — return the full daemon config as JSON.
95    fn get_config(&self) -> impl std::future::Future<Output = Result<String>> + Send;
96
97    /// Handle `SetConfig` — replace the daemon config from JSON.
98    fn set_config(&self, config: String) -> impl std::future::Future<Output = Result<()>> + Send;
99
100    /// Handle `ServiceQuery` — route to a named service.
101    fn service_query(
102        &self,
103        service: String,
104        query: String,
105    ) -> impl std::future::Future<Output = Result<String>> + Send;
106
107    /// Handle `GetServiceSchema` — return JSON Schema for one service's config.
108    fn get_service_schema(
109        &self,
110        service: String,
111    ) -> impl std::future::Future<Output = Result<String>> + Send;
112
113    /// Handle `GetAllSchemas` — return JSON Schemas for all services.
114    fn get_all_schemas(
115        &self,
116    ) -> impl std::future::Future<Output = Result<std::collections::HashMap<String, String>>> + Send;
117
118    /// Handle `GetServices` — list registered services with status.
119    fn list_services(
120        &self,
121    ) -> impl std::future::Future<Output = Result<Vec<ServiceInfoMsg>>> + Send;
122
123    /// Handle `SetServiceConfig` — update a single service's config.
124    fn set_service_config(
125        &self,
126        service: String,
127        config: String,
128    ) -> impl std::future::Future<Output = Result<()>> + Send;
129
130    /// Handle `Reload` — hot-reload runtime from disk.
131    fn reload(&self) -> impl std::future::Future<Output = Result<()>> + Send;
132
133    /// Dispatch a `ClientMessage` to the appropriate handler method.
134    ///
135    /// Returns a stream of `ServerMessage`s. Request-response operations
136    /// yield exactly one message; streaming operations yield many.
137    fn dispatch(&self, msg: ClientMessage) -> impl Stream<Item = ServerMessage> + Send + '_ {
138        async_stream::stream! {
139            let Some(inner) = msg.msg else {
140                yield server_error(400, "empty client message".to_string());
141                return;
142            };
143
144            match inner {
145                client_message::Msg::Send(send_msg) => {
146                    yield result_to_msg(self.send(send_msg).await);
147                }
148                client_message::Msg::Stream(stream_msg) => {
149                    let s = self.stream(stream_msg);
150                    tokio::pin!(s);
151                    while let Some(result) = s.next().await {
152                        yield result_to_msg(result);
153                    }
154                }
155                client_message::Msg::Ping(_) => {
156                    yield match self.ping().await {
157                        Ok(()) => server_pong(),
158                        Err(e) => server_error(500, e.to_string()),
159                    };
160                }
161                client_message::Msg::Hub(hub_msg) => {
162                    let action = hub_msg.action();
163                    let s = self.hub(hub_msg.package, action, hub_msg.filters);
164                    tokio::pin!(s);
165                    while let Some(result) = s.next().await {
166                        yield result_to_msg(result);
167                    }
168                }
169                client_message::Msg::Sessions(_) => {
170                    yield match self.list_sessions().await {
171                        Ok(sessions) => ServerMessage {
172                            msg: Some(server_message::Msg::Sessions(SessionList { sessions })),
173                        },
174                        Err(e) => server_error(500, e.to_string()),
175                    };
176                }
177                client_message::Msg::Kill(kill_msg) => {
178                    yield match self.kill_session(kill_msg.session).await {
179                        Ok(true) => server_pong(),
180                        Ok(false) => server_error(
181                            404,
182                            format!("session {} not found", kill_msg.session),
183                        ),
184                        Err(e) => server_error(500, e.to_string()),
185                    };
186                }
187                client_message::Msg::Tasks(_) => {
188                    yield match self.list_tasks().await {
189                        Ok(tasks) => ServerMessage {
190                            msg: Some(server_message::Msg::Tasks(TaskList { tasks })),
191                        },
192                        Err(e) => server_error(500, e.to_string()),
193                    };
194                }
195                client_message::Msg::KillTask(kill_task_msg) => {
196                    yield match self.kill_task(kill_task_msg.task_id).await {
197                        Ok(true) => server_pong(),
198                        Ok(false) => server_error(
199                            404,
200                            format!("task {} not found", kill_task_msg.task_id),
201                        ),
202                        Err(e) => server_error(500, e.to_string()),
203                    };
204                }
205                client_message::Msg::Approve(approve_msg) => {
206                    yield match self.approve_task(approve_msg.task_id, approve_msg.response).await {
207                        Ok(true) => server_pong(),
208                        Ok(false) => server_error(
209                            404,
210                            format!("task {} not found or not blocked", approve_msg.task_id),
211                        ),
212                        Err(e) => server_error(500, e.to_string()),
213                    };
214                }
215                client_message::Msg::SubscribeTasks(_) => {
216                    let s = self.subscribe_tasks();
217                    tokio::pin!(s);
218                    while let Some(result) = s.next().await {
219                        yield result_to_msg(result);
220                    }
221                }
222                client_message::Msg::Downloads(_) => {
223                    yield match self.list_downloads().await {
224                        Ok(downloads) => ServerMessage {
225                            msg: Some(server_message::Msg::Downloads(DownloadList { downloads })),
226                        },
227                        Err(e) => server_error(500, e.to_string()),
228                    };
229                }
230                client_message::Msg::SubscribeDownloads(_) => {
231                    let s = self.subscribe_downloads();
232                    tokio::pin!(s);
233                    while let Some(result) = s.next().await {
234                        yield result_to_msg(result);
235                    }
236                }
237                client_message::Msg::GetConfig(_) => {
238                    yield match self.get_config().await {
239                        Ok(config) => ServerMessage {
240                            msg: Some(server_message::Msg::Config(ConfigMsg { config })),
241                        },
242                        Err(e) => server_error(500, e.to_string()),
243                    };
244                }
245                client_message::Msg::SetConfig(set_config_msg) => {
246                    yield match self.set_config(set_config_msg.config).await {
247                        Ok(()) => server_pong(),
248                        Err(e) => server_error(500, e.to_string()),
249                    };
250                }
251                client_message::Msg::ServiceQuery(sq) => {
252                    yield match self.service_query(sq.service, sq.query).await {
253                        Ok(result) => ServerMessage {
254                            msg: Some(server_message::Msg::ServiceQueryResult(
255                                ServiceQueryResultMsg { result },
256                            )),
257                        },
258                        Err(e) => server_error(500, e.to_string()),
259                    };
260                }
261                client_message::Msg::GetServiceSchema(req) => {
262                    let service = req.service;
263                    yield match self.get_service_schema(service.clone()).await {
264                        Ok(schema) => ServerMessage {
265                            msg: Some(server_message::Msg::ServiceSchema(ServiceSchemaMsg {
266                                service,
267                                schema,
268                            })),
269                        },
270                        Err(e) => server_error(500, e.to_string()),
271                    };
272                }
273                client_message::Msg::GetAllSchemas(_) => {
274                    yield match self.get_all_schemas().await {
275                        Ok(schemas) => ServerMessage {
276                            msg: Some(server_message::Msg::AllSchemas(AllSchemasMsg { schemas })),
277                        },
278                        Err(e) => server_error(500, e.to_string()),
279                    };
280                }
281                client_message::Msg::GetServices(_) => {
282                    yield match self.list_services().await {
283                        Ok(services) => ServerMessage {
284                            msg: Some(server_message::Msg::ServiceList(ServiceListMsg { services })),
285                        },
286                        Err(e) => server_error(500, e.to_string()),
287                    };
288                }
289                client_message::Msg::SetServiceConfig(req) => {
290                    yield match self.set_service_config(req.service, req.config).await {
291                        Ok(()) => server_pong(),
292                        Err(e) => server_error(500, e.to_string()),
293                    };
294                }
295                client_message::Msg::Reload(_) => {
296                    yield match self.reload().await {
297                        Ok(()) => server_pong(),
298                        Err(e) => server_error(500, e.to_string()),
299                    };
300                }
301            }
302        }
303    }
304}