simploxide_ws_core/
dispatcher.rs1use std::{sync::Arc, task::Poll};
4
5use crate::{WsIn, router::ResponseRouter};
6use futures::{Stream, StreamExt};
7use serde::Deserialize;
8use tokio::sync::mpsc;
9use tokio_tungstenite::tungstenite::Message;
10use tokio_util::sync::CancellationToken;
11
12use super::{Event, RequestId, Result};
13
14type EventSender = mpsc::UnboundedSender<Result<Event>>;
15pub type EventReceiver = mpsc::UnboundedReceiver<Result<Event>>;
16
17pub fn init(ws_in: WsIn, router: ResponseRouter, token: CancellationToken) -> EventQueue {
18 let (events_tx, receiver) = mpsc::unbounded_channel::<Result<Event>>();
19 tokio::spawn(event_dispatcher_task(ws_in, events_tx, router, token));
20
21 EventQueue { receiver }
22}
23
24pub struct EventQueue {
27 receiver: EventReceiver,
28}
29
30impl EventQueue {
31 pub async fn next_event(&mut self) -> Option<Result<Event>> {
35 self.receiver.recv().await
36 }
37
38 pub fn into_receiver(self) -> EventReceiver {
40 self.receiver
41 }
42}
43
44async fn event_dispatcher_task(
45 mut ws_in: WsIn,
46 mut event_queue: EventSender,
47 router: ResponseRouter,
48 token: CancellationToken,
49) {
50 loop {
51 tokio::select! {
52 biased;
58
59 _ = token.cancelled() => {
60 tokio::task::yield_now().await;
63
64 let mut ws_in = Closed(ws_in);
65 while let Some(ev) = ws_in.next().await {
66 match ev {
67 Ok(msg) => {
68 process_raw_event(None, &mut event_queue, msg);
69 }
70 Err(e) => {
71 let _ = event_queue.send(Err(Arc::new(e)));
72 break;
73 }
74 }
75 }
76
77 break;
78 }
79
80 ev = ws_in.next() => {
81 match ev {
82 Some(Ok(msg)) => {
83 process_raw_event(Some(&router), &mut event_queue, msg);
84 }
85 Some(Err(e)) => {
86 let e = Arc::new(e);
87 let _ = event_queue.send(Err(Arc::clone(&e)));
88 router.shutdown(e);
89
90 break;
91 }
92 None => unreachable!("Must receive an error before connection drops")
93
94 }
95 }
96 }
97 }
98
99 log::debug!("Dispatcher task finished");
100}
101
102fn process_raw_event(router: Option<&ResponseRouter>, event_queue: &mut EventSender, msg: Message) {
109 let event = match msg {
110 Message::Text(utf8bytes) => utf8bytes.to_string(),
111 unexpected => {
112 log::warn!("Ignoring event in unexpecetd format: {unexpected:#?}");
113 return;
114 }
115 };
116
117 let header: EventHeader = match serde_json::from_str(&event) {
118 Ok(header) => header,
119 Err(e) => {
120 log::error!("Got invalid JSON form the server\n{event:?}\n{e}");
121 return;
122 }
123 };
124
125 if let Some(corr_id) = header.corr_id {
126 let id: RequestId = match corr_id.parse() {
127 Ok(id) => id,
128 Err(e) => {
129 log::error!("Failed to parse corr_id: {corr_id}\n{e}");
130 return;
131 }
132 };
133
134 match router {
135 Some(router) => router.deliver(id, event),
136 None => {
137 log::warn!("Dropping response because router task already finished\n{event}");
138 }
139 }
140 } else {
141 let _ = event_queue.send(Ok(event));
142 }
143}
144
145#[derive(Deserialize)]
147struct EventHeader<'a> {
148 #[serde(rename = "corrId")]
149 #[serde(borrow)]
150 corr_id: Option<&'a str>,
151}
152
153struct Closed<S>(S);
156
157impl<S> Stream for Closed<S>
158where
159 S: Stream + Unpin,
160{
161 type Item = S::Item;
162
163 fn poll_next(
164 mut self: std::pin::Pin<&mut Self>,
165 cx: &mut std::task::Context<'_>,
166 ) -> Poll<Option<Self::Item>> {
167 match self.0.poll_next_unpin(cx) {
168 Poll::Ready(v) => Poll::Ready(v),
169 Poll::Pending => Poll::Ready(None),
170 }
171 }
172}