walrus_core/protocol/api/
server.rs1use crate::protocol::message::{
4 DownloadEvent, DownloadRequest, HubAction, HubEvent, SendRequest, SendResponse, StreamEvent,
5 StreamRequest,
6 client::ClientMessage,
7 server::{ServerMessage, SessionInfo, TaskInfo},
8};
9use anyhow::Result;
10use futures_core::Stream;
11use futures_util::StreamExt;
12
13pub trait Server: Sync {
23 fn send(
25 &self,
26 req: SendRequest,
27 ) -> impl std::future::Future<Output = Result<SendResponse>> + Send;
28
29 fn stream(&self, req: StreamRequest) -> impl Stream<Item = Result<StreamEvent>> + Send;
31
32 fn download(&self, req: DownloadRequest) -> impl Stream<Item = Result<DownloadEvent>> + Send;
34
35 fn ping(&self) -> impl std::future::Future<Output = Result<()>> + Send;
37
38 fn hub(
40 &self,
41 package: compact_str::CompactString,
42 action: HubAction,
43 ) -> impl Stream<Item = Result<HubEvent>> + Send;
44
45 fn list_sessions(&self) -> impl std::future::Future<Output = Result<Vec<SessionInfo>>> + Send;
47
48 fn kill_session(&self, session: u64) -> impl std::future::Future<Output = Result<bool>> + Send;
50
51 fn list_tasks(&self) -> impl std::future::Future<Output = Result<Vec<TaskInfo>>> + Send;
53
54 fn kill_task(&self, task_id: u64) -> impl std::future::Future<Output = Result<bool>> + Send;
56
57 fn approve_task(
59 &self,
60 task_id: u64,
61 response: String,
62 ) -> impl std::future::Future<Output = Result<bool>> + Send;
63
64 fn dispatch(&self, msg: ClientMessage) -> impl Stream<Item = ServerMessage> + Send + '_ {
69 async_stream::stream! {
70 match msg {
71 ClientMessage::Send { agent, content, session, sender } => {
72 yield result_to_msg(self.send(SendRequest { agent, content, session, sender }).await);
73 }
74 ClientMessage::Stream { agent, content, session, sender } => {
75 let s = self.stream(StreamRequest { agent, content, session, sender });
76 tokio::pin!(s);
77 while let Some(result) = s.next().await {
78 yield result_to_msg(result);
79 }
80 }
81 ClientMessage::Download { model } => {
82 let s = self.download(DownloadRequest { model });
83 tokio::pin!(s);
84 while let Some(result) = s.next().await {
85 yield result_to_msg(result);
86 }
87 }
88 ClientMessage::Ping => {
89 yield match self.ping().await {
90 Ok(()) => ServerMessage::Pong,
91 Err(e) => ServerMessage::Error {
92 code: 500,
93 message: e.to_string(),
94 },
95 };
96 }
97 ClientMessage::Hub { package, action } => {
98 let s = self.hub(package, action);
99 tokio::pin!(s);
100 while let Some(result) = s.next().await {
101 yield result_to_msg(result);
102 }
103 }
104 ClientMessage::Sessions => {
105 yield match self.list_sessions().await {
106 Ok(sessions) => ServerMessage::Sessions(sessions),
107 Err(e) => ServerMessage::Error {
108 code: 500,
109 message: e.to_string(),
110 },
111 };
112 }
113 ClientMessage::Kill { session } => {
114 yield match self.kill_session(session).await {
115 Ok(true) => ServerMessage::Pong,
116 Ok(false) => ServerMessage::Error {
117 code: 404,
118 message: format!("session {session} not found"),
119 },
120 Err(e) => ServerMessage::Error {
121 code: 500,
122 message: e.to_string(),
123 },
124 };
125 }
126 ClientMessage::Tasks => {
127 yield match self.list_tasks().await {
128 Ok(tasks) => ServerMessage::Tasks(tasks),
129 Err(e) => ServerMessage::Error {
130 code: 500,
131 message: e.to_string(),
132 },
133 };
134 }
135 ClientMessage::KillTask { task_id } => {
136 yield match self.kill_task(task_id).await {
137 Ok(true) => ServerMessage::Pong,
138 Ok(false) => ServerMessage::Error {
139 code: 404,
140 message: format!("task {task_id} not found"),
141 },
142 Err(e) => ServerMessage::Error {
143 code: 500,
144 message: e.to_string(),
145 },
146 };
147 }
148 ClientMessage::Approve { task_id, response } => {
149 yield match self.approve_task(task_id, response).await {
150 Ok(true) => ServerMessage::Pong,
151 Ok(false) => ServerMessage::Error {
152 code: 404,
153 message: format!("task {task_id} not found or not blocked"),
154 },
155 Err(e) => ServerMessage::Error {
156 code: 500,
157 message: e.to_string(),
158 },
159 };
160 }
161 }
162 }
163 }
164}
165
166fn result_to_msg<T: Into<ServerMessage>>(result: Result<T>) -> ServerMessage {
168 match result {
169 Ok(resp) => resp.into(),
170 Err(e) => ServerMessage::Error {
171 code: 500,
172 message: e.to_string(),
173 },
174 }
175}