1use serde_json::json;
2use tokio::{
3 spawn,
4 sync::{mpsc, oneshot},
5};
6use worterbuch_common::{
7 Ack, CSet, CState, CStateEvent, ClientMessage, Delete, Err, ErrorCode, Get, INTERNAL_CLIENT_ID,
8 Lock, Ls, LsState, PDelete, PGet, PLs, PState, PStateEvent, PSubscribe, Publish,
9 RegularKeySegment, RequestPattern, SPub, SPubInit, ServerInfo, ServerMessage, Set, State,
10 StateEvent, Subscribe, SubscribeLs, TransactionId, Unsubscribe, UnsubscribeLs, WbApi, Welcome,
11 error::{ConnectionResult, WorterbuchError},
12};
13
14pub struct LocalClientSocket {
15 tx: mpsc::UnboundedSender<ClientMessage>,
16 rx: mpsc::UnboundedReceiver<ServerMessage>,
17 closed: oneshot::Receiver<()>,
18}
19
20impl LocalClientSocket {
21 pub fn new(
22 tx: mpsc::UnboundedSender<ClientMessage>,
23 rx: mpsc::UnboundedReceiver<ServerMessage>,
24 closed: oneshot::Receiver<()>,
25 ) -> Self {
26 Self { tx, rx, closed }
27 }
28
29 pub async fn send_msg(&self, msg: ClientMessage) -> ConnectionResult<()> {
30 self.tx.send(msg)?;
31 Ok(())
32 }
33
34 pub async fn receive_msg(&mut self) -> ConnectionResult<Option<ServerMessage>> {
35 Ok(self.rx.recv().await)
36 }
37
38 pub async fn close(self) -> ConnectionResult<()> {
39 drop(self.tx);
40 drop(self.rx);
41 self.closed.await.ok();
42 Ok(())
43 }
44
45 pub fn spawn_api_forward_loop(
46 api: impl WbApi + Send + Sync + 'static,
47 crx: mpsc::UnboundedReceiver<ClientMessage>,
48 stx: mpsc::UnboundedSender<ServerMessage>,
49 ) {
50 let future = forward_loop(api, crx, stx);
51 spawn(future);
52 }
53}
54
55async fn forward_loop(
56 api: impl WbApi + Send + Sync + 'static,
57 mut crx: mpsc::UnboundedReceiver<ClientMessage>,
58 stx: mpsc::UnboundedSender<ServerMessage>,
59) {
60 let spv = api.supported_protocol_versions();
61 let version = api.version().to_owned();
62 let welcome = Welcome {
63 client_id: INTERNAL_CLIENT_ID.to_string(),
64 info: ServerInfo::new(version, spv.into(), false),
65 };
66
67 if stx.send(ServerMessage::Welcome(welcome)).is_err() {
68 return;
69 }
70
71 while let Some(client_message) = crx.recv().await {
72 match client_message {
73 ClientMessage::ProtocolSwitchRequest(_) => {
74 stx.send(ServerMessage::Ack(Ack { transaction_id: 0 })).ok();
75 }
76 ClientMessage::AuthorizationRequest(_) => {
77 stx.send(ServerMessage::Err(Err {
78 error_code: ErrorCode::AlreadyAuthorized,
79 metadata: "No authorization required".to_owned(),
80 transaction_id: 0,
81 }))
82 .ok();
83 }
84 ClientMessage::Get(Get {
85 transaction_id,
86 key,
87 }) => match api.get(key).await {
88 Ok(val) => {
89 stx.send(ServerMessage::State(State {
90 event: StateEvent::Value(val),
91 transaction_id,
92 }))
93 .ok();
94 }
95 Result::Err(e) => handle_error(&stx, e, transaction_id).await,
96 },
97 ClientMessage::CGet(Get {
98 transaction_id,
99 key,
100 }) => match api.cget(key).await {
101 Ok(val) => {
102 stx.send(ServerMessage::CState(CState {
103 event: CStateEvent {
104 value: val.0,
105 version: val.1,
106 },
107 transaction_id,
108 }))
109 .ok();
110 }
111 Result::Err(e) => handle_error(&stx, e, transaction_id).await,
112 },
113 ClientMessage::PGet(PGet {
114 transaction_id,
115 request_pattern,
116 }) => match api.pget(request_pattern.clone()).await {
117 Ok(kvps) => {
118 stx.send(ServerMessage::PState(PState {
119 event: PStateEvent::KeyValuePairs(kvps),
120 request_pattern,
121 transaction_id,
122 }))
123 .ok();
124 }
125 Result::Err(e) => handle_error(&stx, e, transaction_id).await,
126 },
127 ClientMessage::Set(Set {
128 transaction_id,
129 key,
130 value,
131 }) => match api.set(key, value, INTERNAL_CLIENT_ID).await {
132 Ok(_) => {
133 stx.send(ServerMessage::Ack(Ack { transaction_id })).ok();
134 }
135 Result::Err(e) => handle_error(&stx, e, transaction_id).await,
136 },
137 ClientMessage::CSet(CSet {
138 transaction_id,
139 key,
140 value,
141 version,
142 }) => match api.cset(key, value, version, INTERNAL_CLIENT_ID).await {
143 Ok(_) => {
144 stx.send(ServerMessage::Ack(Ack { transaction_id })).ok();
145 }
146 Result::Err(e) => handle_error(&stx, e, transaction_id).await,
147 },
148 ClientMessage::SPubInit(SPubInit {
149 transaction_id,
150 key,
151 }) => match api.spub_init(transaction_id, key, INTERNAL_CLIENT_ID).await {
152 Ok(_) => {
153 stx.send(ServerMessage::Ack(Ack { transaction_id })).ok();
154 }
155 Result::Err(e) => handle_error(&stx, e, transaction_id).await,
156 },
157 ClientMessage::SPub(SPub {
158 transaction_id,
159 value,
160 }) => match api.spub(transaction_id, value, INTERNAL_CLIENT_ID).await {
161 Ok(_) => {
162 stx.send(ServerMessage::Ack(Ack { transaction_id })).ok();
163 }
164 Result::Err(e) => handle_error(&stx, e, transaction_id).await,
165 },
166 ClientMessage::Publish(Publish {
167 transaction_id,
168 key,
169 value,
170 }) => match api.publish(key, value).await {
171 Ok(_) => {
172 stx.send(ServerMessage::Ack(Ack { transaction_id })).ok();
173 }
174 Result::Err(e) => handle_error(&stx, e, transaction_id).await,
175 },
176 ClientMessage::Subscribe(Subscribe {
177 transaction_id,
178 key,
179 unique,
180 live_only,
181 }) => match api
182 .subscribe(
183 INTERNAL_CLIENT_ID,
184 transaction_id,
185 key,
186 unique,
187 live_only.unwrap_or(false),
188 )
189 .await
190 {
191 Ok((sub_rx, _)) => {
192 spawn_forward_sub_events_loop(sub_rx, transaction_id, stx.clone());
193 stx.send(ServerMessage::Ack(Ack { transaction_id })).ok();
194 }
195 Result::Err(e) => handle_error(&stx, e, transaction_id).await,
196 },
197 ClientMessage::PSubscribe(PSubscribe {
198 transaction_id,
199 request_pattern,
200 unique,
201 live_only,
202 aggregate_events: _,
203 }) => match api
204 .psubscribe(
205 INTERNAL_CLIENT_ID,
206 transaction_id,
207 request_pattern.clone(),
208 unique,
209 live_only.unwrap_or(false),
210 )
211 .await
212 {
213 Ok((psub_rx, _)) => {
214 spawn_forward_psub_events_loop(
215 psub_rx,
216 transaction_id,
217 request_pattern,
218 stx.clone(),
219 );
220 stx.send(ServerMessage::Ack(Ack { transaction_id })).ok();
221 }
222 Result::Err(e) => handle_error(&stx, e, transaction_id).await,
223 },
224 ClientMessage::Unsubscribe(Unsubscribe { transaction_id }) => {
225 match api.unsubscribe(INTERNAL_CLIENT_ID, transaction_id).await {
226 Ok(_) => {
227 stx.send(ServerMessage::Ack(Ack { transaction_id })).ok();
228 }
229 Result::Err(e) => handle_error(&stx, e, transaction_id).await,
230 }
231 }
232 ClientMessage::Delete(Delete {
233 transaction_id,
234 key,
235 }) => match api.delete(key, INTERNAL_CLIENT_ID).await {
236 Ok(val) => {
237 stx.send(ServerMessage::State(State {
238 transaction_id,
239 event: StateEvent::Deleted(val),
240 }))
241 .ok();
242 }
243 Result::Err(e) => handle_error(&stx, e, transaction_id).await,
244 },
245 ClientMessage::PDelete(PDelete {
246 transaction_id,
247 request_pattern,
248 quiet,
249 }) => match api
250 .pdelete(request_pattern.clone(), INTERNAL_CLIENT_ID)
251 .await
252 {
253 Ok(kvps) => {
254 if quiet.unwrap_or(false) {
255 stx.send(ServerMessage::PState(PState {
256 transaction_id,
257 request_pattern,
258 event: PStateEvent::Deleted(kvps),
259 }))
260 .ok();
261 } else {
262 stx.send(ServerMessage::Ack(Ack { transaction_id })).ok();
263 }
264 }
265 Result::Err(e) => handle_error(&stx, e, transaction_id).await,
266 },
267 ClientMessage::Ls(Ls {
268 transaction_id,
269 parent,
270 }) => match api.ls(parent).await {
271 Ok(children) => {
272 stx.send(ServerMessage::LsState(LsState {
273 transaction_id,
274 children,
275 }))
276 .ok();
277 }
278 Result::Err(e) => handle_error(&stx, e, transaction_id).await,
279 },
280 ClientMessage::PLs(PLs {
281 transaction_id,
282 parent_pattern,
283 }) => match api.pls(parent_pattern).await {
284 Ok(children) => {
285 stx.send(ServerMessage::LsState(LsState {
286 transaction_id,
287 children,
288 }))
289 .ok();
290 }
291 Result::Err(e) => handle_error(&stx, e, transaction_id).await,
292 },
293 ClientMessage::SubscribeLs(SubscribeLs {
294 transaction_id,
295 parent,
296 }) => match api
297 .subscribe_ls(INTERNAL_CLIENT_ID, transaction_id, parent)
298 .await
299 {
300 Ok((lssub_rx, _)) => {
301 spawn_forward_lssub_events_loop(lssub_rx, transaction_id, stx.clone());
302 stx.send(ServerMessage::Ack(Ack { transaction_id })).ok();
303 }
304 Result::Err(e) => handle_error(&stx, e, transaction_id).await,
305 },
306 ClientMessage::UnsubscribeLs(UnsubscribeLs { transaction_id }) => {
307 match api.unsubscribe_ls(INTERNAL_CLIENT_ID, transaction_id).await {
308 Ok(_) => {
309 stx.send(ServerMessage::Ack(Ack { transaction_id })).ok();
310 }
311 Result::Err(e) => handle_error(&stx, e, transaction_id).await,
312 }
313 }
314 ClientMessage::Lock(Lock {
315 transaction_id,
316 key,
317 }) => match api.lock(key, INTERNAL_CLIENT_ID).await {
318 Ok(_) => {
319 stx.send(ServerMessage::Ack(Ack { transaction_id })).ok();
320 }
321 Result::Err(e) => handle_error(&stx, e, transaction_id).await,
322 },
323 ClientMessage::AcquireLock(Lock {
324 transaction_id,
325 key,
326 }) => match api.acquire_lock(key, INTERNAL_CLIENT_ID).await {
327 Ok(_) => {
328 stx.send(ServerMessage::Ack(Ack { transaction_id })).ok();
329 }
330 Result::Err(e) => handle_error(&stx, e, transaction_id).await,
331 },
332 ClientMessage::ReleaseLock(Lock {
333 transaction_id,
334 key,
335 }) => match api.release_lock(key, INTERNAL_CLIENT_ID).await {
336 Ok(_) => {
337 stx.send(ServerMessage::Ack(Ack { transaction_id })).ok();
338 }
339 Result::Err(e) => handle_error(&stx, e, transaction_id).await,
340 },
341 ClientMessage::Transform(_) => todo!(),
342 }
343 }
344}
345
346async fn handle_error(
347 tx: &mpsc::UnboundedSender<ServerMessage>,
348 e: WorterbuchError,
349 transaction_id: TransactionId,
350) {
351 let error_code = ErrorCode::from(&e);
352 let err_msg = format!("{e}");
353 let err = Err {
354 error_code,
355 transaction_id,
356 metadata: json!(err_msg).to_string(),
357 };
358 tx.send(ServerMessage::Err(err)).ok();
359}
360
361fn spawn_forward_sub_events_loop(
362 sub_rx: mpsc::Receiver<StateEvent>,
363 transaction_id: TransactionId,
364 stx: mpsc::UnboundedSender<ServerMessage>,
365) {
366 spawn(forward_sub_events(sub_rx, transaction_id, stx));
367}
368
369async fn forward_sub_events(
370 mut sub_rx: mpsc::Receiver<StateEvent>,
371 transaction_id: TransactionId,
372 stx: mpsc::UnboundedSender<ServerMessage>,
373) {
374 while let Some(event) = sub_rx.recv().await {
375 if stx
376 .send(ServerMessage::State(State {
377 transaction_id,
378 event,
379 }))
380 .is_err()
381 {
382 break;
383 }
384 }
385}
386
387fn spawn_forward_psub_events_loop(
388 psub_rx: mpsc::Receiver<PStateEvent>,
389 transaction_id: TransactionId,
390 request_pattern: RequestPattern,
391 stx: mpsc::UnboundedSender<ServerMessage>,
392) {
393 spawn(forward_psub_events(
394 psub_rx,
395 transaction_id,
396 request_pattern,
397 stx,
398 ));
399}
400
401async fn forward_psub_events(
402 mut psub_rx: mpsc::Receiver<PStateEvent>,
403 transaction_id: TransactionId,
404 request_pattern: RequestPattern,
405 stx: mpsc::UnboundedSender<ServerMessage>,
406) {
407 while let Some(event) = psub_rx.recv().await {
408 let request_pattern = request_pattern.clone();
409 if stx
410 .send(ServerMessage::PState(PState {
411 transaction_id,
412 request_pattern,
413 event,
414 }))
415 .is_err()
416 {
417 break;
418 }
419 }
420}
421
422fn spawn_forward_lssub_events_loop(
423 lssub_rx: mpsc::Receiver<Vec<RegularKeySegment>>,
424 transaction_id: TransactionId,
425 stx: mpsc::UnboundedSender<ServerMessage>,
426) {
427 spawn(forward_lssub_events(lssub_rx, transaction_id, stx));
428}
429
430async fn forward_lssub_events(
431 mut lssub_rx: mpsc::Receiver<Vec<RegularKeySegment>>,
432 transaction_id: TransactionId,
433 stx: mpsc::UnboundedSender<ServerMessage>,
434) {
435 while let Some(children) = lssub_rx.recv().await {
436 if stx
437 .send(ServerMessage::LsState(LsState {
438 transaction_id,
439 children,
440 }))
441 .is_err()
442 {
443 break;
444 }
445 }
446}