worterbuch_client/
local.rs

1use serde_json::json;
2use tokio::{
3    spawn,
4    sync::{mpsc, oneshot},
5};
6use worterbuch_common::{
7    Ack, CSet, CState, CStateEvent, ClientMessage, Delete, Err, ErrorCode, Get, INTERNAL_CLIENT_ID,
8    Lock, Ls, LsState, PDelete, PGet, PLs, PState, PStateEvent, PSubscribe, Publish,
9    RegularKeySegment, RequestPattern, SPub, SPubInit, ServerInfo, ServerMessage, Set, State,
10    StateEvent, Subscribe, SubscribeLs, TransactionId, Unsubscribe, UnsubscribeLs, WbApi, Welcome,
11    error::{ConnectionResult, WorterbuchError},
12};
13
14pub struct LocalClientSocket {
15    tx: mpsc::UnboundedSender<ClientMessage>,
16    rx: mpsc::UnboundedReceiver<ServerMessage>,
17    closed: oneshot::Receiver<()>,
18}
19
20impl LocalClientSocket {
21    pub fn new(
22        tx: mpsc::UnboundedSender<ClientMessage>,
23        rx: mpsc::UnboundedReceiver<ServerMessage>,
24        closed: oneshot::Receiver<()>,
25    ) -> Self {
26        Self { tx, rx, closed }
27    }
28
29    pub async fn send_msg(&self, msg: ClientMessage) -> ConnectionResult<()> {
30        self.tx.send(msg)?;
31        Ok(())
32    }
33
34    pub async fn receive_msg(&mut self) -> ConnectionResult<Option<ServerMessage>> {
35        Ok(self.rx.recv().await)
36    }
37
38    pub async fn close(self) -> ConnectionResult<()> {
39        drop(self.tx);
40        drop(self.rx);
41        self.closed.await.ok();
42        Ok(())
43    }
44
45    pub fn spawn_api_forward_loop(
46        api: impl WbApi + Send + Sync + 'static,
47        crx: mpsc::UnboundedReceiver<ClientMessage>,
48        stx: mpsc::UnboundedSender<ServerMessage>,
49    ) {
50        let future = forward_loop(api, crx, stx);
51        spawn(future);
52    }
53}
54
55async fn forward_loop(
56    api: impl WbApi + Send + Sync + 'static,
57    mut crx: mpsc::UnboundedReceiver<ClientMessage>,
58    stx: mpsc::UnboundedSender<ServerMessage>,
59) {
60    let spv = api.supported_protocol_versions();
61    let version = api.version().to_owned();
62    let welcome = Welcome {
63        client_id: INTERNAL_CLIENT_ID.to_string(),
64        info: ServerInfo::new(version, spv.into(), false),
65    };
66
67    if stx.send(ServerMessage::Welcome(welcome)).is_err() {
68        return;
69    }
70
71    while let Some(client_message) = crx.recv().await {
72        match client_message {
73            ClientMessage::ProtocolSwitchRequest(_) => {
74                stx.send(ServerMessage::Ack(Ack { transaction_id: 0 })).ok();
75            }
76            ClientMessage::AuthorizationRequest(_) => {
77                stx.send(ServerMessage::Err(Err {
78                    error_code: ErrorCode::AlreadyAuthorized,
79                    metadata: "No authorization required".to_owned(),
80                    transaction_id: 0,
81                }))
82                .ok();
83            }
84            ClientMessage::Get(Get {
85                transaction_id,
86                key,
87            }) => match api.get(key).await {
88                Ok(val) => {
89                    stx.send(ServerMessage::State(State {
90                        event: StateEvent::Value(val),
91                        transaction_id,
92                    }))
93                    .ok();
94                }
95                Result::Err(e) => handle_error(&stx, e, transaction_id).await,
96            },
97            ClientMessage::CGet(Get {
98                transaction_id,
99                key,
100            }) => match api.cget(key).await {
101                Ok(val) => {
102                    stx.send(ServerMessage::CState(CState {
103                        event: CStateEvent {
104                            value: val.0,
105                            version: val.1,
106                        },
107                        transaction_id,
108                    }))
109                    .ok();
110                }
111                Result::Err(e) => handle_error(&stx, e, transaction_id).await,
112            },
113            ClientMessage::PGet(PGet {
114                transaction_id,
115                request_pattern,
116            }) => match api.pget(request_pattern.clone()).await {
117                Ok(kvps) => {
118                    stx.send(ServerMessage::PState(PState {
119                        event: PStateEvent::KeyValuePairs(kvps),
120                        request_pattern,
121                        transaction_id,
122                    }))
123                    .ok();
124                }
125                Result::Err(e) => handle_error(&stx, e, transaction_id).await,
126            },
127            ClientMessage::Set(Set {
128                transaction_id,
129                key,
130                value,
131            }) => match api.set(key, value, INTERNAL_CLIENT_ID).await {
132                Ok(_) => {
133                    stx.send(ServerMessage::Ack(Ack { transaction_id })).ok();
134                }
135                Result::Err(e) => handle_error(&stx, e, transaction_id).await,
136            },
137            ClientMessage::CSet(CSet {
138                transaction_id,
139                key,
140                value,
141                version,
142            }) => match api.cset(key, value, version, INTERNAL_CLIENT_ID).await {
143                Ok(_) => {
144                    stx.send(ServerMessage::Ack(Ack { transaction_id })).ok();
145                }
146                Result::Err(e) => handle_error(&stx, e, transaction_id).await,
147            },
148            ClientMessage::SPubInit(SPubInit {
149                transaction_id,
150                key,
151            }) => match api.spub_init(transaction_id, key, INTERNAL_CLIENT_ID).await {
152                Ok(_) => {
153                    stx.send(ServerMessage::Ack(Ack { transaction_id })).ok();
154                }
155                Result::Err(e) => handle_error(&stx, e, transaction_id).await,
156            },
157            ClientMessage::SPub(SPub {
158                transaction_id,
159                value,
160            }) => match api.spub(transaction_id, value, INTERNAL_CLIENT_ID).await {
161                Ok(_) => {
162                    stx.send(ServerMessage::Ack(Ack { transaction_id })).ok();
163                }
164                Result::Err(e) => handle_error(&stx, e, transaction_id).await,
165            },
166            ClientMessage::Publish(Publish {
167                transaction_id,
168                key,
169                value,
170            }) => match api.publish(key, value).await {
171                Ok(_) => {
172                    stx.send(ServerMessage::Ack(Ack { transaction_id })).ok();
173                }
174                Result::Err(e) => handle_error(&stx, e, transaction_id).await,
175            },
176            ClientMessage::Subscribe(Subscribe {
177                transaction_id,
178                key,
179                unique,
180                live_only,
181            }) => match api
182                .subscribe(
183                    INTERNAL_CLIENT_ID,
184                    transaction_id,
185                    key,
186                    unique,
187                    live_only.unwrap_or(false),
188                )
189                .await
190            {
191                Ok((sub_rx, _)) => {
192                    spawn_forward_sub_events_loop(sub_rx, transaction_id, stx.clone());
193                    stx.send(ServerMessage::Ack(Ack { transaction_id })).ok();
194                }
195                Result::Err(e) => handle_error(&stx, e, transaction_id).await,
196            },
197            ClientMessage::PSubscribe(PSubscribe {
198                transaction_id,
199                request_pattern,
200                unique,
201                live_only,
202                aggregate_events: _,
203            }) => match api
204                .psubscribe(
205                    INTERNAL_CLIENT_ID,
206                    transaction_id,
207                    request_pattern.clone(),
208                    unique,
209                    live_only.unwrap_or(false),
210                )
211                .await
212            {
213                Ok((psub_rx, _)) => {
214                    spawn_forward_psub_events_loop(
215                        psub_rx,
216                        transaction_id,
217                        request_pattern,
218                        stx.clone(),
219                    );
220                    stx.send(ServerMessage::Ack(Ack { transaction_id })).ok();
221                }
222                Result::Err(e) => handle_error(&stx, e, transaction_id).await,
223            },
224            ClientMessage::Unsubscribe(Unsubscribe { transaction_id }) => {
225                match api.unsubscribe(INTERNAL_CLIENT_ID, transaction_id).await {
226                    Ok(_) => {
227                        stx.send(ServerMessage::Ack(Ack { transaction_id })).ok();
228                    }
229                    Result::Err(e) => handle_error(&stx, e, transaction_id).await,
230                }
231            }
232            ClientMessage::Delete(Delete {
233                transaction_id,
234                key,
235            }) => match api.delete(key, INTERNAL_CLIENT_ID).await {
236                Ok(val) => {
237                    stx.send(ServerMessage::State(State {
238                        transaction_id,
239                        event: StateEvent::Deleted(val),
240                    }))
241                    .ok();
242                }
243                Result::Err(e) => handle_error(&stx, e, transaction_id).await,
244            },
245            ClientMessage::PDelete(PDelete {
246                transaction_id,
247                request_pattern,
248                quiet,
249            }) => match api
250                .pdelete(request_pattern.clone(), INTERNAL_CLIENT_ID)
251                .await
252            {
253                Ok(kvps) => {
254                    if quiet.unwrap_or(false) {
255                        stx.send(ServerMessage::PState(PState {
256                            transaction_id,
257                            request_pattern,
258                            event: PStateEvent::Deleted(kvps),
259                        }))
260                        .ok();
261                    } else {
262                        stx.send(ServerMessage::Ack(Ack { transaction_id })).ok();
263                    }
264                }
265                Result::Err(e) => handle_error(&stx, e, transaction_id).await,
266            },
267            ClientMessage::Ls(Ls {
268                transaction_id,
269                parent,
270            }) => match api.ls(parent).await {
271                Ok(children) => {
272                    stx.send(ServerMessage::LsState(LsState {
273                        transaction_id,
274                        children,
275                    }))
276                    .ok();
277                }
278                Result::Err(e) => handle_error(&stx, e, transaction_id).await,
279            },
280            ClientMessage::PLs(PLs {
281                transaction_id,
282                parent_pattern,
283            }) => match api.pls(parent_pattern).await {
284                Ok(children) => {
285                    stx.send(ServerMessage::LsState(LsState {
286                        transaction_id,
287                        children,
288                    }))
289                    .ok();
290                }
291                Result::Err(e) => handle_error(&stx, e, transaction_id).await,
292            },
293            ClientMessage::SubscribeLs(SubscribeLs {
294                transaction_id,
295                parent,
296            }) => match api
297                .subscribe_ls(INTERNAL_CLIENT_ID, transaction_id, parent)
298                .await
299            {
300                Ok((lssub_rx, _)) => {
301                    spawn_forward_lssub_events_loop(lssub_rx, transaction_id, stx.clone());
302                    stx.send(ServerMessage::Ack(Ack { transaction_id })).ok();
303                }
304                Result::Err(e) => handle_error(&stx, e, transaction_id).await,
305            },
306            ClientMessage::UnsubscribeLs(UnsubscribeLs { transaction_id }) => {
307                match api.unsubscribe_ls(INTERNAL_CLIENT_ID, transaction_id).await {
308                    Ok(_) => {
309                        stx.send(ServerMessage::Ack(Ack { transaction_id })).ok();
310                    }
311                    Result::Err(e) => handle_error(&stx, e, transaction_id).await,
312                }
313            }
314            ClientMessage::Lock(Lock {
315                transaction_id,
316                key,
317            }) => match api.lock(key, INTERNAL_CLIENT_ID).await {
318                Ok(_) => {
319                    stx.send(ServerMessage::Ack(Ack { transaction_id })).ok();
320                }
321                Result::Err(e) => handle_error(&stx, e, transaction_id).await,
322            },
323            ClientMessage::AcquireLock(Lock {
324                transaction_id,
325                key,
326            }) => match api.acquire_lock(key, INTERNAL_CLIENT_ID).await {
327                Ok(_) => {
328                    stx.send(ServerMessage::Ack(Ack { transaction_id })).ok();
329                }
330                Result::Err(e) => handle_error(&stx, e, transaction_id).await,
331            },
332            ClientMessage::ReleaseLock(Lock {
333                transaction_id,
334                key,
335            }) => match api.release_lock(key, INTERNAL_CLIENT_ID).await {
336                Ok(_) => {
337                    stx.send(ServerMessage::Ack(Ack { transaction_id })).ok();
338                }
339                Result::Err(e) => handle_error(&stx, e, transaction_id).await,
340            },
341            ClientMessage::Transform(_) => todo!(),
342        }
343    }
344}
345
346async fn handle_error(
347    tx: &mpsc::UnboundedSender<ServerMessage>,
348    e: WorterbuchError,
349    transaction_id: TransactionId,
350) {
351    let error_code = ErrorCode::from(&e);
352    let err_msg = format!("{e}");
353    let err = Err {
354        error_code,
355        transaction_id,
356        metadata: json!(err_msg).to_string(),
357    };
358    tx.send(ServerMessage::Err(err)).ok();
359}
360
361fn spawn_forward_sub_events_loop(
362    sub_rx: mpsc::Receiver<StateEvent>,
363    transaction_id: TransactionId,
364    stx: mpsc::UnboundedSender<ServerMessage>,
365) {
366    spawn(forward_sub_events(sub_rx, transaction_id, stx));
367}
368
369async fn forward_sub_events(
370    mut sub_rx: mpsc::Receiver<StateEvent>,
371    transaction_id: TransactionId,
372    stx: mpsc::UnboundedSender<ServerMessage>,
373) {
374    while let Some(event) = sub_rx.recv().await {
375        if stx
376            .send(ServerMessage::State(State {
377                transaction_id,
378                event,
379            }))
380            .is_err()
381        {
382            break;
383        }
384    }
385}
386
387fn spawn_forward_psub_events_loop(
388    psub_rx: mpsc::Receiver<PStateEvent>,
389    transaction_id: TransactionId,
390    request_pattern: RequestPattern,
391    stx: mpsc::UnboundedSender<ServerMessage>,
392) {
393    spawn(forward_psub_events(
394        psub_rx,
395        transaction_id,
396        request_pattern,
397        stx,
398    ));
399}
400
401async fn forward_psub_events(
402    mut psub_rx: mpsc::Receiver<PStateEvent>,
403    transaction_id: TransactionId,
404    request_pattern: RequestPattern,
405    stx: mpsc::UnboundedSender<ServerMessage>,
406) {
407    while let Some(event) = psub_rx.recv().await {
408        let request_pattern = request_pattern.clone();
409        if stx
410            .send(ServerMessage::PState(PState {
411                transaction_id,
412                request_pattern,
413                event,
414            }))
415            .is_err()
416        {
417            break;
418        }
419    }
420}
421
422fn spawn_forward_lssub_events_loop(
423    lssub_rx: mpsc::Receiver<Vec<RegularKeySegment>>,
424    transaction_id: TransactionId,
425    stx: mpsc::UnboundedSender<ServerMessage>,
426) {
427    spawn(forward_lssub_events(lssub_rx, transaction_id, stx));
428}
429
430async fn forward_lssub_events(
431    mut lssub_rx: mpsc::Receiver<Vec<RegularKeySegment>>,
432    transaction_id: TransactionId,
433    stx: mpsc::UnboundedSender<ServerMessage>,
434) {
435    while let Some(children) = lssub_rx.recv().await {
436        if stx
437            .send(ServerMessage::LsState(LsState {
438                transaction_id,
439                children,
440            }))
441            .is_err()
442        {
443            break;
444        }
445    }
446}