workflow_websocket/client/
native.rs1use super::{
2 error::Error, message::Message, result::Result, Ack, ConnectOptions, ConnectResult,
3 ConnectStrategy, Handshake, Resolver, WebSocketConfig,
4};
5use futures::{
6 select_biased,
7 stream::{SplitSink, SplitStream},
8 FutureExt,
9};
10use futures_util::{SinkExt, StreamExt};
11use std::sync::atomic::{AtomicBool, Ordering};
12use std::sync::{Arc, Mutex};
13#[allow(unused_imports)]
14use std::time::Instant;
15use tokio::net::TcpStream;
16use tokio::time::timeout;
17use tokio_tungstenite::{
18 connect_async_with_config, tungstenite::protocol::Message as TsMessage, MaybeTlsStream,
19 WebSocketStream,
20};
21use tungstenite::protocol::WebSocketConfig as TsWebSocketConfig;
22pub use workflow_core as core;
23use workflow_core::channel::*;
24pub use workflow_log::*;
25
26impl From<Message> for tungstenite::Message {
27 fn from(message: Message) -> Self {
28 match message {
29 Message::Text(text) => text.into(),
30 Message::Binary(data) => data.into(),
31 _ => {
32 panic!("From<Message> for tungstenite::Message - invalid message type: {message:?}",)
33 }
34 }
35 }
36}
37
38impl From<tungstenite::Message> for Message {
39 fn from(message: tungstenite::Message) -> Self {
40 match message {
41 TsMessage::Text(text) => Message::Text(text),
42 TsMessage::Binary(data) => Message::Binary(data),
43 TsMessage::Close(_) => Message::Close,
44 _ => panic!(
45 "TryFrom<tungstenite::Message> for Message - invalid message type: {message:?}",
46 ),
47 }
48 }
49}
50
51impl From<WebSocketConfig> for TsWebSocketConfig {
52 fn from(config: WebSocketConfig) -> Self {
53 TsWebSocketConfig {
54 write_buffer_size: config.write_buffer_size,
55 max_write_buffer_size: config.max_write_buffer_size,
56 max_message_size: config.max_message_size,
57 max_frame_size: config.max_frame_size,
58 accept_unmasked_frames: config.accept_unmasked_frames,
59 ..Default::default()
60 }
61 }
62}
63
64#[derive(Default)]
65struct Settings {
66 default_url: Option<String>,
67 current_url: Option<String>,
68}
69
70pub struct WebSocketInterface {
71 settings: Mutex<Settings>,
72 config: Mutex<WebSocketConfig>,
73 reconnect: AtomicBool,
74 is_connected: AtomicBool,
75 receiver_channel: Channel<Message>,
76 sender_channel: Channel<(Message, Ack)>,
77 shutdown: DuplexChannel<()>,
78}
79
80impl WebSocketInterface {
81 pub fn new(
82 url: Option<&str>,
83 config: Option<WebSocketConfig>,
84 sender_channel: Channel<(Message, Ack)>,
85 receiver_channel: Channel<Message>,
86 ) -> Result<WebSocketInterface> {
87 let settings = Settings {
88 default_url: url.map(String::from),
89 ..Default::default()
90 };
91
92 let iface = WebSocketInterface {
93 settings: Mutex::new(settings),
94 config: Mutex::new(config.unwrap_or_default()),
95 receiver_channel,
96 sender_channel,
97 reconnect: AtomicBool::new(true),
98 is_connected: AtomicBool::new(false),
99 shutdown: DuplexChannel::unbounded(),
100 };
101
102 Ok(iface)
103 }
104
105 pub fn default_url(self: &Arc<Self>) -> Option<String> {
106 self.settings.lock().unwrap().default_url.clone()
107 }
108
109 pub fn current_url(self: &Arc<Self>) -> Option<String> {
110 self.settings.lock().unwrap().current_url.clone()
111 }
112
113 pub fn set_default_url(self: &Arc<Self>, url: &str) {
114 self.settings
115 .lock()
116 .unwrap()
117 .default_url
118 .replace(url.to_string());
119 }
120
121 pub fn set_current_url(self: &Arc<Self>, url: &str) {
122 self.settings
123 .lock()
124 .unwrap()
125 .current_url
126 .replace(url.to_string());
127 }
128
129 pub fn is_connected(self: &Arc<Self>) -> bool {
130 self.is_connected.load(Ordering::SeqCst)
131 }
132
133 fn resolver(&self) -> Option<Arc<dyn Resolver>> {
134 self.config.lock().unwrap().resolver.clone()
135 }
136
137 fn handshake(&self) -> Option<Arc<dyn Handshake>> {
138 self.config.lock().unwrap().handshake.clone()
139 }
140
141 pub fn configure(&self, config: WebSocketConfig) {
142 *self.config.lock().unwrap() = config;
143 }
144
145 fn config(&self) -> WebSocketConfig {
146 self.config.lock().unwrap().clone()
147 }
148
149 async fn resolve_url(self: &Arc<Self>, options: &ConnectOptions) -> Result<String> {
150 let url = if let Some(url) = options.url.as_ref().or(self.default_url().as_ref()) {
151 url.clone()
152 } else if let Some(resolver) = self.resolver() {
153 resolver.resolve_url().await?
154 } else {
155 return Err(Error::MissingUrl);
156 };
157 self.set_current_url(&url);
158 Ok(url)
159 }
160
161 pub async fn connect(self: &Arc<Self>, options: ConnectOptions) -> ConnectResult<Error> {
162 let this = self.clone();
163
164 if self.is_connected.load(Ordering::SeqCst) {
165 return Err(Error::AlreadyConnected);
166 }
167
168 let (connect_trigger, connect_listener) = oneshot::<Result<()>>();
169 let mut connect_trigger = Some(connect_trigger);
170
171 this.reconnect.store(true, Ordering::SeqCst);
172
173 let block_async_connect = options.block_async_connect;
174 let ts_websocket_config = Some(self.config().into());
175
176 core::task::spawn(async move {
177 'outer: loop {
178 match this.resolve_url(&options).await {
179 Ok(url) => {
180 let connect_future =
181 connect_async_with_config(&url, ts_websocket_config, false);
182 let timeout_future = timeout(options.connect_timeout(), connect_future);
183
184 match timeout_future.await {
185 Ok(Ok(stream)) => {
187 this.is_connected.store(true, Ordering::SeqCst);
190 let (mut ws_stream, _) = stream;
191
192 if connect_trigger.is_some() {
193 connect_trigger.take().unwrap().try_send(Ok(())).ok();
194 }
195
196 if let Err(err) = this.dispatcher(&mut ws_stream, &options).await {
197 log_trace!("WebSocket dispatcher error: {}", err);
198 }
199
200 this.is_connected.store(false, Ordering::SeqCst);
201 }
202 Ok(Err(e)) => {
204 log_trace!("WebSocket failed to connect to {}: {}", url, e);
205 if matches!(options.strategy, ConnectStrategy::Fallback) {
206 if options.block_async_connect && connect_trigger.is_some() {
207 connect_trigger
208 .take()
209 .unwrap()
210 .try_send(Err(e.into()))
211 .ok();
212 }
213 break;
214 }
215 workflow_core::task::sleep(options.retry_interval()).await;
216 }
217 Err(_) => {
219 log_trace!(
220 "WebSocket connection timeout while connecting to {}",
221 url
222 );
223 if matches!(options.strategy, ConnectStrategy::Fallback) {
224 if options.block_async_connect && connect_trigger.is_some() {
225 connect_trigger
226 .take()
227 .unwrap()
228 .try_send(Err(Error::ConnectionTimeout))
229 .ok();
230 }
231 break;
232 }
233 workflow_core::task::sleep(options.retry_interval()).await;
234 }
235 };
236
237 if !this.reconnect.load(Ordering::SeqCst) {
238 break 'outer;
239 };
240 }
241 Err(err) => {
242 log_trace!("WebSocket failed to get session URL: {}", err);
243 if !this.reconnect.load(Ordering::SeqCst) {
244 break 'outer;
245 } else {
246 workflow_core::task::sleep(options.retry_interval()).await;
247 }
248 }
249 }
250 }
251 });
252
253 match block_async_connect {
254 true => match connect_listener.recv().await? {
255 Ok(_) => Ok(None),
256 Err(e) => Err(e),
257 },
258 false => Ok(Some(connect_listener)),
259 }
260 }
261
262 async fn handshake_impl(
263 self: &Arc<Self>,
264 ws_sender: &mut SplitSink<&mut WebSocketStream<MaybeTlsStream<TcpStream>>, TsMessage>,
265 ws_receiver: &mut SplitStream<&mut WebSocketStream<MaybeTlsStream<TcpStream>>>,
266 ) -> Result<()> {
267 if let Some(handshake) = self.handshake() {
268 let (sender_tx, sender_rx) = unbounded();
269 let (receiver_tx, receiver_rx) = unbounded();
270 let (accept_tx, accept_rx) = oneshot();
271
272 core::task::spawn(async move {
273 accept_tx
274 .send(handshake.handshake(&sender_tx, &receiver_rx).await)
275 .await
276 .unwrap_or_else(|err| {
277 log_trace!("WebSocket handshake unable to send completion: `{}`", err)
278 });
279 });
280
281 loop {
282 select_biased! {
283 result = accept_rx.recv().fuse() => {
284 return result?;
285 },
286 msg = sender_rx.recv().fuse() => {
287 if let Ok(msg) = msg {
288 ws_sender.send(msg.into()).await?;
289 }
290 },
291 msg = ws_receiver.next().fuse() => {
292 if let Some(Ok(msg)) = msg {
293 receiver_tx.send(msg.into()).await?;
294 } else {
295 return Err(Error::NegotiationFailure);
296 }
297 }
298 }
299 }
300 }
301
302 Ok(())
303 }
304
305 async fn dispatcher(
306 self: &Arc<Self>,
307 ws_stream: &mut WebSocketStream<MaybeTlsStream<TcpStream>>,
308 _options: &ConnectOptions,
309 ) -> Result<()> {
310 let (mut ws_sender, mut ws_receiver) = ws_stream.split();
311
312 self.handshake_impl(&mut ws_sender, &mut ws_receiver)
313 .await?;
314
315 #[cfg(feature = "delay-reconnect")]
316 let connection_start = Instant::now();
317 #[cfg(feature = "delay-reconnect")]
318 let mut closed_ungracefully = false;
319
320 self.receiver_channel.send(Message::Open).await?;
321
322 loop {
323 select_biased! {
324 dispatch = self.sender_channel.recv().fuse() => {
325 if let Ok((msg,ack)) = dispatch {
326 if let Some(ack_sender) = ack {
327 let result = ws_sender.send(msg.into()).await
328 .map(Arc::new)
329 .map_err(|err|Arc::new(err.into()));
330 ack_sender.send(result).await?;
331 } else {
332 ws_sender.send(msg.into()).await?;
333 }
334 }
335 }
336 msg = ws_receiver.next().fuse() => {
337 match msg {
338 Some(Ok(msg)) => {
339 match msg {
340 TsMessage::Binary(_) | TsMessage::Text(_) | TsMessage::Close(_) => {
341 self
342 .receiver_channel
343 .send(msg.into())
344 .await?;
345 }
346 TsMessage::Ping(data) => {
347 ws_sender.send(TsMessage::Pong(data)).await?;
348 },
349 TsMessage::Pong(_) => { },
350 TsMessage::Frame(_frame) => { },
351 }
352 }
353 Some(Err(e)) => {
354 self.receiver_channel.send(Message::Close).await?;
355 log_trace!("WebSocket error: {}", e);
356 #[cfg(feature = "delay-reconnect")] {
357 closed_ungracefully = true;
358 }
359 break;
360 }
361 None => {
362 self.receiver_channel.send(Message::Close).await?;
363 log_trace!("WebSocket connection closed");
364 #[cfg(feature = "delay-reconnect")] {
365 closed_ungracefully = true;
366 }
367 break;
368 }
369 }
370 }
371 _ = self.shutdown.request.receiver.recv().fuse() => {
372 self.receiver_channel.send(Message::Close).await?;
373 self.shutdown.response.sender.send(()).await?;
374 break;
375 }
376 }
377 }
378
379 #[cfg(feature = "delay-reconnect")]
381 if closed_ungracefully && connection_start.elapsed().as_millis() < 1_000 {
382 workflow_core::task::sleep(_options.retry_interval()).await;
383 }
384
385 Ok(())
386 }
387
388 pub async fn close(self: &Arc<Self>) -> Result<()> {
389 if self.is_connected.load(Ordering::SeqCst) {
391 self.shutdown
393 .request
394 .sender
395 .send(())
396 .await
397 .unwrap_or_else(|err| {
398 log_error!("Unable to signal WebSocket dispatcher shutdown: {}", err)
399 });
400 self.shutdown
401 .response
402 .receiver
403 .recv()
404 .await
405 .unwrap_or_else(|err| {
406 log_error!("Unable to receive WebSocket dispatcher shutdown: {}", err)
407 });
408 }
409
410 Ok(())
411 }
412
413 pub async fn disconnect(self: &Arc<Self>) -> Result<()> {
414 self.reconnect.store(false, Ordering::SeqCst);
415 self.close().await?;
416 Ok(())
417 }
418
419 pub fn trigger_abort(self: &Arc<Self>) -> Result<()> {
420 if self.is_connected.load(Ordering::SeqCst) {
421 self.receiver_channel.try_send(Message::Close)?;
422 }
423 Ok(())
424 }
425}