1use std::collections::HashMap;
17use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
18use std::sync::mpsc;
19use std::sync::Arc;
20use std::sync::Mutex as StdMutex;
21use std::thread;
22
23use futures_util::{SinkExt, StreamExt};
24use serde_json::Value;
25use tokio::sync::oneshot;
26use tokio::sync::{mpsc as async_mpsc, Mutex as TokioMutex};
27use tokio_tungstenite::{connect_async, tungstenite::Message};
28
29use log::{debug, info, warn};
30
31use crate::wamp;
32
33const DEFAULT_WAAPI_URL: &str = "ws://localhost:8080/waapi";
39
40const DEFAULT_REALM: &str = "realm1";
46
47#[derive(Debug, thiserror::Error)]
53pub enum WaapiError {
54 #[error("client already disconnected")]
58 Disconnected,
59 #[error("WAMP error: {0}")]
63 Wamp(String),
64 #[error("WebSocket error: {0}")]
68 WebSocket(#[from] Box<tokio_tungstenite::tungstenite::Error>),
69 #[error("{0}")]
73 Serde(#[from] serde_json::Error),
74 #[error("{0}")]
78 Io(#[from] std::io::Error),
79}
80
81type CallResult = Result<Option<Value>, WaapiError>;
84type SubResult = Result<u64, WaapiError>;
85type UnsubResult = Result<(), WaapiError>;
86
87pub type EventPayload = (u64, Option<Value>);
89
90type WsSink = futures_util::stream::SplitSink<
93 tokio_tungstenite::WebSocketStream<
94 tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
95 >,
96 Message,
97>;
98
99struct WampConn {
100 ws_tx: TokioMutex<WsSink>,
101 pending_calls: StdMutex<HashMap<u64, oneshot::Sender<CallResult>>>,
102 pending_subs: StdMutex<HashMap<u64, oneshot::Sender<SubResult>>>,
103 pending_unsubs: StdMutex<HashMap<u64, oneshot::Sender<UnsubResult>>>,
104 event_senders: StdMutex<HashMap<u64, async_mpsc::UnboundedSender<EventPayload>>>,
105 next_id: AtomicU64,
106}
107
108impl WampConn {
109 fn new(sink: WsSink) -> Self {
110 Self {
111 ws_tx: TokioMutex::new(sink),
112 pending_calls: StdMutex::new(HashMap::new()),
113 pending_subs: StdMutex::new(HashMap::new()),
114 pending_unsubs: StdMutex::new(HashMap::new()),
115 event_senders: StdMutex::new(HashMap::new()),
116 next_id: AtomicU64::new(1),
117 }
118 }
119
120 fn next_id(&self) -> u64 {
121 self.next_id.fetch_add(1, Ordering::Relaxed)
122 }
123
124 async fn send(&self, text: String) -> Result<(), WaapiError> {
125 self.ws_tx
126 .lock()
127 .await
128 .send(Message::Text(text.into()))
129 .await
130 .map_err(|e| WaapiError::WebSocket(Box::new(e)))
131 }
132}
133
134type WsStream = tokio_tungstenite::WebSocketStream<
137 tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
138>;
139
140async fn run_event_loop(
141 conn: Arc<WampConn>,
142 mut ws_rx: futures_util::stream::SplitStream<WsStream>,
143 connected: Arc<AtomicBool>,
144) {
145 while let Some(msg) = ws_rx.next().await {
146 match msg {
147 Ok(Message::Text(text)) => {
148 if let Some(wamp_msg) = wamp::parse(&text) {
149 dispatch(&conn, wamp_msg);
150 }
151 }
152 Ok(Message::Close(_)) | Err(_) => break,
153 _ => {}
154 }
155 }
156 connected.store(false, Ordering::Release);
157 drain_pending(&conn);
159}
160
161fn dispatch(conn: &WampConn, msg: wamp::WampMessage) {
162 match msg {
163 wamp::WampMessage::Result { request_id, kwargs } => {
164 if let Some(tx) = conn
165 .pending_calls
166 .lock()
167 .unwrap_or_else(|e| e.into_inner())
168 .remove(&request_id)
169 {
170 let _ = tx.send(Ok(kwargs));
171 }
172 }
173 wamp::WampMessage::Error {
174 request_type,
175 request_id,
176 error,
177 } => {
178 let err_str = error.clone();
179 if request_type == 48 {
181 if let Some(tx) = conn
182 .pending_calls
183 .lock()
184 .unwrap_or_else(|e| e.into_inner())
185 .remove(&request_id)
186 {
187 let _ = tx.send(Err(WaapiError::Wamp(err_str)));
188 return;
189 }
190 }
191 if request_type == 32 {
193 if let Some(tx) = conn
194 .pending_subs
195 .lock()
196 .unwrap_or_else(|e| e.into_inner())
197 .remove(&request_id)
198 {
199 let _ = tx.send(Err(WaapiError::Wamp(error)));
200 return;
201 }
202 }
203 if request_type == 34 {
205 if let Some(tx) = conn
206 .pending_unsubs
207 .lock()
208 .unwrap_or_else(|e| e.into_inner())
209 .remove(&request_id)
210 {
211 let _ = tx.send(Err(WaapiError::Wamp(error)));
212 }
213 }
214 }
215 wamp::WampMessage::Subscribed {
216 request_id,
217 sub_id,
218 } => {
219 if let Some(tx) = conn
220 .pending_subs
221 .lock()
222 .unwrap_or_else(|e| e.into_inner())
223 .remove(&request_id)
224 {
225 let _ = tx.send(Ok(sub_id));
226 }
227 }
228 wamp::WampMessage::Unsubscribed { request_id } => {
229 if let Some(tx) = conn
230 .pending_unsubs
231 .lock()
232 .unwrap_or_else(|e| e.into_inner())
233 .remove(&request_id)
234 {
235 let _ = tx.send(Ok(()));
236 }
237 }
238 wamp::WampMessage::Event {
239 sub_id,
240 pub_id,
241 kwargs,
242 } => {
243 let senders = conn
244 .event_senders
245 .lock()
246 .unwrap_or_else(|e| e.into_inner());
247 if let Some(tx) = senders.get(&sub_id) {
248 let _ = tx.send((pub_id, kwargs));
249 }
250 }
251 wamp::WampMessage::Goodbye | wamp::WampMessage::Welcome { .. } => {}
252 }
253}
254
255fn drain_pending(conn: &WampConn) {
257 let calls: Vec<_> = conn
258 .pending_calls
259 .lock()
260 .unwrap_or_else(|e| e.into_inner())
261 .drain()
262 .collect();
263 for (_, tx) in calls {
264 let _ = tx.send(Err(WaapiError::Disconnected));
265 }
266 let subs: Vec<_> = conn
267 .pending_subs
268 .lock()
269 .unwrap_or_else(|e| e.into_inner())
270 .drain()
271 .collect();
272 for (_, tx) in subs {
273 let _ = tx.send(Err(WaapiError::Disconnected));
274 }
275 let unsubs: Vec<_> = conn
276 .pending_unsubs
277 .lock()
278 .unwrap_or_else(|e| e.into_inner())
279 .drain()
280 .collect();
281 for (_, tx) in unsubs {
282 let _ = tx.send(Err(WaapiError::Disconnected));
283 }
284}
285
286async fn read_welcome(
290 ws_rx: &mut futures_util::stream::SplitStream<WsStream>,
291) -> Result<u64, WaapiError> {
292 loop {
293 match ws_rx.next().await {
294 Some(Ok(Message::Text(text))) => {
295 if let Some(wamp::WampMessage::Welcome { session_id }) = wamp::parse(&text) {
296 return Ok(session_id);
297 }
298 return Err(WaapiError::Wamp(format!("expected WELCOME, got: {text}")));
299 }
300 Some(Ok(_)) => continue, Some(Err(e)) => return Err(WaapiError::WebSocket(Box::new(e))),
302 None => return Err(WaapiError::Disconnected),
303 }
304 }
305}
306
307pub struct SubscriptionHandle {
316 sub_id: u64,
317 conn: Arc<WampConn>,
318 subscription_ids: Arc<StdMutex<Vec<u64>>>,
319 recv_task: Option<tokio::task::JoinHandle<()>>,
320 is_unsubscribed: bool,
321}
322
323fn mark_unsubscribed(flag: &mut bool) -> bool {
324 if *flag {
325 false
326 } else {
327 *flag = true;
328 true
329 }
330}
331
332impl SubscriptionHandle {
333 pub async fn unsubscribe(mut self) -> Result<(), WaapiError> {
339 debug!("Unsubscribing sub_id={}", self.sub_id);
340 if let Some(task) = self.recv_task.take() {
341 task.abort();
342 }
343 self.subscription_ids
344 .lock()
345 .unwrap_or_else(|e| e.into_inner())
346 .retain(|&id| id != self.sub_id);
347 self.conn
349 .event_senders
350 .lock()
351 .unwrap_or_else(|e| e.into_inner())
352 .remove(&self.sub_id);
353 if !mark_unsubscribed(&mut self.is_unsubscribed) {
354 return Ok(());
355 }
356 do_network_unsubscribe(&self.conn, self.sub_id).await
357 }
358}
359
360async fn do_network_unsubscribe(conn: &WampConn, sub_id: u64) -> Result<(), WaapiError> {
361 let id = conn.next_id();
362 let (tx, rx) = oneshot::channel();
363 conn.pending_unsubs
364 .lock()
365 .unwrap_or_else(|e| e.into_inner())
366 .insert(id, tx);
367 conn.send(wamp::unsubscribe_msg(id, sub_id)).await?;
368 rx.await.unwrap_or(Err(WaapiError::Disconnected))
369}
370
371impl Drop for SubscriptionHandle {
372 fn drop(&mut self) {
373 let sub_id = self.sub_id;
374 let conn = Arc::clone(&self.conn);
375 let subscription_ids = Arc::clone(&self.subscription_ids);
376 if let Some(task) = self.recv_task.take() {
377 task.abort();
378 }
379 subscription_ids
380 .lock()
381 .unwrap_or_else(|e| e.into_inner())
382 .retain(|&id| id != sub_id);
383 conn.event_senders
385 .lock()
386 .unwrap_or_else(|e| e.into_inner())
387 .remove(&sub_id);
388 if !mark_unsubscribed(&mut self.is_unsubscribed) {
389 return;
390 }
391 if let Ok(rt) = tokio::runtime::Handle::try_current() {
392 debug!("SubscriptionHandle dropped, spawning unsubscribe for sub_id={sub_id}");
393 rt.spawn(async move {
394 let _ = do_network_unsubscribe(&conn, sub_id).await;
395 });
396 } else {
397 warn!("SubscriptionHandle dropped without runtime, skipping network unsubscribe for sub_id={sub_id}");
398 }
399 }
400}
401
402pub struct WaapiClient {
416 conn: Option<Arc<WampConn>>,
417 event_loop_handle: Option<tokio::task::JoinHandle<()>>,
418 subscription_ids: Arc<StdMutex<Vec<u64>>>,
419 connected: Arc<AtomicBool>,
420}
421
422impl WaapiClient {
423 pub async fn connect() -> Result<Self, WaapiError> {
431 Self::connect_with_url(DEFAULT_WAAPI_URL).await
432 }
433
434 pub async fn connect_with_url(url: &str) -> Result<Self, WaapiError> {
440 info!("Connecting to WAAPI at {url}");
441 let (ws_stream, _) = connect_async(url).await.map_err(|e| WaapiError::WebSocket(Box::new(e)))?;
442 let (ws_tx, mut ws_rx) = ws_stream.split();
443
444 let conn = Arc::new(WampConn::new(ws_tx));
445
446 conn.send(wamp::hello_msg(DEFAULT_REALM)).await?;
448 let _session_id = read_welcome(&mut ws_rx).await?;
449
450 let connected = Arc::new(AtomicBool::new(true));
451 let connected_flag = Arc::clone(&connected);
452 let conn_for_loop = Arc::clone(&conn);
453 let handle = tokio::spawn(async move {
454 run_event_loop(conn_for_loop, ws_rx, connected_flag).await;
455 });
456
457 info!("Connected to WAAPI at {url}");
458 Ok(Self {
459 conn: Some(conn),
460 event_loop_handle: Some(handle),
461 subscription_ids: Arc::new(StdMutex::new(Vec::new())),
462 connected,
463 })
464 }
465
466 pub async fn call(
482 &self,
483 uri: &str,
484 args: Option<Value>,
485 options: Option<Value>,
486 ) -> Result<Option<Value>, WaapiError> {
487 let conn = self.conn.as_ref().ok_or(WaapiError::Disconnected)?;
488 let id = conn.next_id();
489 let (tx, rx) = oneshot::channel();
490 conn.pending_calls
491 .lock()
492 .unwrap_or_else(|e| e.into_inner())
493 .insert(id, tx);
494 debug!("Calling WAAPI: {uri} (id={id})");
495 conn.send(wamp::call_msg(id, uri, args.as_ref(), options.as_ref()))
496 .await?;
497 rx.await.unwrap_or(Err(WaapiError::Disconnected))
498 }
499
500 pub(crate) async fn subscribe_inner(
502 &self,
503 topic: &str,
504 options: Option<Value>,
505 ) -> Result<
506 (
507 SubscriptionHandle,
508 async_mpsc::UnboundedReceiver<EventPayload>,
509 ),
510 WaapiError,
511 > {
512 let conn = self.conn.as_ref().ok_or(WaapiError::Disconnected)?;
513 let id = conn.next_id();
514 let (tx, rx) = oneshot::channel();
515 conn.pending_subs
516 .lock()
517 .unwrap_or_else(|e| e.into_inner())
518 .insert(id, tx);
519 conn.send(wamp::subscribe_msg(id, topic, options.as_ref()))
520 .await?;
521 let sub_id = rx.await.unwrap_or(Err(WaapiError::Disconnected))?;
522 debug!("Subscribed to {topic} (sub_id={sub_id})");
523
524 let (event_tx, event_rx) = async_mpsc::unbounded_channel();
525 conn.event_senders
526 .lock()
527 .unwrap_or_else(|e| e.into_inner())
528 .insert(sub_id, event_tx);
529 self.subscription_ids
530 .lock()
531 .unwrap_or_else(|e| e.into_inner())
532 .push(sub_id);
533
534 let handle = SubscriptionHandle {
535 sub_id,
536 conn: Arc::clone(conn),
537 subscription_ids: Arc::clone(&self.subscription_ids),
538 recv_task: None,
539 is_unsubscribed: false,
540 };
541 Ok((handle, event_rx))
542 }
543
544 pub async fn subscribe<F>(
559 &self,
560 topic: &str,
561 options: Option<Value>,
562 callback: F,
563 ) -> Result<SubscriptionHandle, WaapiError>
564 where
565 F: Fn(Option<Value>) + Send + Sync + 'static,
566 {
567 let (mut handle, mut event_rx) = self.subscribe_inner(topic, options).await?;
568 let recv_task = tokio::spawn(async move {
569 while let Some((_pub_id, kwargs)) = event_rx.recv().await {
570 callback(kwargs);
571 }
572 });
573 handle.recv_task = Some(recv_task);
574 Ok(handle)
575 }
576
577 #[must_use]
583 pub fn is_connected(&self) -> bool {
584 self.conn.is_some() && self.connected.load(Ordering::Acquire)
585 }
586
587 pub async fn disconnect(mut self) {
595 info!("Disconnecting from WAAPI");
596 self.cleanup().await;
597 info!("Disconnected from WAAPI");
598 }
599
600 async fn cleanup(&mut self) {
601 self.connected.store(false, Ordering::Release);
602 if let Some(conn) = self.conn.take() {
603 let ids: Vec<u64> = {
605 let mut guard = self.subscription_ids.lock().unwrap_or_else(|e| e.into_inner());
606 std::mem::take(&mut *guard)
607 };
608 for sub_id in ids {
609 let id = conn.next_id();
610 let (tx, rx) = oneshot::channel();
611 conn.pending_unsubs
612 .lock()
613 .unwrap_or_else(|e| e.into_inner())
614 .insert(id, tx);
615 if conn.send(wamp::unsubscribe_msg(id, sub_id)).await.is_ok() {
616 let _ = rx.await;
617 }
618 }
619 let _ = conn.send(wamp::goodbye_msg()).await;
621 let _ = conn.ws_tx.lock().await.close().await;
623 }
624 if let Some(handle) = self.event_loop_handle.take() {
625 handle.abort();
626 }
627 }
628}
629
630impl Drop for WaapiClient {
631 fn drop(&mut self) {
632 if self.conn.is_some() || self.event_loop_handle.is_some() {
633 let conn = self.conn.take();
634 let event_loop = self.event_loop_handle.take();
635 let subscription_ids = Arc::clone(&self.subscription_ids);
636 let connected = Arc::clone(&self.connected);
637 connected.store(false, Ordering::Release);
638 if let Ok(rt) = tokio::runtime::Handle::try_current() {
639 debug!("WaapiClient dropped, spawning async cleanup");
640 rt.spawn(async move {
641 if let Some(conn) = conn {
642 let ids: Vec<u64> = {
643 let mut guard =
644 subscription_ids.lock().unwrap_or_else(|e| e.into_inner());
645 std::mem::take(&mut *guard)
646 };
647 for sub_id in ids {
648 let id = conn.next_id();
649 let (tx, rx) = oneshot::channel::<UnsubResult>();
650 conn.pending_unsubs
651 .lock()
652 .unwrap_or_else(|e| e.into_inner())
653 .insert(id, tx);
654 if conn.send(wamp::unsubscribe_msg(id, sub_id)).await.is_ok() {
655 let _ = rx.await;
656 }
657 }
658 let _ = conn.send(wamp::goodbye_msg()).await;
659 let _ = conn.ws_tx.lock().await.close().await;
660 }
661 if let Some(h) = event_loop {
662 h.abort();
663 }
664 });
665 } else {
666 warn!("WaapiClient dropped without runtime, skipping graceful cleanup");
667 if let Some(h) = event_loop {
668 h.abort();
669 }
670 }
671 }
672 }
673}
674
675pub struct SubscriptionHandleSync {
686 runtime: Arc<tokio::runtime::Runtime>,
687 inner: Option<SubscriptionHandle>,
688 bridge_join: Option<thread::JoinHandle<()>>,
689 bridge_thread_id: Option<thread::ThreadId>,
690}
691
692impl SubscriptionHandleSync {
693 pub fn unsubscribe(mut self) -> Result<(), WaapiError> {
699 let inner = self.inner.take();
700 let bridge_join = self.bridge_join.take();
701 if let Some(h) = inner {
702 self.runtime.block_on(h.unsubscribe())?;
703 }
704 if let Some(jh) = bridge_join {
705 let _ = jh.join();
706 }
707 Ok(())
708 }
709}
710
711impl Drop for SubscriptionHandleSync {
712 fn drop(&mut self) {
713 let is_bridge_thread = self.bridge_thread_id.as_ref() == Some(&thread::current().id());
714 let inner = self.inner.take();
715 let bridge_join = self.bridge_join.take();
716 let runtime = Arc::clone(&self.runtime);
717 if let Some(h) = inner {
718 if tokio::runtime::Handle::try_current().is_ok() {
719 warn!("SubscriptionHandleSync dropped inside async context, falling back to spawn");
720 runtime.handle().spawn(async move {
721 let _ = h.unsubscribe().await;
722 });
723 } else {
724 let _ = runtime.block_on(h.unsubscribe());
725 }
726 }
727 if !is_bridge_thread {
728 if let Some(jh) = bridge_join {
729 let _ = jh.join();
730 }
731 }
732 }
733}
734
735pub struct WaapiClientSync {
746 runtime: Arc<tokio::runtime::Runtime>,
747 client: Option<WaapiClient>,
748}
749
750impl WaapiClientSync {
751 pub fn connect() -> Result<Self, WaapiError> {
757 Self::connect_with_url(DEFAULT_WAAPI_URL)
758 }
759
760 pub fn connect_with_url(url: &str) -> Result<Self, WaapiError> {
766 info!("Connecting to WAAPI (sync) at {url}");
767 let runtime = Arc::new(
768 tokio::runtime::Builder::new_multi_thread()
769 .enable_all()
770 .build()?,
771 );
772 let client = runtime.block_on(WaapiClient::connect_with_url(url))?;
773 info!("Connected to WAAPI (sync) at {url}");
774 Ok(Self {
775 runtime,
776 client: Some(client),
777 })
778 }
779
780 pub fn call(
786 &self,
787 uri: &str,
788 args: Option<Value>,
789 options: Option<Value>,
790 ) -> Result<Option<Value>, WaapiError> {
791 let client = self.client.as_ref().ok_or(WaapiError::Disconnected)?;
792 self.runtime.block_on(client.call(uri, args, options))
793 }
794
795 pub fn subscribe<F>(
805 &self,
806 topic: &str,
807 options: Option<Value>,
808 callback: F,
809 ) -> Result<SubscriptionHandleSync, WaapiError>
810 where
811 F: Fn(Option<Value>) + Send + Sync + 'static,
812 {
813 let client = self.client.as_ref().ok_or(WaapiError::Disconnected)?;
814 let (inner, mut async_rx) = self
815 .runtime
816 .block_on(client.subscribe_inner(topic, options))?;
817 let (id_tx, id_rx) = mpsc::channel();
818 let runtime = Arc::clone(&self.runtime);
819 let bridge_join = thread::spawn(move || {
820 let _ = id_tx.send(thread::current().id());
821 while let Some((_pub_id, kwargs)) = runtime.block_on(async_rx.recv()) {
822 callback(kwargs);
823 }
824 });
825 let bridge_thread_id = id_rx.recv().ok();
826 Ok(SubscriptionHandleSync {
827 runtime: Arc::clone(&self.runtime),
828 inner: Some(inner),
829 bridge_join: Some(bridge_join),
830 bridge_thread_id,
831 })
832 }
833
834 #[must_use]
840 pub fn is_connected(&self) -> bool {
841 self.client.as_ref().is_some_and(|c| c.is_connected())
842 }
843
844 pub fn disconnect(mut self) {
850 info!("Disconnecting from WAAPI (sync)");
851 if let Some(client) = self.client.take() {
852 self.runtime.block_on(client.disconnect());
853 }
854 info!("Disconnected from WAAPI (sync)");
855 }
856}
857
858impl Drop for WaapiClientSync {
859 fn drop(&mut self) {
860 if let Some(client) = self.client.take() {
861 if tokio::runtime::Handle::try_current().is_ok() {
862 warn!("WaapiClientSync dropped inside async context, offloading cleanup to a dedicated thread");
863 let runtime = Arc::clone(&self.runtime);
864 let _ = thread::Builder::new()
865 .name("waapi-sync-drop-cleanup".to_string())
866 .spawn(move || {
867 runtime.block_on(client.disconnect());
868 });
869 } else {
870 self.runtime.block_on(client.disconnect());
871 }
872 }
873 }
874}
875
876#[cfg(test)]
879mod tests {
880 use super::*;
881
882 #[test]
883 fn test_mark_unsubscribed_is_idempotent() {
884 let mut is_unsubscribed = false;
885 assert!(mark_unsubscribed(&mut is_unsubscribed));
886 assert!(!mark_unsubscribed(&mut is_unsubscribed));
887 }
888
889 #[tokio::test]
890 async fn test_sync_client_drop_inside_async_context_is_safe() {
891 let runtime = Arc::new(
892 tokio::runtime::Builder::new_multi_thread()
893 .enable_all()
894 .build()
895 .expect("failed to create runtime"),
896 );
897 let async_client = WaapiClient {
898 conn: None,
899 event_loop_handle: None,
900 subscription_ids: Arc::new(StdMutex::new(Vec::new())),
901 connected: Arc::new(AtomicBool::new(false)),
902 };
903 let sync_client = WaapiClientSync {
904 runtime,
905 client: Some(async_client),
906 };
907 drop(sync_client);
908 }
909}