walrus_core/protocol/api/
server.rs1use crate::protocol::message::{
4 DownloadEvent, DownloadRequest, HubAction, MemoryOp, MemoryResult, SendRequest, SendResponse,
5 StreamEvent, StreamRequest, TaskEvent,
6 client::ClientMessage,
7 server::{DownloadInfo, ServerMessage, SessionInfo, TaskInfo},
8};
9use anyhow::Result;
10use futures_core::Stream;
11use futures_util::StreamExt;
12
13pub trait Server: Sync {
23 fn send(
25 &self,
26 req: SendRequest,
27 ) -> impl std::future::Future<Output = Result<SendResponse>> + Send;
28
29 fn stream(&self, req: StreamRequest) -> impl Stream<Item = Result<StreamEvent>> + Send;
31
32 fn download(&self, req: DownloadRequest) -> impl Stream<Item = Result<DownloadEvent>> + Send;
34
35 fn ping(&self) -> impl std::future::Future<Output = Result<()>> + Send;
37
38 fn hub(
40 &self,
41 package: compact_str::CompactString,
42 action: HubAction,
43 ) -> impl Stream<Item = Result<DownloadEvent>> + Send;
44
45 fn list_sessions(&self) -> impl std::future::Future<Output = Result<Vec<SessionInfo>>> + Send;
47
48 fn kill_session(&self, session: u64) -> impl std::future::Future<Output = Result<bool>> + Send;
50
51 fn list_tasks(&self) -> impl std::future::Future<Output = Result<Vec<TaskInfo>>> + Send;
53
54 fn kill_task(&self, task_id: u64) -> impl std::future::Future<Output = Result<bool>> + Send;
56
57 fn approve_task(
59 &self,
60 task_id: u64,
61 response: String,
62 ) -> impl std::future::Future<Output = Result<bool>> + Send;
63
64 fn evaluate(&self, req: SendRequest) -> impl std::future::Future<Output = Result<bool>> + Send;
66
67 fn subscribe_tasks(&self) -> impl Stream<Item = Result<TaskEvent>> + Send;
69
70 fn list_downloads(&self)
72 -> impl std::future::Future<Output = Result<Vec<DownloadInfo>>> + Send;
73
74 fn subscribe_downloads(&self) -> impl Stream<Item = Result<DownloadEvent>> + Send;
76
77 fn get_config(&self) -> impl std::future::Future<Output = Result<String>> + Send;
79
80 fn set_config(&self, config: String) -> impl std::future::Future<Output = Result<()>> + Send;
82
83 fn memory_query(
85 &self,
86 query: MemoryOp,
87 ) -> impl std::future::Future<Output = Result<MemoryResult>> + Send;
88
89 fn dispatch(&self, msg: ClientMessage) -> impl Stream<Item = ServerMessage> + Send + '_ {
94 async_stream::stream! {
95 match msg {
96 ClientMessage::Send { agent, content, session, sender } => {
97 yield result_to_msg(self.send(SendRequest { agent, content, session, sender }).await);
98 }
99 ClientMessage::Stream { agent, content, session, sender } => {
100 let s = self.stream(StreamRequest { agent, content, session, sender });
101 tokio::pin!(s);
102 while let Some(result) = s.next().await {
103 yield result_to_msg(result);
104 }
105 }
106 ClientMessage::Download { model } => {
107 let s = self.download(DownloadRequest { model });
108 tokio::pin!(s);
109 while let Some(result) = s.next().await {
110 yield result_to_msg(result);
111 }
112 }
113 ClientMessage::Ping => {
114 yield match self.ping().await {
115 Ok(()) => ServerMessage::Pong,
116 Err(e) => ServerMessage::Error {
117 code: 500,
118 message: e.to_string(),
119 },
120 };
121 }
122 ClientMessage::Hub { package, action } => {
123 let s = self.hub(package, action);
124 tokio::pin!(s);
125 while let Some(result) = s.next().await {
126 yield result_to_msg(result);
127 }
128 }
129 ClientMessage::Sessions => {
130 yield match self.list_sessions().await {
131 Ok(sessions) => ServerMessage::Sessions(sessions),
132 Err(e) => ServerMessage::Error {
133 code: 500,
134 message: e.to_string(),
135 },
136 };
137 }
138 ClientMessage::Kill { session } => {
139 yield match self.kill_session(session).await {
140 Ok(true) => ServerMessage::Pong,
141 Ok(false) => ServerMessage::Error {
142 code: 404,
143 message: format!("session {session} not found"),
144 },
145 Err(e) => ServerMessage::Error {
146 code: 500,
147 message: e.to_string(),
148 },
149 };
150 }
151 ClientMessage::Tasks => {
152 yield match self.list_tasks().await {
153 Ok(tasks) => ServerMessage::Tasks(tasks),
154 Err(e) => ServerMessage::Error {
155 code: 500,
156 message: e.to_string(),
157 },
158 };
159 }
160 ClientMessage::KillTask { task_id } => {
161 yield match self.kill_task(task_id).await {
162 Ok(true) => ServerMessage::Pong,
163 Ok(false) => ServerMessage::Error {
164 code: 404,
165 message: format!("task {task_id} not found"),
166 },
167 Err(e) => ServerMessage::Error {
168 code: 500,
169 message: e.to_string(),
170 },
171 };
172 }
173 ClientMessage::Approve { task_id, response } => {
174 yield match self.approve_task(task_id, response).await {
175 Ok(true) => ServerMessage::Pong,
176 Ok(false) => ServerMessage::Error {
177 code: 404,
178 message: format!("task {task_id} not found or not blocked"),
179 },
180 Err(e) => ServerMessage::Error {
181 code: 500,
182 message: e.to_string(),
183 },
184 };
185 }
186 ClientMessage::Evaluate { agent, content, session, sender } => {
187 yield match self.evaluate(SendRequest { agent, content, session, sender }).await {
188 Ok(respond) => ServerMessage::Evaluation { respond },
189 Err(e) => ServerMessage::Error {
190 code: 500,
191 message: e.to_string(),
192 },
193 };
194 }
195 ClientMessage::SubscribeTasks => {
196 let s = self.subscribe_tasks();
197 tokio::pin!(s);
198 while let Some(result) = s.next().await {
199 yield result_to_msg(result);
200 }
201 }
202 ClientMessage::Downloads => {
203 yield match self.list_downloads().await {
204 Ok(downloads) => ServerMessage::Downloads(downloads),
205 Err(e) => ServerMessage::Error {
206 code: 500,
207 message: e.to_string(),
208 },
209 };
210 }
211 ClientMessage::SubscribeDownloads => {
212 let s = self.subscribe_downloads();
213 tokio::pin!(s);
214 while let Some(result) = s.next().await {
215 yield result_to_msg(result);
216 }
217 }
218 ClientMessage::GetConfig => {
219 yield match self.get_config().await {
220 Ok(config) => ServerMessage::Config { config },
221 Err(e) => ServerMessage::Error {
222 code: 500,
223 message: e.to_string(),
224 },
225 };
226 }
227 ClientMessage::SetConfig { config } => {
228 yield match self.set_config(config).await {
229 Ok(()) => ServerMessage::Pong,
230 Err(e) => ServerMessage::Error {
231 code: 500,
232 message: e.to_string(),
233 },
234 };
235 }
236 ClientMessage::MemoryQuery { query } => {
237 yield match self.memory_query(query).await {
238 Ok(result) => ServerMessage::Memory(result),
239 Err(e) => ServerMessage::Error {
240 code: 500,
241 message: e.to_string(),
242 },
243 };
244 }
245 }
246 }
247 }
248}
249
250fn result_to_msg<T: Into<ServerMessage>>(result: Result<T>) -> ServerMessage {
252 match result {
253 Ok(resp) => resp.into(),
254 Err(e) => ServerMessage::Error {
255 code: 500,
256 message: e.to_string(),
257 },
258 }
259}