walrus_core/protocol/api/
server.rs1use crate::protocol::message::{
4 DownloadEvent, DownloadRequest, HubAction, HubEvent, SendRequest, SendResponse, StreamEvent,
5 StreamRequest, client::ClientMessage, server::ServerMessage,
6};
7use anyhow::Result;
8use futures_core::Stream;
9use futures_util::StreamExt;
10
11pub trait Server: Sync {
21 fn send(
23 &self,
24 req: SendRequest,
25 ) -> impl std::future::Future<Output = Result<SendResponse>> + Send;
26
27 fn stream(&self, req: StreamRequest) -> impl Stream<Item = Result<StreamEvent>> + Send;
29
30 fn download(&self, req: DownloadRequest) -> impl Stream<Item = Result<DownloadEvent>> + Send;
32
33 fn ping(&self) -> impl std::future::Future<Output = Result<()>> + Send;
35
36 fn hub(
38 &self,
39 package: compact_str::CompactString,
40 action: HubAction,
41 ) -> impl Stream<Item = Result<HubEvent>> + Send;
42
43 fn dispatch(&self, msg: ClientMessage) -> impl Stream<Item = ServerMessage> + Send + '_ {
48 async_stream::stream! {
49 match msg {
50 ClientMessage::Send { agent, content } => {
51 yield result_to_msg(self.send(SendRequest { agent, content }).await);
52 }
53 ClientMessage::Stream { agent, content } => {
54 let s = self.stream(StreamRequest { agent, content });
55 tokio::pin!(s);
56 while let Some(result) = s.next().await {
57 yield result_to_msg(result);
58 }
59 }
60 ClientMessage::Download { model } => {
61 let s = self.download(DownloadRequest { model });
62 tokio::pin!(s);
63 while let Some(result) = s.next().await {
64 yield result_to_msg(result);
65 }
66 }
67 ClientMessage::Ping => {
68 yield match self.ping().await {
69 Ok(()) => ServerMessage::Pong,
70 Err(e) => ServerMessage::Error {
71 code: 500,
72 message: e.to_string(),
73 },
74 };
75 }
76 ClientMessage::Hub { package, action } => {
77 let s = self.hub(package, action);
78 tokio::pin!(s);
79 while let Some(result) = s.next().await {
80 yield result_to_msg(result);
81 }
82 }
83 }
84 }
85 }
86}
87
88fn result_to_msg<T: Into<ServerMessage>>(result: Result<T>) -> ServerMessage {
90 match result {
91 Ok(resp) => resp.into(),
92 Err(e) => ServerMessage::Error {
93 code: 500,
94 message: e.to_string(),
95 },
96 }
97}