Skip to main content

walrus_core/protocol/api/
server.rs

1//! Server trait — one async method per protocol operation.
2
3use crate::protocol::message::{
4    DownloadEvent, DownloadRequest, HubAction, HubEvent, SendRequest, SendResponse, StreamEvent,
5    StreamRequest,
6    client::ClientMessage,
7    server::{ServerMessage, SessionInfo, TaskInfo},
8};
9use anyhow::Result;
10use futures_core::Stream;
11use futures_util::StreamExt;
12
13/// Server-side protocol handler.
14///
15/// Each method corresponds to one `ClientMessage` variant. Implementations
16/// receive typed request structs and return typed responses — no enum matching
17/// required. Streaming operations return `impl Stream`.
18///
19/// The provided [`dispatch`](Server::dispatch) method routes a raw
20/// `ClientMessage` to the appropriate handler, returning a stream of
21/// `ServerMessage`s.
22pub trait Server: Sync {
23    /// Handle `Send` — run agent and return complete response.
24    fn send(
25        &self,
26        req: SendRequest,
27    ) -> impl std::future::Future<Output = Result<SendResponse>> + Send;
28
29    /// Handle `Stream` — run agent and stream response events.
30    fn stream(&self, req: StreamRequest) -> impl Stream<Item = Result<StreamEvent>> + Send;
31
32    /// Handle `Download` — download model files with progress.
33    fn download(&self, req: DownloadRequest) -> impl Stream<Item = Result<DownloadEvent>> + Send;
34
35    /// Handle `Ping` — keepalive.
36    fn ping(&self) -> impl std::future::Future<Output = Result<()>> + Send;
37
38    /// Handle `Hub` — install or uninstall a hub package.
39    fn hub(
40        &self,
41        package: compact_str::CompactString,
42        action: HubAction,
43    ) -> impl Stream<Item = Result<HubEvent>> + Send;
44
45    /// Handle `Sessions` — list active sessions.
46    fn list_sessions(&self) -> impl std::future::Future<Output = Result<Vec<SessionInfo>>> + Send;
47
48    /// Handle `Kill` — close a session by ID.
49    fn kill_session(&self, session: u64) -> impl std::future::Future<Output = Result<bool>> + Send;
50
51    /// Handle `Tasks` — list tasks in the task registry.
52    fn list_tasks(&self) -> impl std::future::Future<Output = Result<Vec<TaskInfo>>> + Send;
53
54    /// Handle `KillTask` — cancel a task by ID.
55    fn kill_task(&self, task_id: u64) -> impl std::future::Future<Output = Result<bool>> + Send;
56
57    /// Handle `Approve` — approve a blocked task's inbox item.
58    fn approve_task(
59        &self,
60        task_id: u64,
61        response: String,
62    ) -> impl std::future::Future<Output = Result<bool>> + Send;
63
64    /// Dispatch a `ClientMessage` to the appropriate handler method.
65    ///
66    /// Returns a stream of `ServerMessage`s. Request-response operations
67    /// yield exactly one message; streaming operations yield many.
68    fn dispatch(&self, msg: ClientMessage) -> impl Stream<Item = ServerMessage> + Send + '_ {
69        async_stream::stream! {
70            match msg {
71                ClientMessage::Send { agent, content, session, sender } => {
72                    yield result_to_msg(self.send(SendRequest { agent, content, session, sender }).await);
73                }
74                ClientMessage::Stream { agent, content, session, sender } => {
75                    let s = self.stream(StreamRequest { agent, content, session, sender });
76                    tokio::pin!(s);
77                    while let Some(result) = s.next().await {
78                        yield result_to_msg(result);
79                    }
80                }
81                ClientMessage::Download { model } => {
82                    let s = self.download(DownloadRequest { model });
83                    tokio::pin!(s);
84                    while let Some(result) = s.next().await {
85                        yield result_to_msg(result);
86                    }
87                }
88                ClientMessage::Ping => {
89                    yield match self.ping().await {
90                        Ok(()) => ServerMessage::Pong,
91                        Err(e) => ServerMessage::Error {
92                            code: 500,
93                            message: e.to_string(),
94                        },
95                    };
96                }
97                ClientMessage::Hub { package, action } => {
98                    let s = self.hub(package, action);
99                    tokio::pin!(s);
100                    while let Some(result) = s.next().await {
101                        yield result_to_msg(result);
102                    }
103                }
104                ClientMessage::Sessions => {
105                    yield match self.list_sessions().await {
106                        Ok(sessions) => ServerMessage::Sessions(sessions),
107                        Err(e) => ServerMessage::Error {
108                            code: 500,
109                            message: e.to_string(),
110                        },
111                    };
112                }
113                ClientMessage::Kill { session } => {
114                    yield match self.kill_session(session).await {
115                        Ok(true) => ServerMessage::Pong,
116                        Ok(false) => ServerMessage::Error {
117                            code: 404,
118                            message: format!("session {session} not found"),
119                        },
120                        Err(e) => ServerMessage::Error {
121                            code: 500,
122                            message: e.to_string(),
123                        },
124                    };
125                }
126                ClientMessage::Tasks => {
127                    yield match self.list_tasks().await {
128                        Ok(tasks) => ServerMessage::Tasks(tasks),
129                        Err(e) => ServerMessage::Error {
130                            code: 500,
131                            message: e.to_string(),
132                        },
133                    };
134                }
135                ClientMessage::KillTask { task_id } => {
136                    yield match self.kill_task(task_id).await {
137                        Ok(true) => ServerMessage::Pong,
138                        Ok(false) => ServerMessage::Error {
139                            code: 404,
140                            message: format!("task {task_id} not found"),
141                        },
142                        Err(e) => ServerMessage::Error {
143                            code: 500,
144                            message: e.to_string(),
145                        },
146                    };
147                }
148                ClientMessage::Approve { task_id, response } => {
149                    yield match self.approve_task(task_id, response).await {
150                        Ok(true) => ServerMessage::Pong,
151                        Ok(false) => ServerMessage::Error {
152                            code: 404,
153                            message: format!("task {task_id} not found or not blocked"),
154                        },
155                        Err(e) => ServerMessage::Error {
156                            code: 500,
157                            message: e.to_string(),
158                        },
159                    };
160                }
161            }
162        }
163    }
164}
165
166/// Convert a typed `Result` into a `ServerMessage`.
167fn result_to_msg<T: Into<ServerMessage>>(result: Result<T>) -> ServerMessage {
168    match result {
169        Ok(resp) => resp.into(),
170        Err(e) => ServerMessage::Error {
171            code: 500,
172            message: e.to_string(),
173        },
174    }
175}