1pub mod buffer;
21pub mod config;
22pub mod error;
23pub mod local;
24#[cfg(feature = "tcp")]
25pub mod tcp;
26#[cfg(all(target_family = "unix", feature = "unix"))]
27pub mod unix;
28#[cfg(any(feature = "ws", feature = "wasm"))]
29pub mod ws;
30
31use crate::{config::Config, local::LocalClientSocket};
32use buffer::SendBuffer;
33use error::SubscriptionError;
34use futures_util::FutureExt;
35#[cfg(any(feature = "ws", feature = "wasm"))]
36use futures_util::{SinkExt, StreamExt};
37use hashbrown::HashMap;
38use serde::{Serialize, de::DeserializeOwned};
39use serde_json::{self as json};
40use std::{
41 fmt::{Debug, Display},
42 future::Future,
43 io,
44 net::SocketAddr,
45 ops::ControlFlow,
46 time::Duration,
47};
48#[cfg(feature = "tcp")]
49use tcp::TcpClientSocket;
50#[cfg(feature = "tcp")]
51use tokio::net::TcpStream;
52#[cfg(all(target_family = "unix", feature = "unix"))]
53use tokio::net::UnixStream;
54#[cfg(any(feature = "tcp", feature = "unix"))]
55use tokio::{
56 io::{AsyncBufReadExt, AsyncWriteExt, BufReader},
57 time::sleep,
58};
59#[cfg(feature = "tokio")]
60use tokio::{
61 select, spawn,
62 sync::{mpsc, oneshot},
63};
64#[cfg(feature = "ws")]
65use tokio_tungstenite::{connect_async_with_config, tungstenite::Message};
66#[cfg(feature = "wasm")]
67use tokio_tungstenite_wasm::{Message, connect as connect_wasm};
68use tracing::{Level, debug, error, info, instrument, trace, warn};
69#[cfg(all(target_family = "unix", feature = "unix"))]
70use unix::UnixClientSocket;
71use worterbuch_common::error::WorterbuchError;
72#[cfg(any(feature = "ws", feature = "wasm"))]
73use ws::WsClientSocket;
74
75pub use worterbuch_common::*;
76pub use worterbuch_common::{
77 self, Ack, AuthorizationRequest, ClientMessage as CM, Delete, Err, Get, GraveGoods, Key,
78 KeyValuePairs, LastWill, LsState, PState, PStateEvent, ProtocolVersion, RegularKeySegment,
79 ServerMessage as SM, Set, State, StateEvent, TransactionId,
80 error::{ConnectionError, ConnectionResult},
81};
82
83const PROTOCOL_VERSION: ProtocolVersion = ProtocolVersion::new(1, 1);
84
85#[derive(Debug)]
86pub(crate) enum Command {
87 Set(Key, Value, AckCallback),
88 SetAsync(Key, Value, AsyncTicket),
89 CSet(Key, Value, CasVersion, AckCallback),
90 CSetAsync(Key, Value, CasVersion, AsyncTicket),
91 SPubInit(Key, AckCallback),
92 SPubInitAsync(Key, AsyncTicket),
93 SPub(TransactionId, Value, AckCallback),
94 SPubAsync(TransactionId, Value, AsyncTicket),
95 Publish(Key, Value, AckCallback),
96 PublishAsync(Key, Value, AsyncTicket),
97 Get(Key, StateCallback),
98 GetAsync(Key, AsyncTicket),
99 CGet(Key, CStateCallback),
100 CGetAsync(Key, AsyncTicket),
101 PGet(Key, PStateCallback),
102 PGetAsync(Key, AsyncTicket),
103 Delete(Key, StateCallback),
104 DeleteAsync(Key, AsyncTicket),
105 PDelete(Key, bool, PStateCallback),
106 PDeleteAsync(Key, bool, AsyncTicket),
107 Ls(Option<Key>, LsStateCallback),
108 LsAsync(Option<Key>, AsyncTicket),
109 PLs(Option<RequestPattern>, LsStateCallback),
110 PLsAsync(Option<RequestPattern>, AsyncTicket),
111 Subscribe(
112 Key,
113 UniqueFlag,
114 AckCallback,
115 mpsc::UnboundedSender<Option<Value>>,
116 LiveOnlyFlag,
117 ),
118 SubscribeAsync(Key, UniqueFlag, AsyncTicket, LiveOnlyFlag),
119 PSubscribe(
120 Key,
121 UniqueFlag,
122 AckCallback,
123 mpsc::UnboundedSender<PStateEvent>,
124 Option<u64>,
125 LiveOnlyFlag,
126 ),
127 PSubscribeAsync(Key, UniqueFlag, AsyncTicket, Option<u64>, LiveOnlyFlag),
128 Unsubscribe(TransactionId, AckCallback),
129 UnsubscribeAsync(TransactionId, AsyncTicket),
130 SubscribeLs(
131 Option<Key>,
132 AckCallback,
133 mpsc::UnboundedSender<Vec<RegularKeySegment>>,
134 ),
135 SubscribeLsAsync(Option<Key>, AsyncTicket),
136 UnsubscribeLs(TransactionId, AckCallback),
137 UnsubscribeLsAsync(TransactionId, AsyncTicket),
138 Lock(Key, AckCallback),
139 LockAsync(Key, AsyncTicket),
140 AcquireLock(Key, AckCallback),
141 ReleaseLock(Key, AckCallback),
142 ReleaseLockAsync(Key, AsyncTicket),
143 AllMessages(GenericCallback),
144}
145
146enum ClientSocket {
147 #[cfg(feature = "tcp")]
148 Tcp(TcpClientSocket),
149 #[cfg(any(feature = "ws", feature = "wasm"))]
150 Ws(WsClientSocket),
151 #[cfg(all(target_family = "unix", feature = "unix"))]
152 Unix(UnixClientSocket),
153 Local(LocalClientSocket),
154}
155
156impl ClientSocket {
157 #[instrument(skip(self), level = "trace", err)]
158 pub async fn send_msg(&mut self, msg: CM, wait: bool) -> ConnectionResult<()> {
159 match self {
160 #[cfg(feature = "tcp")]
161 ClientSocket::Tcp(sock) => sock.send_msg(msg, wait).await,
162 #[cfg(any(feature = "ws", feature = "wasm"))]
163 ClientSocket::Ws(sock) => sock.send_msg(&msg).await,
164 #[cfg(all(target_family = "unix", feature = "unix"))]
165 ClientSocket::Unix(sock) => sock.send_msg(msg).await,
166 ClientSocket::Local(sock) => sock.send_msg(msg).await,
167 }
168 }
169
170 #[instrument(skip(self), level = "trace", err)]
171 pub async fn receive_msg(&mut self) -> ConnectionResult<Option<ServerMessage>> {
172 match self {
173 #[cfg(feature = "tcp")]
174 ClientSocket::Tcp(sock) => sock.receive_msg().await,
175 #[cfg(any(feature = "ws", feature = "wasm"))]
176 ClientSocket::Ws(sock) => sock.receive_msg().await,
177 #[cfg(all(target_family = "unix", feature = "unix"))]
178 ClientSocket::Unix(sock) => sock.receive_msg().await,
179 ClientSocket::Local(sock) => sock.receive_msg().await,
180 }
181 }
182
183 #[instrument(skip(self), level = "debug", err)]
184 pub async fn close(self) -> ConnectionResult<()> {
185 match self {
186 #[cfg(feature = "tcp")]
187 ClientSocket::Tcp(tcp_client_socket) => tcp_client_socket.close().await?,
188 #[cfg(any(feature = "ws", feature = "wasm"))]
189 ClientSocket::Ws(ws_client_socket) => ws_client_socket.close().await?,
190 #[cfg(all(target_family = "unix", feature = "unix"))]
191 ClientSocket::Unix(unix_client_socket) => unix_client_socket.close().await?,
192 ClientSocket::Local(unix_client_socket) => unix_client_socket.close().await?,
193 }
194 Ok(())
195 }
196}
197
198#[derive(Clone)]
199pub struct Worterbuch {
200 commands: mpsc::Sender<Command>,
201 stop: mpsc::Sender<oneshot::Sender<()>>,
202 client_id: String,
203}
204
205impl Worterbuch {
206 fn new(
207 commands: mpsc::Sender<Command>,
208 stop: mpsc::Sender<oneshot::Sender<()>>,
209 client_id: String,
210 ) -> Self {
211 Self {
212 commands,
213 stop,
214 client_id,
215 }
216 }
217
218 #[instrument(skip(self), err)]
219 pub async fn set_last_will(&self, last_will: &[KeyValuePair]) -> ConnectionResult<()> {
220 self.set(
221 topic!(
222 SYSTEM_TOPIC_ROOT,
223 SYSTEM_TOPIC_CLIENTS,
224 &self.client_id,
225 SYSTEM_TOPIC_LAST_WILL
226 ),
227 last_will,
228 )
229 .await
230 }
231
232 #[instrument(skip(self), err)]
233 pub async fn set_grave_goods(&self, grave_goods: &[&str]) -> ConnectionResult<()> {
234 self.set(
235 topic!(
236 SYSTEM_TOPIC_ROOT,
237 SYSTEM_TOPIC_CLIENTS,
238 &self.client_id,
239 SYSTEM_TOPIC_GRAVE_GOODS
240 ),
241 grave_goods,
242 )
243 .await
244 }
245
246 #[instrument(skip(self), err)]
247 pub async fn set_client_name<T: Display + Debug>(
248 &self,
249 client_name: T,
250 ) -> ConnectionResult<()> {
251 self.set(
252 topic!(
253 SYSTEM_TOPIC_ROOT,
254 SYSTEM_TOPIC_CLIENTS,
255 &self.client_id,
256 SYSTEM_TOPIC_CLIENT_NAME
257 ),
258 client_name.to_string(),
259 )
260 .await
261 }
262
263 #[instrument(skip(self), err)]
264 pub async fn set_generic(&self, key: Key, value: Value) -> ConnectionResult<()> {
265 let (tx, rx) = oneshot::channel();
266 let cmd = Command::Set(key, value, tx);
267 debug!("Queuing command {cmd:?}");
268 self.commands.send(cmd).await?;
269 debug!("Command queued.");
270 rx.await??;
271 Ok(())
272 }
273
274 #[instrument(skip(self, value), fields(value), err)]
275 pub async fn set<T: Serialize>(&self, key: Key, value: T) -> ConnectionResult<()> {
276 let value = json::to_value(value)?;
277 self.set_generic(key, value).await
278 }
279
280 #[instrument(skip(self), err)]
281 pub async fn set_generic_async(
282 &self,
283 key: Key,
284 value: Value,
285 ) -> ConnectionResult<TransactionId> {
286 let (tx, rx) = oneshot::channel();
287 let cmd = Command::SetAsync(key, value, tx);
288 debug!("Queuing command {cmd:?}");
289 self.commands.send(cmd).await?;
290 debug!("Command queued.");
291 let res = rx.await?;
292 Ok(res)
293 }
294
295 #[instrument(skip(self, value), fields(value), err)]
296 pub async fn set_async<T: Serialize>(
297 &self,
298 key: Key,
299 value: T,
300 ) -> ConnectionResult<TransactionId> {
301 let value = json::to_value(value)?;
302 self.set_generic_async(key, value).await
303 }
304
305 #[instrument(skip(self), err)]
306 pub async fn cset_generic(
307 &self,
308 key: Key,
309 value: Value,
310 version: CasVersion,
311 ) -> ConnectionResult<()> {
312 let (tx, rx) = oneshot::channel();
313 let cmd = Command::CSet(key, value, version, tx);
314 debug!("Queuing command {cmd:?}");
315 self.commands.send(cmd).await?;
316 debug!("Command queued.");
317 rx.await??;
318 Ok(())
319 }
320
321 #[instrument(skip(self, value), fields(value), err)]
322 pub async fn cset<T: Serialize>(
323 &self,
324 key: Key,
325 value: T,
326 version: CasVersion,
327 ) -> ConnectionResult<()> {
328 let value = json::to_value(value)?;
329 self.cset_generic(key, value, version).await
330 }
331
332 #[instrument(skip(self), err)]
333 pub async fn cset_generic_async(
334 &self,
335 key: Key,
336 value: Value,
337 version: CasVersion,
338 ) -> ConnectionResult<TransactionId> {
339 let (tx, rx) = oneshot::channel();
340 let cmd = Command::CSetAsync(key, value, version, tx);
341 debug!("Queuing command {cmd:?}");
342 self.commands.send(cmd).await?;
343 debug!("Command queued.");
344 let res = rx.await?;
345 Ok(res)
346 }
347
348 #[instrument(skip(self, value), fields(value), err)]
349 pub async fn cset_async<T: Serialize>(
350 &self,
351 key: Key,
352 value: &T,
353 version: CasVersion,
354 ) -> ConnectionResult<TransactionId> {
355 let value = serde_json::to_value(value)?;
356 self.cset_generic_async(key, value, version).await
357 }
358
359 #[instrument(skip(self), err)]
360 pub async fn spub_init(&self, key: Key) -> ConnectionResult<TransactionId> {
361 let (tx, rx) = oneshot::channel();
362 let cmd = Command::SPubInit(key, tx);
363 debug!("Queuing command {cmd:?}");
364 self.commands.send(cmd).await?;
365 debug!("Command queued.");
366 let res = rx.await??;
367 Ok(res.transaction_id)
368 }
369
370 #[instrument(skip(self), err)]
371 pub async fn spub_init_async(&self, key: Key) -> ConnectionResult<TransactionId> {
372 let (tx, rx) = oneshot::channel();
373 let cmd = Command::SPubInitAsync(key, tx);
374 debug!("Queuing command {cmd:?}");
375 self.commands.send(cmd).await?;
376 debug!("Command queued.");
377 let transaction_id = rx.await?;
378 Ok(transaction_id)
379 }
380
381 #[instrument(skip(self), err)]
382 pub async fn spub_generic(
383 &self,
384 transaction_id: TransactionId,
385 value: Value,
386 ) -> ConnectionResult<()> {
387 let (tx, rx) = oneshot::channel();
388 let cmd = Command::SPub(transaction_id, value, tx);
389 debug!("Queuing command {cmd:?}");
390 self.commands.send(cmd).await?;
391 debug!("Command queued.");
392 rx.await??;
393 Ok(())
394 }
395
396 #[instrument(skip(self, value), fields(value), err)]
397 pub async fn spub<T: Serialize>(
398 &self,
399 transaction_id: TransactionId,
400 value: &T,
401 ) -> ConnectionResult<()> {
402 let value = serde_json::to_value(value)?;
403 self.spub_generic(transaction_id, value).await
404 }
405
406 #[instrument(skip(self), err)]
407 pub async fn spub_generic_async(
408 &self,
409 transaction_id: TransactionId,
410 value: Value,
411 ) -> ConnectionResult<TransactionId> {
412 let (tx, rx) = oneshot::channel();
413 let cmd = Command::SPubAsync(transaction_id, value, tx);
414 debug!("Queuing command {cmd:?}");
415 self.commands.send(cmd).await?;
416 debug!("Command queued.");
417 let res = rx.await?;
418 Ok(res)
419 }
420
421 #[instrument(skip(self, value), fields(value), err)]
422 pub async fn spub_async<T: Serialize>(
423 &self,
424 transaction_id: TransactionId,
425 value: &T,
426 ) -> ConnectionResult<TransactionId> {
427 let value = serde_json::to_value(value)?;
428 self.spub_generic_async(transaction_id, value).await
429 }
430
431 #[instrument(skip(self), err)]
432 pub async fn publish_generic(&self, key: Key, value: Value) -> ConnectionResult<()> {
433 let (tx, rx) = oneshot::channel();
434 let cmd = Command::Publish(key, value, tx);
435 debug!("Queuing command {cmd:?}");
436 self.commands.send(cmd).await?;
437 debug!("Command queued.");
438 rx.await??;
439 Ok(())
440 }
441
442 #[instrument(skip(self, value), fields(value), err)]
443 pub async fn publish<T: Serialize>(&self, key: Key, value: &T) -> ConnectionResult<()> {
444 let value = json::to_value(value)?;
445 self.publish_generic(key, value).await
446 }
447
448 #[instrument(skip(self), err)]
449 pub async fn publish_generic_async(
450 &self,
451 key: Key,
452 value: Value,
453 ) -> ConnectionResult<TransactionId> {
454 let (tx, rx) = oneshot::channel();
455 let cmd = Command::PublishAsync(key, value, tx);
456 debug!("Queuing command {cmd:?}");
457 self.commands.send(cmd).await?;
458 debug!("Command queued.");
459 let res = rx.await?;
460 Ok(res)
461 }
462
463 #[instrument(skip(self, value), fields(value), err)]
464 pub async fn publish_async<T: Serialize>(
465 &self,
466 key: Key,
467 value: &T,
468 ) -> ConnectionResult<TransactionId> {
469 let value = json::to_value(value)?;
470 self.publish_generic_async(key, value).await
471 }
472
473 #[instrument(skip(self), err)]
474 pub async fn get_async(&self, key: Key) -> ConnectionResult<TransactionId> {
475 let (tx, rx) = oneshot::channel();
476 let cmd = Command::GetAsync(key, tx);
477 debug!("Queuing command {cmd:?}");
478 self.commands.send(cmd).await?;
479 debug!("Command queued.");
480 let res = rx.await?;
481 Ok(res)
482 }
483
484 #[instrument(skip(self), ret, err)]
485 pub async fn get_generic(&self, key: Key) -> ConnectionResult<Option<Value>> {
486 let (tx, rx) = oneshot::channel();
487 let cmd = Command::Get(key, tx);
488 debug!("Queuing command {cmd:?}");
489 self.commands.send(cmd).await?;
490 debug!("Command queued.");
491 match rx.await? {
492 Ok(state) => {
493 if let StateEvent::Value(val) = state.event {
494 Ok(Some(val))
495 } else {
496 Ok(None)
497 }
498 }
499 Result::Err(e) => {
500 if e.error_code == ErrorCode::NoSuchValue {
501 Ok(None)
502 } else {
503 Err(e.into())
504 }
505 }
506 }
507 }
508
509 #[instrument(skip(self), err)]
510 pub async fn get<T: DeserializeOwned>(&self, key: Key) -> ConnectionResult<Option<T>> {
511 if let Some(val) = self.get_generic(key).await? {
512 Ok(Some(json::from_value(val)?))
513 } else {
514 Ok(None)
515 }
516 }
517
518 #[instrument(skip(self), err)]
519 pub async fn cget_async(&self, key: Key) -> ConnectionResult<TransactionId> {
520 let (tx, rx) = oneshot::channel();
521 let cmd = Command::CGetAsync(key, tx);
522 debug!("Queuing command {cmd:?}");
523 self.commands.send(cmd).await?;
524 debug!("Command queued.");
525 let res = rx.await?;
526 Ok(res)
527 }
528
529 #[instrument(skip(self), ret, err)]
530 pub async fn cget_generic(&self, key: Key) -> ConnectionResult<Option<(Value, CasVersion)>> {
531 let (tx, rx) = oneshot::channel();
532 let cmd = Command::CGet(key, tx);
533 debug!("Queuing command {cmd:?}");
534 self.commands.send(cmd).await?;
535 debug!("Command queued.");
536 match rx.await? {
537 Ok(state) => Ok(Some((state.event.value, state.event.version))),
538 Result::Err(e) => {
539 if e.error_code == ErrorCode::NoSuchValue {
540 Ok(None)
541 } else {
542 Err(e.into())
543 }
544 }
545 }
546 }
547
548 #[instrument(skip(self), err)]
549 pub async fn cget<T: DeserializeOwned>(
550 &self,
551 key: Key,
552 ) -> ConnectionResult<Option<(T, CasVersion)>> {
553 if let Some((val, version)) = self.cget_generic(key).await? {
554 Ok(Some((json::from_value(val)?, version)))
555 } else {
556 Ok(None)
557 }
558 }
559
560 #[instrument(skip(self), err)]
561 pub async fn pget_async(&self, key: Key) -> ConnectionResult<TransactionId> {
562 let (tx, rx) = oneshot::channel();
563 let cmd = Command::PGetAsync(key, tx);
564 debug!("Queuing command {cmd:?}");
565 self.commands.send(cmd).await?;
566 debug!("Command queued.");
567 let tid = rx.await?;
568 Ok(tid)
569 }
570
571 #[instrument(skip(self), ret, err)]
572 pub async fn pget_generic(&self, key: Key) -> ConnectionResult<KeyValuePairs> {
573 let (tx, rx) = oneshot::channel();
574 let cmd = Command::PGet(key, tx);
575 debug!("Queuing command {cmd:?}");
576 self.commands.send(cmd).await?;
577 debug!("Command queued.");
578 match rx.await??.event {
579 PStateEvent::KeyValuePairs(kvps) => Ok(kvps),
580 PStateEvent::Deleted(_) => Ok(vec![]),
581 }
582 }
583
584 #[instrument(skip(self), err)]
585 pub async fn pget<T: DeserializeOwned + Debug>(
586 &self,
587 key: Key,
588 ) -> ConnectionResult<TypedKeyValuePairs<T>> {
589 let kvps = self.pget_generic(key).await?;
590 let typed_kvps = deserialize_key_value_pairs(kvps)?;
591 Ok(typed_kvps)
592 }
593
594 #[instrument(skip(self), err)]
595 pub async fn delete_async(&self, key: Key) -> ConnectionResult<TransactionId> {
596 let (tx, rx) = oneshot::channel();
597 let cmd = Command::DeleteAsync(key, tx);
598 debug!("Queuing command {cmd:?}");
599 self.commands.send(cmd).await?;
600 debug!("Command queued.");
601 let tid = rx.await?;
602 Ok(tid)
603 }
604
605 #[instrument(skip(self), ret, err)]
606 pub async fn delete_generic(&self, key: Key) -> ConnectionResult<Option<Value>> {
607 let (tx, rx) = oneshot::channel();
608 let cmd = Command::Delete(key, tx);
609 debug!("Queuing command {cmd:?}");
610 self.commands.send(cmd).await?;
611 debug!("Command queued.");
612 match rx.await? {
613 Ok(state) => match state.event {
614 StateEvent::Value(_) => Ok(None),
615 StateEvent::Deleted(value) => Ok(Some(value)),
616 },
617 Result::Err(e) => {
618 if e.error_code == ErrorCode::NoSuchValue {
619 Ok(None)
620 } else {
621 Err(e.into())
622 }
623 }
624 }
625 }
626
627 #[instrument(skip(self), err)]
628 pub async fn delete<T: DeserializeOwned + Debug>(
629 &self,
630 key: Key,
631 ) -> ConnectionResult<Option<T>> {
632 if let Some(val) = self.delete_generic(key).await? {
633 Ok(Some(json::from_value(val)?))
634 } else {
635 Ok(None)
636 }
637 }
638
639 #[instrument(skip(self), err)]
640 pub async fn pdelete_async(&self, key: Key, quiet: bool) -> ConnectionResult<TransactionId> {
641 let (tx, rx) = oneshot::channel();
642 let cmd = Command::PDeleteAsync(key, quiet, tx);
643 debug!("Queuing command {cmd:?}");
644 self.commands.send(cmd).await?;
645 debug!("Command queued.");
646 let tid = rx.await?;
647 Ok(tid)
648 }
649
650 #[instrument(skip(self), ret, err)]
651 pub async fn pdelete_generic(&self, key: Key, quiet: bool) -> ConnectionResult<KeyValuePairs> {
652 let (tx, rx) = oneshot::channel();
653 let cmd = Command::PDelete(key, quiet, tx);
654 debug!("Queuing command {cmd:?}");
655 self.commands.send(cmd).await?;
656 debug!("Command queued.");
657 match rx.await??.event {
658 PStateEvent::KeyValuePairs(_) => Ok(vec![]),
659 PStateEvent::Deleted(kvps) => Ok(kvps),
660 }
661 }
662
663 #[instrument(skip(self), err)]
664 pub async fn pdelete<T: DeserializeOwned + Debug>(
665 &self,
666 key: Key,
667 quiet: bool,
668 ) -> ConnectionResult<TypedKeyValuePairs<T>> {
669 let kvps = self.pdelete_generic(key, quiet).await?;
670 let typed_kvps = deserialize_key_value_pairs(kvps)?;
671 Ok(typed_kvps)
672 }
673
674 #[instrument(skip(self), err)]
675 pub async fn ls_async(&self, parent: Option<Key>) -> ConnectionResult<TransactionId> {
676 let (tx, rx) = oneshot::channel();
677 let cmd = Command::LsAsync(parent, tx);
678 debug!("Queuing command {cmd:?}");
679 self.commands.send(cmd).await?;
680 debug!("Command queued.");
681 let tid = rx.await?;
682 Ok(tid)
683 }
684
685 #[instrument(skip(self), ret, err)]
686 pub async fn ls(&self, parent: Option<Key>) -> ConnectionResult<Vec<RegularKeySegment>> {
687 let (tx, rx) = oneshot::channel();
688 let cmd = Command::Ls(parent, tx);
689 debug!("Queuing command {cmd:?}");
690 self.commands.send(cmd).await?;
691 debug!("Command queued.");
692 let children = rx.await??.children;
693 Ok(children)
694 }
695
696 #[instrument(skip(self), err)]
697 pub async fn pls_async(
698 &self,
699 parent: Option<RequestPattern>,
700 ) -> ConnectionResult<TransactionId> {
701 let (tx, rx) = oneshot::channel();
702 let cmd = Command::PLsAsync(parent, tx);
703 debug!("Queuing command {cmd:?}");
704 self.commands.send(cmd).await?;
705 debug!("Command queued.");
706 let tid = rx.await?;
707 Ok(tid)
708 }
709
710 #[instrument(skip(self), ret, err)]
711 pub async fn pls(
712 &self,
713 parent: Option<RequestPattern>,
714 ) -> ConnectionResult<Vec<RegularKeySegment>> {
715 let (tx, rx) = oneshot::channel();
716 let cmd = Command::PLs(parent, tx);
717 debug!("Queuing command {cmd:?}");
718 self.commands.send(cmd).await?;
719 debug!("Command queued.");
720 let children = rx.await??.children;
721 Ok(children)
722 }
723
724 #[instrument(skip(self), err)]
725 pub async fn subscribe_async(
726 &self,
727 key: Key,
728 unique: bool,
729 live_only: bool,
730 ) -> ConnectionResult<TransactionId> {
731 let (tx, rx) = oneshot::channel();
732 self.commands
733 .send(Command::SubscribeAsync(key, unique, tx, live_only))
734 .await?;
735 let tid = rx.await?;
736 Ok(tid)
737 }
738
739 #[instrument(skip(self), err)]
740 pub async fn subscribe_generic(
741 &self,
742 key: Key,
743 unique: bool,
744 live_only: bool,
745 ) -> ConnectionResult<(mpsc::UnboundedReceiver<Option<Value>>, TransactionId)> {
746 let (tid_tx, tid_rx) = oneshot::channel();
747 let (val_tx, val_rx) = mpsc::unbounded_channel();
748 self.commands
749 .send(Command::Subscribe(key, unique, tid_tx, val_tx, live_only))
750 .await?;
751 let res = tid_rx.await??;
752 Ok((val_rx, res.transaction_id))
753 }
754
755 #[instrument(skip(self), err)]
756 pub async fn subscribe<T: DeserializeOwned + Send + 'static>(
757 &self,
758 key: Key,
759 unique: bool,
760 live_only: bool,
761 ) -> ConnectionResult<(mpsc::UnboundedReceiver<Option<T>>, TransactionId)> {
762 let (val_rx, transaction_id) = self.subscribe_generic(key, unique, live_only).await?;
763 let (typed_val_tx, typed_val_rx) = mpsc::unbounded_channel();
764 spawn(deserialize_values(val_rx, typed_val_tx));
765 Ok((typed_val_rx, transaction_id))
766 }
767
768 #[instrument(skip(self), err)]
769 pub async fn psubscribe_async(
770 &self,
771 request_pattern: RequestPattern,
772 unique: bool,
773 live_only: bool,
774 aggregation_duration: Option<Duration>,
775 ) -> ConnectionResult<TransactionId> {
776 let (tx, rx) = oneshot::channel();
777 self.commands
778 .send(Command::PSubscribeAsync(
779 request_pattern,
780 unique,
781 tx,
782 aggregation_duration.map(|d| d.as_millis() as u64),
783 live_only,
784 ))
785 .await?;
786 let tid = rx.await?;
787 Ok(tid)
788 }
789
790 #[instrument(skip(self), err)]
791 pub async fn psubscribe_generic(
792 &self,
793 request_pattern: RequestPattern,
794 unique: bool,
795 live_only: bool,
796 aggregation_duration: Option<Duration>,
797 ) -> ConnectionResult<(mpsc::UnboundedReceiver<PStateEvent>, TransactionId)> {
798 let (tid_tx, tid_rx) = oneshot::channel();
799 let (event_tx, event_rx) = mpsc::unbounded_channel();
800 self.commands
801 .send(Command::PSubscribe(
802 request_pattern,
803 unique,
804 tid_tx,
805 event_tx,
806 aggregation_duration.map(|d| d.as_millis() as u64),
807 live_only,
808 ))
809 .await?;
810 let res = tid_rx.await??;
811 Ok((event_rx, res.transaction_id))
812 }
813
814 #[instrument(skip(self), err)]
815 pub async fn psubscribe<T: DeserializeOwned + Debug + Send + 'static>(
816 &self,
817 request_pattern: RequestPattern,
818 unique: bool,
819 live_only: bool,
820 aggregation_duration: Option<Duration>,
821 ) -> ConnectionResult<(mpsc::UnboundedReceiver<TypedPStateEvent<T>>, TransactionId)> {
822 let (event_rx, transaction_id) = self
823 .psubscribe_generic(request_pattern, unique, live_only, aggregation_duration)
824 .await?;
825 let (typed_event_tx, typed_event_rx) = mpsc::unbounded_channel();
826 spawn(deserialize_events(event_rx, typed_event_tx));
827 Ok((typed_event_rx, transaction_id))
828 }
829
830 #[instrument(skip(self), err)]
831 pub async fn unsubscribe(&self, transaction_id: TransactionId) -> ConnectionResult<()> {
832 let (tx, rx) = oneshot::channel();
833 self.commands
834 .send(Command::Unsubscribe(transaction_id, tx))
835 .await?;
836 rx.await??;
837 Ok(())
838 }
839
840 #[instrument(skip(self), err)]
841 pub async fn unsubscribe_async(
842 &self,
843 transaction_id: TransactionId,
844 ) -> ConnectionResult<TransactionId> {
845 let (tx, rx) = oneshot::channel();
846 self.commands
847 .send(Command::UnsubscribeAsync(transaction_id, tx))
848 .await?;
849 let res = rx.await?;
850 Ok(res)
851 }
852
853 #[instrument(skip(self), err)]
854 pub async fn subscribe_ls_async(&self, parent: Option<Key>) -> ConnectionResult<TransactionId> {
855 let (tx, rx) = oneshot::channel();
856 self.commands
857 .send(Command::SubscribeLsAsync(parent, tx))
858 .await?;
859 let tid = rx.await?;
860 Ok(tid)
861 }
862
863 #[instrument(skip(self), err)]
864 pub async fn subscribe_ls(
865 &self,
866 parent: Option<Key>,
867 ) -> ConnectionResult<(
868 mpsc::UnboundedReceiver<Vec<RegularKeySegment>>,
869 TransactionId,
870 )> {
871 let (tid_tx, tid_rx) = oneshot::channel();
872 let (children_tx, children_rx) = mpsc::unbounded_channel();
873 self.commands
874 .send(Command::SubscribeLs(parent, tid_tx, children_tx))
875 .await?;
876 let res = tid_rx.await??;
877 Ok((children_rx, res.transaction_id))
878 }
879
880 #[instrument(skip(self), err)]
881 pub async fn unsubscribe_ls(&self, transaction_id: TransactionId) -> ConnectionResult<()> {
882 let (tx, rx) = oneshot::channel();
883 self.commands
884 .send(Command::UnsubscribeLs(transaction_id, tx))
885 .await?;
886 rx.await??;
887 Ok(())
888 }
889
890 #[instrument(skip(self), err)]
891 pub async fn unsubscribe_ls_async(
892 &self,
893 transaction_id: TransactionId,
894 ) -> ConnectionResult<TransactionId> {
895 let (tx, rx) = oneshot::channel();
896 self.commands
897 .send(Command::UnsubscribeLsAsync(transaction_id, tx))
898 .await?;
899 let res = rx.await?;
900 Ok(res)
901 }
902
903 #[instrument(skip(self), err)]
904 pub async fn lock(&self, key: Key) -> ConnectionResult<()> {
905 let (tx, rx) = oneshot::channel();
906 self.commands.send(Command::Lock(key, tx)).await?;
907 rx.await??;
908 Ok(())
909 }
910
911 #[instrument(skip(self), err)]
912 pub async fn lock_async(&self, key: Key) -> ConnectionResult<TransactionId> {
913 let (tx, rx) = oneshot::channel();
914 self.commands.send(Command::LockAsync(key, tx)).await?;
915 let res = rx.await?;
916 Ok(res)
917 }
918
919 #[instrument(skip(self), err)]
920 pub async fn acquire_lock(&self, key: Key) -> ConnectionResult<()> {
921 let (tx, rx) = oneshot::channel();
922 self.commands.send(Command::AcquireLock(key, tx)).await?;
923 match rx.await? {
924 Ok(_) => Ok(()),
925 Result::Err(e) => Err(ConnectionError::ServerResponse(Box::new(e))),
926 }
927 }
928
929 pub async fn locked<T>(
930 &self,
931 key: Key,
932 task: impl AsyncFnOnce() -> T + Send,
933 ) -> ConnectionResult<T> {
934 self.acquire_lock(key.clone()).await?;
935 let result = task().await;
936 self.release_lock(key).await?;
937 Ok(result)
938 }
939
940 pub async fn try_locked<T, E>(
941 &self,
942 key: Key,
943 task: impl AsyncFnOnce() -> std::result::Result<T, E> + Send,
944 ) -> ConnectionResult<std::result::Result<T, E>> {
945 self.lock(key.clone()).await?;
946 let result = task().await;
947 self.release_lock(key).await?;
948 Ok(result)
949 }
950
951 #[instrument(skip(self), err)]
952 pub async fn release_lock(&self, key: Key) -> ConnectionResult<()> {
953 let (tx, rx) = oneshot::channel();
954 self.commands.send(Command::ReleaseLock(key, tx)).await?;
955 rx.await??;
956 Ok(())
957 }
958
959 #[instrument(skip(self), err)]
960 pub async fn release_lock_async(&self, key: Key) -> ConnectionResult<TransactionId> {
961 let (tx, rx) = oneshot::channel();
962 self.commands
963 .send(Command::ReleaseLockAsync(key, tx))
964 .await?;
965 let res = rx.await?;
966 Ok(res)
967 }
968
969 #[instrument(skip(self, seed, update), err)]
970 pub async fn update<T: DeserializeOwned + Serialize + Debug, S: Fn() -> T, F: Fn(&mut T)>(
971 &self,
972 key: Key,
973 seed: S,
974 update: F,
975 ) -> ConnectionResult<()> {
976 let f = |mut set: Option<T>| match set.take() {
977 Some(mut it) => {
978 update(&mut it);
979 it
980 }
981 None => {
982 let mut it = seed();
983 update(&mut it);
984 it
985 }
986 };
987
988 self.try_update(key, f, 0).await
989 }
990
991 #[instrument(skip(self, swap), err)]
992 pub async fn swap<T: DeserializeOwned + Debug, V: Serialize, F: Fn(Option<T>) -> V>(
993 &self,
994 key: Key,
995 swap: F,
996 ) -> ConnectionResult<()> {
997 self.try_update(key, swap, 0).await
998 }
999
1000 #[instrument(skip(self, transform), level = "debug", err)]
1001 async fn try_update<T: DeserializeOwned + Debug, V: Serialize, F: Fn(Option<T>) -> V>(
1002 &self,
1003 key: Key,
1004 transform: F,
1005 counter: usize,
1006 ) -> ConnectionResult<()> {
1007 if counter >= 100 {
1008 return Err(ConnectionError::Timeout(Box::new(
1009 "could not update, value keeps being changed by another instance".to_owned(),
1010 )));
1011 }
1012
1013 let (new_val, version) = match self.cget::<T>(key.clone()).await? {
1014 Some((val, version)) => (transform(Some(val)), version),
1015 None => (transform(None), 0),
1016 };
1017
1018 if let Err(e) = self.cset(key.clone(), new_val, version).await {
1019 match e {
1020 ConnectionError::ServerResponse(_) => {
1021 tracing::debug!(
1022 "value has changed in the mean time, re-fetching and trying again"
1023 );
1024 Box::pin(self.try_update(key, transform, counter + 1)).await
1025 }
1026 e => Err(e),
1027 }
1028 } else {
1029 Ok(())
1030 }
1031 }
1032
1033 #[instrument(skip(self))]
1034 pub async fn send_buffer(&self, delay: Duration) -> SendBuffer {
1035 SendBuffer::new(self.commands.clone(), delay).await
1036 }
1037
1038 #[instrument(skip(self), err)]
1039 pub async fn close(&self) -> ConnectionResult<()> {
1040 let (tx, rx) = oneshot::channel();
1041 self.stop.send(tx).await?;
1042 rx.await.ok();
1043 Ok(())
1044 }
1045
1046 #[instrument(skip(self))]
1047 pub async fn all_messages(&self) -> ConnectionResult<mpsc::UnboundedReceiver<ServerMessage>> {
1048 let (tx, rx) = mpsc::unbounded_channel();
1049 self.commands.send(Command::AllMessages(tx)).await?;
1050 Ok(rx)
1051 }
1052
1053 pub fn client_id(&self) -> &str {
1054 &self.client_id
1055 }
1056}
1057
1058#[instrument(level = "trace")]
1059async fn deserialize_values<T: DeserializeOwned + Send + 'static>(
1060 mut val_rx: mpsc::UnboundedReceiver<Option<Value>>,
1061 typed_val_tx: mpsc::UnboundedSender<Option<T>>,
1062) {
1063 while let Some(val) = val_rx.recv().await {
1064 match val {
1065 Some(val) => match json::from_value(val) {
1066 Ok(typed_val) => {
1067 if typed_val_tx.send(typed_val).is_err() {
1068 break;
1069 }
1070 }
1071 Err(e) => {
1072 error!("could not deserialize json value to requested type: {e}");
1073 break;
1074 }
1075 },
1076 None => {
1077 if typed_val_tx.send(None).is_err() {
1078 break;
1079 }
1080 }
1081 };
1082 }
1083}
1084
1085#[instrument(level = "trace")]
1086async fn deserialize_events<T: DeserializeOwned + Debug + Send + 'static>(
1087 mut event_rx: mpsc::UnboundedReceiver<PStateEvent>,
1088 typed_event_tx: mpsc::UnboundedSender<TypedPStateEvent<T>>,
1089) {
1090 while let Some(evt) = event_rx.recv().await {
1091 match deserialize_pstate_event(evt) {
1092 Ok(typed_event) => {
1093 if typed_event_tx.send(typed_event).is_err() {
1094 break;
1095 }
1096 }
1097 Result::Err(e) => {
1098 error!("could not deserialize json to requested type: {e}");
1099 break;
1100 }
1101 }
1102 }
1103}
1104
1105type GenericCallback = mpsc::UnboundedSender<ServerMessage>;
1106type AsyncTicket = oneshot::Sender<TransactionId>;
1107type AckCallback = oneshot::Sender<Result<Ack, Err>>;
1108type StateCallback = oneshot::Sender<Result<State, Err>>;
1109type CStateCallback = oneshot::Sender<Result<CState, Err>>;
1110type PStateCallback = oneshot::Sender<Result<PState, Err>>;
1111type LsStateCallback = oneshot::Sender<Result<LsState, Err>>;
1112
1113type GenericCallbacks = Vec<GenericCallback>;
1114type AckCallbacks = HashMap<TransactionId, AckCallback>;
1115type StateCallbacks = HashMap<TransactionId, StateCallback>;
1116type CStateCallbacks = HashMap<TransactionId, CStateCallback>;
1117type PStateCallbacks = HashMap<TransactionId, PStateCallback>;
1118type LsStateCallbacks = HashMap<TransactionId, LsStateCallback>;
1119type SubCallbacks = HashMap<TransactionId, mpsc::UnboundedSender<Option<Value>>>;
1120type PSubCallbacks = HashMap<TransactionId, mpsc::UnboundedSender<PStateEvent>>;
1121type SubLsCallbacks = HashMap<TransactionId, mpsc::UnboundedSender<Vec<RegularKeySegment>>>;
1122
1123#[derive(Default)]
1124struct Callbacks {
1125 generic: GenericCallbacks,
1126 ack: AckCallbacks,
1127 state: StateCallbacks,
1128 cstate: CStateCallbacks,
1129 pstate: PStateCallbacks,
1130 lsstate: LsStateCallbacks,
1131 sub: SubCallbacks,
1132 psub: PSubCallbacks,
1133 subls: SubLsCallbacks,
1134}
1135
1136struct TransactionIds {
1137 next_transaction_id: TransactionId,
1138}
1139
1140impl Default for TransactionIds {
1141 fn default() -> Self {
1142 TransactionIds {
1143 next_transaction_id: 1,
1144 }
1145 }
1146}
1147
1148impl TransactionIds {
1149 pub fn next(&mut self) -> TransactionId {
1150 let tid = self.next_transaction_id;
1151 self.next_transaction_id += 1;
1152 tid
1153 }
1154}
1155
1156pub struct OnDisconnect {
1157 rx: oneshot::Receiver<()>,
1158}
1159
1160impl Future for OnDisconnect {
1161 type Output = ();
1162
1163 fn poll(
1164 mut self: std::pin::Pin<&mut Self>,
1165 cx: &mut std::task::Context<'_>,
1166 ) -> std::task::Poll<Self::Output> {
1167 self.rx.poll_unpin(cx).map(|_| ())
1168 }
1169}
1170
1171#[instrument(, err)]
1172pub async fn connect_with_default_config() -> ConnectionResult<(Worterbuch, OnDisconnect, Config)> {
1173 let config = Config::new();
1174 let (conn, disconnected) = connect(config.clone()).await?;
1175 Ok((conn, disconnected, config))
1176}
1177
1178#[instrument(, err)]
1179pub async fn connect(config: Config) -> ConnectionResult<(Worterbuch, OnDisconnect)> {
1180 let mut err = None;
1181 for addr in &config.servers {
1182 info!("Trying to connect to server {addr} …");
1183 match try_connect(config.clone(), *addr).await {
1184 Ok(con) => {
1185 info!("Successfully connected to server {addr}");
1186 return Ok(con);
1187 }
1188 Err(e) => {
1189 warn!("Could not connect to server {addr}: {e}");
1190 err = Some(e);
1191 }
1192 }
1193 }
1194 if let Some(e) = err {
1195 Err(e)
1196 } else {
1197 Err(ConnectionError::NoServerAddressesConfigured)
1198 }
1199}
1200
1201pub fn local_client_wrapper(api: impl WbApi + Send + Sync + 'static) -> Worterbuch {
1202 let (commands_tx, cmd_rx) = mpsc::channel(1);
1203 let (stop_tx, stop_rx) = mpsc::channel(1);
1204 let (disco_tx, disco_rx) = oneshot::channel();
1205
1206 let (ctx, crx) = mpsc::unbounded_channel();
1207 let (stx, srx) = mpsc::unbounded_channel();
1208
1209 LocalClientSocket::spawn_api_forward_loop(api, crx, stx);
1210
1211 let local_socket = LocalClientSocket::new(ctx, srx, disco_rx);
1212 let client_socket = ClientSocket::Local(local_socket);
1213 let config = Config::new();
1214
1215 spawn(async move {
1216 if let Err(e) = run(cmd_rx, client_socket, stop_rx, config).await {
1217 error!("Connection closed with error: {e}");
1218 } else {
1219 debug!("Connection closed.");
1220 }
1221 disco_tx.send(()).ok();
1222 });
1223
1224 Worterbuch {
1225 client_id: "internal".to_owned(),
1226 commands: commands_tx,
1227 stop: stop_tx,
1228 }
1229}
1230
1231#[instrument(skip(config), err(level = Level::WARN))]
1232pub async fn try_connect(
1233 config: Config,
1234 host_addr: SocketAddr,
1235) -> ConnectionResult<(Worterbuch, OnDisconnect)> {
1236 let proto = &config.proto;
1237 let tcp = proto == "tcp";
1238 let unix = proto == "unix";
1239 let path = if tcp { "" } else { "/ws" };
1240 #[cfg(target_family = "unix")]
1241 let url = if unix {
1242 config
1243 .socket_path
1244 .clone()
1245 .unwrap_or_else(|| "/tmp/worterbuch.socket".into())
1246 .to_string_lossy()
1247 .to_string()
1248 } else {
1249 format!("{proto}://{host_addr}{path}")
1250 };
1251 #[cfg(not(target_family = "unix"))]
1252 let url = format!("{proto}://{host_addr}{path}");
1253
1254 debug!("Got server url from config: {url}");
1255
1256 let (disco_tx, disco_rx) = oneshot::channel();
1257
1258 let wb = if tcp {
1259 #[cfg(not(feature = "tcp"))]
1260 panic!("tcp not supported, binary was compiled without the tcp feature flag");
1261 #[cfg(feature = "tcp")]
1262 connect_tcp(host_addr, disco_tx, config).await?
1263 } else if unix {
1264 #[cfg(not(all(target_family = "unix", feature = "unix")))]
1265 panic!(
1266 "not supported, binary was compile without the unix feature flag or for non-unix operating systems"
1267 );
1268 #[cfg(all(target_family = "unix", feature = "unix"))]
1269 connect_unix(url, disco_tx, config).await?
1270 } else {
1271 #[cfg(not(any(feature = "ws", feature = "wasm")))]
1272 panic!("websocket not supported, binary was compiled without the ws feature flag");
1273 #[cfg(any(feature = "ws", feature = "wasm"))]
1274 connect_ws(url, host_addr, disco_tx, config).await?
1275 };
1276
1277 let disconnected = OnDisconnect { rx: disco_rx };
1278
1279 Ok((wb, disconnected))
1280}
1281
1282#[cfg(any(feature = "ws", feature = "wasm"))]
1283#[instrument(skip(config, on_disconnect), level = Level::INFO)]
1284async fn connect_ws(
1285 url: String,
1286 host: SocketAddr,
1287 on_disconnect: oneshot::Sender<()>,
1288 config: Config,
1289) -> Result<Worterbuch, ConnectionError> {
1290 debug!("Connecting to server {url} over websocket …");
1291
1292 #[cfg(feature = "ws")]
1293 let mut websocket = {
1294 use tokio_tungstenite::tungstenite::{client::IntoClientRequest, http::HeaderValue};
1295
1296 let auth_token = config.auth_token.clone();
1297 let mut request = url.into_client_request()?;
1298
1299 if let Some(auth_token) = auth_token {
1300 let header_value = HeaderValue::from_str(&format!("Bearer {auth_token}"))?;
1301 request.headers_mut().insert("Authorization", header_value);
1302 }
1303
1304 #[cfg(feature = "ws")]
1305 let (websocket, _) = connect_async_with_config(request, None, true).await?;
1306
1307 websocket
1308 };
1309
1310 #[cfg(feature = "wasm")]
1311 let mut websocket = connect_wasm(url).await?;
1312
1313 debug!("Connected to server.");
1314
1315 let Welcome { client_id, info } = match websocket.next().await {
1316 Some(Ok(msg)) => match msg.to_text() {
1317 Ok(data) => match json::from_str::<SM>(data) {
1318 Ok(SM::Welcome(welcome)) => {
1319 debug!("Welcome message received: {welcome:?}");
1320 welcome
1321 }
1322 Ok(msg) => {
1323 return Err(ConnectionError::IoError(Box::new(io::Error::new(
1324 io::ErrorKind::InvalidData,
1325 format!("server sent invalid welcome message: {msg:?}"),
1326 ))));
1327 }
1328 Err(e) => {
1329 return Err(ConnectionError::IoError(Box::new(io::Error::new(
1330 io::ErrorKind::InvalidData,
1331 format!("error parsing welcome message '{data}': {e}"),
1332 ))));
1333 }
1334 },
1335 Err(e) => {
1336 return Err(ConnectionError::IoError(Box::new(io::Error::new(
1337 io::ErrorKind::InvalidData,
1338 format!("invalid welcome message '{msg:?}': {e}"),
1339 ))));
1340 }
1341 },
1342 Some(Err(e)) => return Err(e.into()),
1343 None => {
1344 return Err(ConnectionError::IoError(Box::new(io::Error::new(
1345 io::ErrorKind::ConnectionAborted,
1346 "connection closed before welcome message",
1347 ))));
1348 }
1349 };
1350
1351 let proto_version = if let Some(v) = info
1352 .supported_protocol_versions
1353 .iter()
1354 .find(|v| PROTOCOL_VERSION.is_compatible_with_server(v))
1355 {
1356 v
1357 } else {
1358 return Err(ConnectionError::WorterbuchError(Box::new(
1359 WorterbuchError::ProtocolNegotiationFailed(PROTOCOL_VERSION.major()),
1360 )));
1361 };
1362
1363 debug!("Found compatible protocol version {proto_version}.");
1364
1365 let proto_switch = ProtocolSwitchRequest {
1366 version: proto_version.major(),
1367 };
1368 let msg = json::to_string(&CM::ProtocolSwitchRequest(proto_switch))?;
1369 debug!("Sending protocol switch message: {msg}");
1370 websocket.send(Message::Text(msg.into())).await?;
1371
1372 match websocket.next().await {
1373 Some(msg) => match msg? {
1374 Message::Text(msg) => match serde_json::from_str(&msg) {
1375 Ok(SM::Ack(_)) => {
1376 debug!("Protocol switched to v{}.", proto_version.major());
1377 }
1378 Ok(SM::Err(e)) => {
1379 error!("Protocol switch failed: {e}");
1380 return Err(ConnectionError::WorterbuchError(Box::new(
1381 WorterbuchError::ServerResponse(e),
1382 )));
1383 }
1384 Ok(msg) => {
1385 return Err(ConnectionError::IoError(Box::new(io::Error::new(
1386 io::ErrorKind::InvalidData,
1387 format!("server sent invalid protocol switch response: {msg:?}"),
1388 ))));
1389 }
1390 Err(e) => {
1391 return Err(ConnectionError::IoError(Box::new(io::Error::new(
1392 io::ErrorKind::InvalidData,
1393 format!("error receiving protocol switch response: {e}"),
1394 ))));
1395 }
1396 },
1397 msg => {
1398 return Err(ConnectionError::IoError(Box::new(io::Error::new(
1399 io::ErrorKind::InvalidData,
1400 format!("received unexpected message from server: {msg:?}"),
1401 ))));
1402 }
1403 },
1404 None => {
1405 warn!("Server closed the connection");
1406 return Err(ConnectionError::IoError(Box::new(io::Error::new(
1407 io::ErrorKind::ConnectionReset,
1408 "connection closed before welcome message",
1409 ))));
1410 }
1411 }
1412
1413 if info.authorization_required {
1414 if let Some(auth_token) = config.auth_token.clone() {
1415 let handshake = AuthorizationRequest { auth_token };
1416 let msg = json::to_string(&CM::AuthorizationRequest(handshake))?;
1417 debug!("Sending authorization message: {msg}");
1418 websocket.send(Message::Text(msg.into())).await?;
1419
1420 match websocket.next().await {
1421 Some(Err(e)) => Err(e.into()),
1422 Some(Ok(Message::Text(msg))) => match serde_json::from_str(&msg) {
1423 Ok(SM::Authorized(_)) => {
1424 debug!("Authorization accepted.");
1425 connected(
1426 ClientSocket::Ws(WsClientSocket::new(websocket)),
1427 on_disconnect,
1428 config,
1429 client_id,
1430 )
1431 }
1432 Ok(SM::Err(e)) => {
1433 error!("Authorization failed: {e}");
1434 Err(ConnectionError::WorterbuchError(Box::new(
1435 WorterbuchError::ServerResponse(e),
1436 )))
1437 }
1438 Ok(msg) => Err(ConnectionError::IoError(Box::new(io::Error::new(
1439 io::ErrorKind::InvalidData,
1440 format!("server sent invalid authetication response: {msg:?}"),
1441 )))),
1442 Err(e) => Err(ConnectionError::IoError(Box::new(io::Error::new(
1443 io::ErrorKind::InvalidData,
1444 format!("error receiving authorization response: {e}"),
1445 )))),
1446 },
1447 Some(Ok(msg)) => Err(ConnectionError::IoError(Box::new(io::Error::new(
1448 io::ErrorKind::InvalidData,
1449 format!("received unexpected message from server: {msg:?}"),
1450 )))),
1451 None => Err(ConnectionError::IoError(Box::new(io::Error::new(
1452 io::ErrorKind::ConnectionReset,
1453 "connection closed before welcome message",
1454 )))),
1455 }
1456 } else {
1457 Err(ConnectionError::AuthorizationError(Box::new(
1458 "Server requires authorization but no auth token was provided.".to_owned(),
1459 )))
1460 }
1461 } else {
1462 connected(
1463 ClientSocket::Ws(WsClientSocket::new(websocket)),
1464 on_disconnect,
1465 config,
1466 client_id,
1467 )
1468 }
1469}
1470
1471#[cfg(feature = "tcp")]
1472#[instrument(skip(config, on_disconnect))]
1473async fn connect_tcp(
1474 host_addr: SocketAddr,
1475 on_disconnect: oneshot::Sender<()>,
1476 config: Config,
1477) -> Result<Worterbuch, ConnectionError> {
1478 let timeout = config.connection_timeout;
1479 debug!(
1480 "Connecting to server tcp://{host_addr} (timeout: {} ms) …",
1481 timeout.as_millis()
1482 );
1483
1484 let stream = select! {
1485 conn = TcpStream::connect(host_addr) => conn,
1486 _ = sleep(timeout) => {
1487 return Err(ConnectionError::Timeout(Box::new("Timeout while waiting for TCP connection.".to_owned())));
1488 },
1489 }?;
1490 debug!("Connected to tcp://{host_addr}.");
1491 let (tcp_rx, mut tcp_tx) = stream.into_split();
1492 let mut tcp_rx = BufReader::new(tcp_rx).lines();
1493
1494 debug!("Connected to server.");
1495
1496 let Welcome { client_id, info } = select! {
1497 line = tcp_rx.next_line() => match line {
1498 Ok(None) => {
1499 return Err(ConnectionError::IoError(Box::new(io::Error::new(
1500 io::ErrorKind::ConnectionReset,
1501 "connection closed before welcome message",
1502 ))))
1503 }
1504 Ok(Some(line)) => {
1505 let msg = json::from_str::<SM>(&line);
1506 match msg {
1507 Ok(SM::Welcome(welcome)) => {
1508 debug!("Welcome message received: {welcome:?}");
1509 welcome
1510 }
1511 Ok(msg) => {
1512 return Err(ConnectionError::IoError(Box::new(io::Error::new(
1513 io::ErrorKind::InvalidData,
1514 format!("server sent invalid welcome message: {msg:?}"),
1515 ))))
1516 }
1517 Err(e) => {
1518 return Err(ConnectionError::IoError(Box::new(io::Error::new(
1519 io::ErrorKind::InvalidData,
1520 format!("error parsing welcome message '{line}': {e}"),
1521 ))))
1522 }
1523 }
1524 }
1525 Err(e) => return Err(ConnectionError::IoError(Box::new(e))),
1526 },
1527 _ = sleep(timeout) => {
1528 return Err(ConnectionError::Timeout(Box::new("Timeout while waiting for welcome message.".to_owned())));
1529 },
1530 };
1531
1532 let proto_version = if let Some(v) = info
1533 .supported_protocol_versions
1534 .iter()
1535 .find(|v| PROTOCOL_VERSION.is_compatible_with_server(v))
1536 {
1537 v
1538 } else {
1539 return Err(ConnectionError::WorterbuchError(Box::new(
1540 WorterbuchError::ProtocolNegotiationFailed(PROTOCOL_VERSION.major()),
1541 )));
1542 };
1543
1544 debug!("Found compatible protocol version {proto_version}.");
1545
1546 let proto_switch = ProtocolSwitchRequest {
1547 version: proto_version.major(),
1548 };
1549 let mut msg = json::to_string(&CM::ProtocolSwitchRequest(proto_switch))?;
1550 msg.push('\n');
1551 debug!("Sending protocol switch message: {msg}");
1552 tcp_tx.write_all(msg.as_bytes()).await?;
1553
1554 match tcp_rx.next_line().await {
1555 Ok(None) => {
1556 return Err(ConnectionError::IoError(Box::new(io::Error::new(
1557 io::ErrorKind::ConnectionReset,
1558 "connection closed before handshake",
1559 ))));
1560 }
1561 Ok(Some(line)) => match serde_json::from_str(&line) {
1562 Ok(SM::Ack(_)) => {
1563 debug!("Protocol switched to v{}.", proto_version.major());
1564 }
1565 Ok(SM::Err(e)) => {
1566 error!("Protocol switch failed: {e}");
1567 return Err(ConnectionError::WorterbuchError(Box::new(
1568 WorterbuchError::ServerResponse(e),
1569 )));
1570 }
1571 Ok(msg) => {
1572 return Err(ConnectionError::IoError(Box::new(io::Error::new(
1573 io::ErrorKind::InvalidData,
1574 format!("server sent invalid protocol switch response: {msg:?}"),
1575 ))));
1576 }
1577 Err(e) => {
1578 return Err(ConnectionError::IoError(Box::new(io::Error::new(
1579 io::ErrorKind::InvalidData,
1580 format!("error receiving protocol switch response: {e}"),
1581 ))));
1582 }
1583 },
1584 Err(e) => {
1585 warn!("Server closed the connection");
1586 return Err(ConnectionError::IoError(Box::new(e)));
1587 }
1588 }
1589
1590 if info.authorization_required {
1591 if let Some(auth_token) = config.auth_token.clone() {
1592 let handshake = AuthorizationRequest { auth_token };
1593 let mut msg = json::to_string(&CM::AuthorizationRequest(handshake))?;
1594 msg.push('\n');
1595 debug!("Sending authorization message: {msg}");
1596 tcp_tx.write_all(msg.as_bytes()).await?;
1597
1598 match tcp_rx.next_line().await {
1599 Ok(None) => Err(ConnectionError::IoError(Box::new(io::Error::new(
1600 io::ErrorKind::ConnectionReset,
1601 "connection closed before handshake",
1602 )))),
1603 Ok(Some(line)) => {
1604 let msg = json::from_str::<SM>(&line);
1605 match msg {
1606 Ok(SM::Authorized(_)) => {
1607 debug!("Authorization accepted.");
1608 connected(
1609 ClientSocket::Tcp(
1610 TcpClientSocket::new(
1611 tcp_tx,
1612 tcp_rx,
1613 config.send_timeout,
1614 config.channel_buffer_size,
1615 )
1616 .await,
1617 ),
1618 on_disconnect,
1619 config,
1620 client_id,
1621 )
1622 }
1623 Ok(SM::Err(e)) => {
1624 error!("Authorization failed: {e}");
1625 Err(ConnectionError::WorterbuchError(Box::new(
1626 WorterbuchError::ServerResponse(e),
1627 )))
1628 }
1629 Ok(msg) => Err(ConnectionError::IoError(Box::new(io::Error::new(
1630 io::ErrorKind::InvalidData,
1631 format!("server sent invalid authetication response: {msg:?}"),
1632 )))),
1633 Err(e) => Err(ConnectionError::IoError(Box::new(io::Error::new(
1634 io::ErrorKind::InvalidData,
1635 format!("error receiving authorization response: {e}"),
1636 )))),
1637 }
1638 }
1639 Err(e) => Err(ConnectionError::IoError(Box::new(e))),
1640 }
1641 } else {
1642 Err(ConnectionError::AuthorizationError(Box::new(
1643 "Server requires authorization but no auth token was provided.".to_owned(),
1644 )))
1645 }
1646 } else {
1647 connected(
1648 ClientSocket::Tcp(
1649 TcpClientSocket::new(
1650 tcp_tx,
1651 tcp_rx,
1652 config.send_timeout,
1653 config.channel_buffer_size,
1654 )
1655 .await,
1656 ),
1657 on_disconnect,
1658 config,
1659 client_id,
1660 )
1661 }
1662}
1663
1664#[cfg(all(target_family = "unix", feature = "unix"))]
1665#[instrument(skip(config, on_disconnect), err(level = Level::WARN))]
1666async fn connect_unix(
1667 path: String,
1668 on_disconnect: oneshot::Sender<()>,
1669 config: Config,
1670) -> Result<Worterbuch, ConnectionError> {
1671 let timeout = config.connection_timeout;
1672 debug!(
1673 "Connecting to server socket {path} (timeout: {} ms) …",
1674 timeout.as_millis()
1675 );
1676
1677 let stream = select! {
1678 conn = UnixStream::connect(&path) => conn,
1679 _ = sleep(timeout) => {
1680 return Err(ConnectionError::Timeout(Box::new("Timeout while waiting for TCP connection.".to_owned())));
1681 },
1682 }?;
1683 debug!("Connected to {path}.");
1684 let (tcp_rx, mut tcp_tx) = stream.into_split();
1685 let mut tcp_rx = BufReader::new(tcp_rx).lines();
1686
1687 debug!("Connected to server.");
1688
1689 let Welcome { client_id, info } = select! {
1690 line = tcp_rx.next_line() => match line {
1691 Ok(None) => {
1692 return Err(ConnectionError::IoError(Box::new(io::Error::new(
1693 io::ErrorKind::ConnectionReset,
1694 "connection closed before welcome message",
1695 ))))
1696 }
1697 Ok(Some(line)) => {
1698 let msg = json::from_str::<SM>(&line);
1699 match msg {
1700 Ok(SM::Welcome(welcome)) => {
1701 debug!("Welcome message received: {welcome:?}");
1702 welcome
1703 }
1704 Ok(msg) => {
1705 return Err(ConnectionError::IoError(Box::new(io::Error::new(
1706 io::ErrorKind::InvalidData,
1707 format!("server sent invalid welcome message: {msg:?}"),
1708 ))))
1709 }
1710 Err(e) => {
1711 return Err(ConnectionError::IoError(Box::new(io::Error::new(
1712 io::ErrorKind::InvalidData,
1713 format!("error parsing welcome message '{line}': {e}"),
1714 ))))
1715 }
1716 }
1717 }
1718 Err(e) => return Err(ConnectionError::IoError(Box::new(e))),
1719 },
1720 _ = sleep(timeout) => {
1721 return Err(ConnectionError::Timeout(Box::new("Timeout while waiting for welcome message.".to_owned())));
1722 },
1723 };
1724
1725 let proto_version = if let Some(v) = info
1726 .supported_protocol_versions
1727 .iter()
1728 .find(|v| PROTOCOL_VERSION.is_compatible_with_server(v))
1729 {
1730 v
1731 } else {
1732 return Err(ConnectionError::WorterbuchError(Box::new(
1733 WorterbuchError::ProtocolNegotiationFailed(PROTOCOL_VERSION.major()),
1734 )));
1735 };
1736
1737 debug!("Found compatible protocol version {proto_version}.");
1738
1739 let proto_switch = ProtocolSwitchRequest {
1740 version: proto_version.major(),
1741 };
1742 let mut msg = json::to_string(&CM::ProtocolSwitchRequest(proto_switch))?;
1743 msg.push('\n');
1744 debug!("Sending protocol switch message: {msg}");
1745 tcp_tx.write_all(msg.as_bytes()).await?;
1746
1747 match tcp_rx.next_line().await {
1748 Ok(None) => {
1749 return Err(ConnectionError::IoError(Box::new(io::Error::new(
1750 io::ErrorKind::ConnectionReset,
1751 "connection closed before handshake",
1752 ))));
1753 }
1754 Ok(Some(line)) => match serde_json::from_str(&line) {
1755 Ok(SM::Ack(_)) => {
1756 debug!("Protocol switched to v{}.", proto_version.major());
1757 }
1758 Ok(SM::Err(e)) => {
1759 error!("Protocol switch failed: {e}");
1760 return Err(ConnectionError::WorterbuchError(Box::new(
1761 WorterbuchError::ServerResponse(e),
1762 )));
1763 }
1764 Ok(msg) => {
1765 return Err(ConnectionError::IoError(Box::new(io::Error::new(
1766 io::ErrorKind::InvalidData,
1767 format!("server sent invalid protocol switch response: {msg:?}"),
1768 ))));
1769 }
1770 Err(e) => {
1771 return Err(ConnectionError::IoError(Box::new(io::Error::new(
1772 io::ErrorKind::InvalidData,
1773 format!("error receiving protocol switch response: {e}"),
1774 ))));
1775 }
1776 },
1777 Err(e) => {
1778 warn!("Server closed the connection");
1779 return Err(ConnectionError::IoError(Box::new(e)));
1780 }
1781 }
1782
1783 if info.authorization_required {
1784 if let Some(auth_token) = config.auth_token.clone() {
1785 let handshake = AuthorizationRequest { auth_token };
1786 let mut msg = json::to_string(&CM::AuthorizationRequest(handshake))?;
1787 msg.push('\n');
1788 debug!("Sending authorization message: {msg}");
1789 tcp_tx.write_all(msg.as_bytes()).await?;
1790
1791 match tcp_rx.next_line().await {
1792 Ok(None) => Err(ConnectionError::IoError(Box::new(io::Error::new(
1793 io::ErrorKind::ConnectionReset,
1794 "connection closed before handshake",
1795 )))),
1796 Ok(Some(line)) => {
1797 let msg = json::from_str::<SM>(&line);
1798 match msg {
1799 Ok(SM::Authorized(_)) => {
1800 debug!("Authorization accepted.");
1801 connected(
1802 ClientSocket::Unix(
1803 UnixClientSocket::new(
1804 tcp_tx,
1805 tcp_rx,
1806 config.channel_buffer_size,
1807 )
1808 .await,
1809 ),
1810 on_disconnect,
1811 config,
1812 client_id,
1813 )
1814 }
1815 Ok(SM::Err(e)) => {
1816 error!("Authorization failed: {e}");
1817 Err(ConnectionError::WorterbuchError(Box::new(
1818 WorterbuchError::ServerResponse(e),
1819 )))
1820 }
1821 Ok(msg) => Err(ConnectionError::IoError(Box::new(io::Error::new(
1822 io::ErrorKind::InvalidData,
1823 format!("server sent invalid authetication response: {msg:?}"),
1824 )))),
1825 Err(e) => Err(ConnectionError::IoError(Box::new(io::Error::new(
1826 io::ErrorKind::InvalidData,
1827 format!("error receiving authorization response: {e}"),
1828 )))),
1829 }
1830 }
1831 Err(e) => Err(ConnectionError::IoError(Box::new(e))),
1832 }
1833 } else {
1834 Err(ConnectionError::AuthorizationError(Box::new(
1835 "Server requires authorization but no auth token was provided.".to_owned(),
1836 )))
1837 }
1838 } else {
1839 connected(
1840 ClientSocket::Unix(
1841 UnixClientSocket::new(tcp_tx, tcp_rx, config.channel_buffer_size).await,
1842 ),
1843 on_disconnect,
1844 config,
1845 client_id,
1846 )
1847 }
1848}
1849
1850#[instrument(skip(client_socket, on_disconnect, config))]
1851fn connected(
1852 client_socket: ClientSocket,
1853 on_disconnect: oneshot::Sender<()>,
1854 config: Config,
1855 client_id: String,
1856) -> Result<Worterbuch, ConnectionError> {
1857 let (stop_tx, stop_rx) = mpsc::channel(1);
1858 let (cmd_tx, cmd_rx) = mpsc::channel(1);
1859
1860 spawn(async move {
1861 if let Err(e) = run(cmd_rx, client_socket, stop_rx, config).await {
1862 error!("Connection closed with error: {e}");
1863 } else {
1864 debug!("Connection closed.");
1865 }
1866 on_disconnect.send(()).ok();
1867 });
1868
1869 Ok(Worterbuch::new(cmd_tx, stop_tx, client_id))
1870}
1871
1872#[instrument(skip(cmd_rx, client_socket, stop_rx, config), err)]
1873async fn run(
1874 mut cmd_rx: mpsc::Receiver<Command>,
1875 mut client_socket: ClientSocket,
1876 mut stop_rx: mpsc::Receiver<oneshot::Sender<()>>,
1877 config: Config,
1878) -> ConnectionResult<()> {
1879 let mut callbacks = Callbacks::default();
1880 let mut transaction_ids = TransactionIds::default();
1881
1882 let mut stop_tx = None;
1883
1884 loop {
1885 trace!("loop: wait for command / ws message / shutdown request");
1886 select! {
1887 recv = stop_rx.recv() => {
1888 debug!("Shutdown request received.");
1889 stop_tx = recv;
1890 break;
1891 },
1892 ws_msg = client_socket.receive_msg() => {
1893 match process_incoming_server_message(ws_msg, &mut callbacks).await {
1894 Ok(ControlFlow::Break(_)) => break,
1895 Err(e) => {
1896 error!("Error processing server message: {e}");
1897 break;
1898 },
1899 _ => trace!("websocket message processing done")
1900 }
1901 },
1902 cmd = cmd_rx.recv() => {
1903 match process_incoming_command(cmd, &mut callbacks, &mut transaction_ids).await {
1904 Ok(ControlFlow::Continue(msg)) => if let Some(msg) = msg
1905 && let Err(e) = client_socket.send_msg(msg, config.use_backpressure).await {
1906 error!("Error sending message to server: {e}");
1907 break;
1908 },
1909 Ok(ControlFlow::Break(_)) => break,
1910 Err(e) => {
1911 error!("Error processing command: {e}");
1912 break;
1913 },
1914 }
1915 }
1916 }
1917 }
1918
1919 client_socket.close().await?;
1920 if let Some(tx) = stop_tx {
1921 tx.send(()).ok();
1922 }
1923
1924 Ok(())
1925}
1926
1927#[instrument(skip(callbacks, transaction_ids), level = "trace", err)]
1928async fn process_incoming_command(
1929 cmd: Option<Command>,
1930 callbacks: &mut Callbacks,
1931 transaction_ids: &mut TransactionIds,
1932) -> ConnectionResult<ControlFlow<(), Option<CM>>> {
1933 if let Some(command) = cmd {
1934 debug!("Processing command: {command:?}");
1935 let transaction_id = transaction_ids.next();
1936 let cm = match command {
1937 Command::Set(key, value, callback) => {
1938 callbacks.ack.insert(transaction_id, callback);
1939 Some(CM::Set(Set {
1940 transaction_id,
1941 key,
1942 value,
1943 }))
1944 }
1945 Command::SetAsync(key, value, callback) => {
1946 callback.send(transaction_id).ok();
1947 Some(CM::Set(Set {
1948 transaction_id,
1949 key,
1950 value,
1951 }))
1952 }
1953 Command::CSet(key, value, version, callback) => {
1954 callbacks.ack.insert(transaction_id, callback);
1955 Some(CM::CSet(CSet {
1956 transaction_id,
1957 key,
1958 value,
1959 version,
1960 }))
1961 }
1962 Command::CSetAsync(key, value, version, callback) => {
1963 callback.send(transaction_id).ok();
1964 Some(CM::CSet(CSet {
1965 transaction_id,
1966 key,
1967 value,
1968 version,
1969 }))
1970 }
1971 Command::SPubInit(key, callback) => {
1972 callbacks.ack.insert(transaction_id, callback);
1973 Some(CM::SPubInit(SPubInit {
1974 transaction_id,
1975 key,
1976 }))
1977 }
1978 Command::SPubInitAsync(key, callback) => {
1979 callback.send(transaction_id).ok();
1980 Some(CM::SPubInit(SPubInit {
1981 transaction_id,
1982 key,
1983 }))
1984 }
1985 Command::SPub(transaction_id, value, callback) => {
1986 callbacks.ack.insert(transaction_id, callback);
1987 Some(CM::SPub(SPub {
1988 transaction_id,
1989 value,
1990 }))
1991 }
1992 Command::SPubAsync(transaction_id, value, callback) => {
1993 callback.send(transaction_id).ok();
1994 Some(CM::SPub(SPub {
1995 transaction_id,
1996 value,
1997 }))
1998 }
1999 Command::Publish(key, value, callback) => {
2000 callbacks.ack.insert(transaction_id, callback);
2001 Some(CM::Publish(Publish {
2002 transaction_id,
2003 key,
2004 value,
2005 }))
2006 }
2007 Command::PublishAsync(key, value, callback) => {
2008 callback.send(transaction_id).ok();
2009 Some(CM::Publish(Publish {
2010 transaction_id,
2011 key,
2012 value,
2013 }))
2014 }
2015 Command::Get(key, callback) => {
2016 callbacks.state.insert(transaction_id, callback);
2017 Some(CM::Get(Get {
2018 transaction_id,
2019 key,
2020 }))
2021 }
2022 Command::GetAsync(key, callback) => {
2023 callback.send(transaction_id).ok();
2024 Some(CM::Get(Get {
2025 transaction_id,
2026 key,
2027 }))
2028 }
2029 Command::CGet(key, callback) => {
2030 callbacks.cstate.insert(transaction_id, callback);
2031 Some(CM::CGet(Get {
2032 transaction_id,
2033 key,
2034 }))
2035 }
2036 Command::CGetAsync(key, callback) => {
2037 callback.send(transaction_id).ok();
2038 Some(CM::CGet(Get {
2039 transaction_id,
2040 key,
2041 }))
2042 }
2043 Command::PGet(request_pattern, callback) => {
2044 callbacks.pstate.insert(transaction_id, callback);
2045 Some(CM::PGet(PGet {
2046 transaction_id,
2047 request_pattern,
2048 }))
2049 }
2050 Command::PGetAsync(request_pattern, callback) => {
2051 callback.send(transaction_id).ok();
2052 Some(CM::PGet(PGet {
2053 transaction_id,
2054 request_pattern,
2055 }))
2056 }
2057 Command::Delete(key, callback) => {
2058 callbacks.state.insert(transaction_id, callback);
2059 Some(CM::Delete(Delete {
2060 transaction_id,
2061 key,
2062 }))
2063 }
2064 Command::DeleteAsync(key, callback) => {
2065 callback.send(transaction_id).ok();
2066 Some(CM::Delete(Delete {
2067 transaction_id,
2068 key,
2069 }))
2070 }
2071 Command::PDelete(request_pattern, quiet, callback) => {
2072 callbacks.pstate.insert(transaction_id, callback);
2073 Some(CM::PDelete(PDelete {
2074 transaction_id,
2075 request_pattern,
2076 quiet: Some(quiet),
2077 }))
2078 }
2079 Command::PDeleteAsync(request_pattern, quiet, callback) => {
2080 callback.send(transaction_id).ok();
2081 Some(CM::PDelete(PDelete {
2082 transaction_id,
2083 request_pattern,
2084 quiet: Some(quiet),
2085 }))
2086 }
2087 Command::Ls(parent, callback) => {
2088 callbacks.lsstate.insert(transaction_id, callback);
2089 Some(CM::Ls(Ls {
2090 transaction_id,
2091 parent,
2092 }))
2093 }
2094 Command::LsAsync(parent, callback) => {
2095 callback.send(transaction_id).ok();
2096 Some(CM::Ls(Ls {
2097 transaction_id,
2098 parent,
2099 }))
2100 }
2101 Command::PLs(parent_pattern, callback) => {
2102 callbacks.lsstate.insert(transaction_id, callback);
2103 Some(CM::PLs(PLs {
2104 transaction_id,
2105 parent_pattern,
2106 }))
2107 }
2108 Command::PLsAsync(parent_pattern, callback) => {
2109 callback.send(transaction_id).ok();
2110 Some(CM::PLs(PLs {
2111 transaction_id,
2112 parent_pattern,
2113 }))
2114 }
2115 Command::Subscribe(key, unique, tid_callback, value_callback, live_only) => {
2116 callbacks.sub.insert(transaction_id, value_callback);
2117 callbacks.ack.insert(transaction_id, tid_callback);
2118 Some(CM::Subscribe(Subscribe {
2119 transaction_id,
2120 key,
2121 unique,
2122 live_only: Some(live_only),
2123 }))
2124 }
2125 Command::SubscribeAsync(key, unique, callback, live_only) => {
2126 callback.send(transaction_id).ok();
2127 Some(CM::Subscribe(Subscribe {
2128 transaction_id,
2129 key,
2130 unique,
2131 live_only: Some(live_only),
2132 }))
2133 }
2134 Command::PSubscribe(
2135 request_pattern,
2136 unique,
2137 tid_callback,
2138 event_callback,
2139 aggregate_events,
2140 live_only,
2141 ) => {
2142 callbacks.psub.insert(transaction_id, event_callback);
2143 callbacks.ack.insert(transaction_id, tid_callback);
2144 Some(CM::PSubscribe(PSubscribe {
2145 transaction_id,
2146 request_pattern,
2147 unique,
2148 aggregate_events,
2149 live_only: Some(live_only),
2150 }))
2151 }
2152 Command::PSubscribeAsync(
2153 request_pattern,
2154 unique,
2155 callback,
2156 aggregate_events,
2157 live_only,
2158 ) => {
2159 callback.send(transaction_id).ok();
2160 Some(CM::PSubscribe(PSubscribe {
2161 transaction_id,
2162 request_pattern,
2163 unique,
2164 aggregate_events,
2165 live_only: Some(live_only),
2166 }))
2167 }
2168 Command::Unsubscribe(transaction_id, callback) => {
2169 callbacks.ack.insert(transaction_id, callback);
2170 callbacks.sub.remove(&transaction_id);
2171 callbacks.psub.remove(&transaction_id);
2172 Some(CM::Unsubscribe(Unsubscribe { transaction_id }))
2173 }
2174 Command::UnsubscribeAsync(transaction_id, callback) => {
2175 callbacks.sub.remove(&transaction_id);
2176 callbacks.psub.remove(&transaction_id);
2177 callback.send(transaction_id).ok();
2178 Some(CM::Unsubscribe(Unsubscribe { transaction_id }))
2179 }
2180 Command::SubscribeLs(parent, tid_callback, children_callback) => {
2181 callbacks.subls.insert(transaction_id, children_callback);
2182 callbacks.ack.insert(transaction_id, tid_callback);
2183 Some(CM::SubscribeLs(SubscribeLs {
2184 transaction_id,
2185 parent,
2186 }))
2187 }
2188 Command::SubscribeLsAsync(parent, callback) => {
2189 callback.send(transaction_id).ok();
2190 Some(CM::SubscribeLs(SubscribeLs {
2191 transaction_id,
2192 parent,
2193 }))
2194 }
2195 Command::UnsubscribeLs(transaction_id, callback) => {
2196 callbacks.ack.insert(transaction_id, callback);
2197 callbacks.subls.remove(&transaction_id);
2198 Some(CM::UnsubscribeLs(UnsubscribeLs { transaction_id }))
2199 }
2200 Command::UnsubscribeLsAsync(transaction_id, callback) => {
2201 callbacks.subls.remove(&transaction_id);
2202 callback.send(transaction_id).ok();
2203 Some(CM::Unsubscribe(Unsubscribe { transaction_id }))
2204 }
2205 Command::Lock(key, callback) => {
2206 callbacks.ack.insert(transaction_id, callback);
2207 Some(CM::Lock(Lock {
2208 transaction_id,
2209 key,
2210 }))
2211 }
2212 Command::LockAsync(key, callback) => {
2213 callback.send(transaction_id).ok();
2214 Some(CM::Lock(Lock {
2215 transaction_id,
2216 key,
2217 }))
2218 }
2219 Command::AcquireLock(key, callback) => {
2220 callbacks.ack.insert(transaction_id, callback);
2221 Some(CM::AcquireLock(Lock {
2222 transaction_id,
2223 key,
2224 }))
2225 }
2226 Command::ReleaseLock(key, callback) => {
2227 callbacks.ack.insert(transaction_id, callback);
2228 Some(CM::ReleaseLock(Lock {
2229 transaction_id,
2230 key,
2231 }))
2232 }
2233 Command::ReleaseLockAsync(key, callback) => {
2234 callback.send(transaction_id).ok();
2235 Some(CM::ReleaseLock(Lock {
2236 transaction_id,
2237 key,
2238 }))
2239 }
2240 Command::AllMessages(tx) => {
2241 callbacks.generic.push(tx);
2242 None
2243 }
2244 };
2245 Ok(ControlFlow::Continue(cm))
2246 } else {
2247 debug!("No more commands");
2248 Ok(ControlFlow::Break(()))
2249 }
2250}
2251
2252#[instrument(skip(callbacks), level = "trace", err)]
2253async fn process_incoming_server_message(
2254 msg: ConnectionResult<Option<ServerMessage>>,
2255 callbacks: &mut Callbacks,
2256) -> ConnectionResult<ControlFlow<()>> {
2257 match msg {
2258 Ok(Some(msg)) => {
2259 deliver_generic(&msg, callbacks);
2260 match msg {
2261 SM::State(state) => deliver_state(state, callbacks).await?,
2262 SM::CState(state) => deliver_cstate(state, callbacks).await?,
2263 SM::PState(pstate) => deliver_pstate(pstate, callbacks).await?,
2264 SM::LsState(ls) => deliver_ls(ls, callbacks).await?,
2265 SM::Err(err) => deliver_err(err, callbacks).await,
2266 SM::Ack(ack) => deliver_ack(ack, callbacks).await,
2267 SM::Welcome(_) | SM::Authorized(_) => (),
2268 }
2269 Ok(ControlFlow::Continue(()))
2270 }
2271 Ok(None) => {
2272 warn!("Connection closed.");
2273 Ok(ControlFlow::Break(()))
2274 }
2275 Err(e) => {
2276 error!("Error receiving message: {e}");
2277 Ok(ControlFlow::Break(()))
2278 }
2279 }
2280}
2281
2282#[instrument(skip(callbacks), level = "trace", ret)]
2283fn deliver_generic(msg: &ServerMessage, callbacks: &mut Callbacks) {
2284 callbacks.generic.retain(|tx| match tx.send(msg.clone()) {
2285 Ok(_) => true,
2286 Err(e) => {
2287 error!("Removing callback due to failure to deliver message to receiver: {e}");
2288 false
2289 }
2290 });
2291}
2292
2293#[instrument(skip(callbacks), level = "trace", err)]
2294async fn deliver_state(state: State, callbacks: &mut Callbacks) -> ConnectionResult<()> {
2295 if let Some(cb) = callbacks.state.remove(&state.transaction_id) {
2296 cb.send(Ok(state.clone())).ok();
2297 }
2298
2299 if let Some(cb) = callbacks.sub.get(&state.transaction_id) {
2300 let value = match state.event {
2301 StateEvent::Value(v) => Some(v),
2302 StateEvent::Deleted(_) => None,
2303 };
2304 cb.send(value)?;
2305 }
2306 Ok(())
2307}
2308
2309#[instrument(skip(callbacks), level = "trace", err)]
2310async fn deliver_cstate(state: CState, callbacks: &mut Callbacks) -> ConnectionResult<()> {
2311 if let Some(cb) = callbacks.cstate.remove(&state.transaction_id) {
2312 cb.send(Ok(state)).ok();
2313 }
2314 Ok(())
2315}
2316
2317#[instrument(skip(callbacks), level = "trace", err)]
2318async fn deliver_pstate(pstate: PState, callbacks: &mut Callbacks) -> ConnectionResult<()> {
2319 if let Some(cb) = callbacks.pstate.remove(&pstate.transaction_id) {
2320 cb.send(Ok(pstate.clone())).ok();
2321 }
2322
2323 if let Some(cb) = callbacks.psub.get(&pstate.transaction_id) {
2324 cb.send(pstate.event)?;
2325 }
2326 Ok(())
2327}
2328
2329#[instrument(skip(callbacks), level = "trace", err)]
2330async fn deliver_ls(ls: LsState, callbacks: &mut Callbacks) -> ConnectionResult<()> {
2331 if let Some(cb) = callbacks.lsstate.remove(&ls.transaction_id) {
2332 cb.send(Ok(ls.clone())).ok();
2333 }
2334
2335 if let Some(cb) = callbacks.subls.get(&ls.transaction_id) {
2336 cb.send(ls.children)?;
2337 }
2338
2339 Ok(())
2340}
2341
2342#[instrument(skip(callbacks), level = "trace", ret)]
2343async fn deliver_ack(ack: Ack, callbacks: &mut Callbacks) {
2344 if let Some(cb) = callbacks.ack.remove(&ack.transaction_id) {
2345 cb.send(Ok(ack)).ok();
2346 }
2347}
2348
2349#[instrument(skip(callbacks), level = "trace", ret)]
2350async fn deliver_err(err: Err, callbacks: &mut Callbacks) {
2351 if let Some(cb) = callbacks.ack.remove(&err.transaction_id) {
2352 cb.send(Err(err.clone())).ok();
2353 }
2354 if let Some(cb) = callbacks.state.remove(&err.transaction_id) {
2355 cb.send(Err(err.clone())).ok();
2356 }
2357 if let Some(cb) = callbacks.cstate.remove(&err.transaction_id) {
2358 cb.send(Err(err.clone())).ok();
2359 }
2360 if let Some(cb) = callbacks.pstate.remove(&err.transaction_id) {
2361 cb.send(Err(err.clone())).ok();
2362 }
2363 if let Some(cb) = callbacks.lsstate.remove(&err.transaction_id) {
2364 cb.send(Err(err.clone())).ok();
2365 }
2366}
2367
2368#[instrument(level = "trace", err)]
2369fn deserialize_key_value_pairs<T: DeserializeOwned + Debug>(
2370 kvps: KeyValuePairs,
2371) -> Result<TypedKeyValuePairs<T>, ConnectionError> {
2372 let mut typed = TypedKeyValuePairs::new();
2373 for kvp in kvps {
2374 typed.push(kvp.try_into()?);
2375 }
2376 Ok(typed)
2377}
2378
2379#[instrument(level = "trace", err)]
2380fn deserialize_pstate_event<T: DeserializeOwned + Debug>(
2381 pstate: PStateEvent,
2382) -> Result<TypedPStateEvent<T>, SubscriptionError> {
2383 Ok(pstate.try_into()?)
2384}