walrus_core/protocol/api/
server.rs1use crate::protocol::message::{
4 AllSchemasMsg, ClientMessage, ConfigMsg, DownloadEvent, DownloadInfo, DownloadList, ErrorMsg,
5 HubAction, Pong, SendMsg, SendResponse, ServerMessage, ServiceInfoMsg, ServiceListMsg,
6 ServiceQueryResultMsg, ServiceSchemaMsg, SessionInfo, SessionList, StreamEvent, StreamMsg,
7 TaskEvent, TaskInfo, TaskList, client_message, server_message,
8};
9use anyhow::Result;
10use futures_core::Stream;
11use futures_util::StreamExt;
12
13fn server_error(code: u32, message: String) -> ServerMessage {
15 ServerMessage {
16 msg: Some(server_message::Msg::Error(ErrorMsg { code, message })),
17 }
18}
19
20fn server_pong() -> ServerMessage {
22 ServerMessage {
23 msg: Some(server_message::Msg::Pong(Pong {})),
24 }
25}
26
27fn result_to_msg<T: Into<ServerMessage>>(result: Result<T>) -> ServerMessage {
29 match result {
30 Ok(resp) => resp.into(),
31 Err(e) => server_error(500, e.to_string()),
32 }
33}
34
35pub trait Server: Sync {
45 fn send(&self, req: SendMsg) -> impl std::future::Future<Output = Result<SendResponse>> + Send;
47
48 fn stream(&self, req: StreamMsg) -> impl Stream<Item = Result<StreamEvent>> + Send;
50
51 fn ping(&self) -> impl std::future::Future<Output = Result<()>> + Send;
53
54 fn hub(
59 &self,
60 package: String,
61 action: HubAction,
62 filters: Vec<String>,
63 ) -> impl Stream<Item = Result<DownloadEvent>> + Send;
64
65 fn list_sessions(&self) -> impl std::future::Future<Output = Result<Vec<SessionInfo>>> + Send;
67
68 fn kill_session(&self, session: u64) -> impl std::future::Future<Output = Result<bool>> + Send;
70
71 fn list_tasks(&self) -> impl std::future::Future<Output = Result<Vec<TaskInfo>>> + Send;
73
74 fn kill_task(&self, task_id: u64) -> impl std::future::Future<Output = Result<bool>> + Send;
76
77 fn approve_task(
79 &self,
80 task_id: u64,
81 response: String,
82 ) -> impl std::future::Future<Output = Result<bool>> + Send;
83
84 fn subscribe_tasks(&self) -> impl Stream<Item = Result<TaskEvent>> + Send;
86
87 fn list_downloads(&self)
89 -> impl std::future::Future<Output = Result<Vec<DownloadInfo>>> + Send;
90
91 fn subscribe_downloads(&self) -> impl Stream<Item = Result<DownloadEvent>> + Send;
93
94 fn get_config(&self) -> impl std::future::Future<Output = Result<String>> + Send;
96
97 fn set_config(&self, config: String) -> impl std::future::Future<Output = Result<()>> + Send;
99
100 fn service_query(
102 &self,
103 service: String,
104 query: String,
105 ) -> impl std::future::Future<Output = Result<String>> + Send;
106
107 fn get_service_schema(
109 &self,
110 service: String,
111 ) -> impl std::future::Future<Output = Result<String>> + Send;
112
113 fn get_all_schemas(
115 &self,
116 ) -> impl std::future::Future<Output = Result<std::collections::HashMap<String, String>>> + Send;
117
118 fn list_services(
120 &self,
121 ) -> impl std::future::Future<Output = Result<Vec<ServiceInfoMsg>>> + Send;
122
123 fn set_service_config(
125 &self,
126 service: String,
127 config: String,
128 ) -> impl std::future::Future<Output = Result<()>> + Send;
129
130 fn reload(&self) -> impl std::future::Future<Output = Result<()>> + Send;
132
133 fn dispatch(&self, msg: ClientMessage) -> impl Stream<Item = ServerMessage> + Send + '_ {
138 async_stream::stream! {
139 let Some(inner) = msg.msg else {
140 yield server_error(400, "empty client message".to_string());
141 return;
142 };
143
144 match inner {
145 client_message::Msg::Send(send_msg) => {
146 yield result_to_msg(self.send(send_msg).await);
147 }
148 client_message::Msg::Stream(stream_msg) => {
149 let s = self.stream(stream_msg);
150 tokio::pin!(s);
151 while let Some(result) = s.next().await {
152 yield result_to_msg(result);
153 }
154 }
155 client_message::Msg::Ping(_) => {
156 yield match self.ping().await {
157 Ok(()) => server_pong(),
158 Err(e) => server_error(500, e.to_string()),
159 };
160 }
161 client_message::Msg::Hub(hub_msg) => {
162 let action = hub_msg.action();
163 let s = self.hub(hub_msg.package, action, hub_msg.filters);
164 tokio::pin!(s);
165 while let Some(result) = s.next().await {
166 yield result_to_msg(result);
167 }
168 }
169 client_message::Msg::Sessions(_) => {
170 yield match self.list_sessions().await {
171 Ok(sessions) => ServerMessage {
172 msg: Some(server_message::Msg::Sessions(SessionList { sessions })),
173 },
174 Err(e) => server_error(500, e.to_string()),
175 };
176 }
177 client_message::Msg::Kill(kill_msg) => {
178 yield match self.kill_session(kill_msg.session).await {
179 Ok(true) => server_pong(),
180 Ok(false) => server_error(
181 404,
182 format!("session {} not found", kill_msg.session),
183 ),
184 Err(e) => server_error(500, e.to_string()),
185 };
186 }
187 client_message::Msg::Tasks(_) => {
188 yield match self.list_tasks().await {
189 Ok(tasks) => ServerMessage {
190 msg: Some(server_message::Msg::Tasks(TaskList { tasks })),
191 },
192 Err(e) => server_error(500, e.to_string()),
193 };
194 }
195 client_message::Msg::KillTask(kill_task_msg) => {
196 yield match self.kill_task(kill_task_msg.task_id).await {
197 Ok(true) => server_pong(),
198 Ok(false) => server_error(
199 404,
200 format!("task {} not found", kill_task_msg.task_id),
201 ),
202 Err(e) => server_error(500, e.to_string()),
203 };
204 }
205 client_message::Msg::Approve(approve_msg) => {
206 yield match self.approve_task(approve_msg.task_id, approve_msg.response).await {
207 Ok(true) => server_pong(),
208 Ok(false) => server_error(
209 404,
210 format!("task {} not found or not blocked", approve_msg.task_id),
211 ),
212 Err(e) => server_error(500, e.to_string()),
213 };
214 }
215 client_message::Msg::SubscribeTasks(_) => {
216 let s = self.subscribe_tasks();
217 tokio::pin!(s);
218 while let Some(result) = s.next().await {
219 yield result_to_msg(result);
220 }
221 }
222 client_message::Msg::Downloads(_) => {
223 yield match self.list_downloads().await {
224 Ok(downloads) => ServerMessage {
225 msg: Some(server_message::Msg::Downloads(DownloadList { downloads })),
226 },
227 Err(e) => server_error(500, e.to_string()),
228 };
229 }
230 client_message::Msg::SubscribeDownloads(_) => {
231 let s = self.subscribe_downloads();
232 tokio::pin!(s);
233 while let Some(result) = s.next().await {
234 yield result_to_msg(result);
235 }
236 }
237 client_message::Msg::GetConfig(_) => {
238 yield match self.get_config().await {
239 Ok(config) => ServerMessage {
240 msg: Some(server_message::Msg::Config(ConfigMsg { config })),
241 },
242 Err(e) => server_error(500, e.to_string()),
243 };
244 }
245 client_message::Msg::SetConfig(set_config_msg) => {
246 yield match self.set_config(set_config_msg.config).await {
247 Ok(()) => server_pong(),
248 Err(e) => server_error(500, e.to_string()),
249 };
250 }
251 client_message::Msg::ServiceQuery(sq) => {
252 yield match self.service_query(sq.service, sq.query).await {
253 Ok(result) => ServerMessage {
254 msg: Some(server_message::Msg::ServiceQueryResult(
255 ServiceQueryResultMsg { result },
256 )),
257 },
258 Err(e) => server_error(500, e.to_string()),
259 };
260 }
261 client_message::Msg::GetServiceSchema(req) => {
262 let service = req.service;
263 yield match self.get_service_schema(service.clone()).await {
264 Ok(schema) => ServerMessage {
265 msg: Some(server_message::Msg::ServiceSchema(ServiceSchemaMsg {
266 service,
267 schema,
268 })),
269 },
270 Err(e) => server_error(500, e.to_string()),
271 };
272 }
273 client_message::Msg::GetAllSchemas(_) => {
274 yield match self.get_all_schemas().await {
275 Ok(schemas) => ServerMessage {
276 msg: Some(server_message::Msg::AllSchemas(AllSchemasMsg { schemas })),
277 },
278 Err(e) => server_error(500, e.to_string()),
279 };
280 }
281 client_message::Msg::GetServices(_) => {
282 yield match self.list_services().await {
283 Ok(services) => ServerMessage {
284 msg: Some(server_message::Msg::ServiceList(ServiceListMsg { services })),
285 },
286 Err(e) => server_error(500, e.to_string()),
287 };
288 }
289 client_message::Msg::SetServiceConfig(req) => {
290 yield match self.set_service_config(req.service, req.config).await {
291 Ok(()) => server_pong(),
292 Err(e) => server_error(500, e.to_string()),
293 };
294 }
295 client_message::Msg::Reload(_) => {
296 yield match self.reload().await {
297 Ok(()) => server_pong(),
298 Err(e) => server_error(500, e.to_string()),
299 };
300 }
301 }
302 }
303 }
304}