websocket_util/
subscribe.rs

1// Copyright (C) 2021-2023 Daniel Mueller (deso@posteo.net)
2// SPDX-License-Identifier: GPL-3.0-or-later
3
4//! A module providing low-level building blocks for controlling a
5//! WebSocket stream with an embedded control channel through an
6//! external subscription object.
7
8use std::convert::Infallible;
9use std::fmt::Debug;
10use std::marker::PhantomData;
11use std::pin::Pin;
12use std::sync::Arc;
13use std::sync::Mutex;
14
15use futures::channel::oneshot::channel;
16use futures::channel::oneshot::Canceled;
17use futures::channel::oneshot::Sender;
18use futures::select_biased;
19use futures::sink::SinkExt as _;
20use futures::stream::FusedStream;
21use futures::task::Context;
22use futures::task::Poll;
23use futures::Future;
24use futures::FutureExt as _;
25use futures::Sink;
26use futures::Stream;
27use futures::StreamExt as _;
28
29
30/// An enum for the possible classifications of a message.
31#[derive(Debug)]
32pub enum Classification<U, C> {
33  /// A user message. When classified as such, the message is directly
34  /// emitted by the [`MessageStream`] and the associated
35  /// [`Subscription`] is not informed.
36  UserMessage(U),
37  /// A control message. Such a message is forwarded to the
38  /// [`Subscription`]. It is never emitted by the [`MessageStream`].
39  ControlMessage(C),
40}
41
42
43/// A trait allowing our stream and subscription infrastructure to work
44/// with messages.
45pub trait Message {
46  /// A message that is relevant to the user.
47  type UserMessage;
48  /// An internally used control message.
49  type ControlMessage;
50
51  /// Classify a message as a user-relevant message or a control
52  /// message.
53  fn classify(self) -> Classification<Self::UserMessage, Self::ControlMessage>;
54
55  /// Check whether a user message is considered an error. Erroneous
56  /// messages cause any ongoing [`Subscription::send`] or
57  /// [`Subscription::read`] requests to result in an error.
58  fn is_error(user_message: &Self::UserMessage) -> bool;
59}
60
61
62/// State shared between the message stream and the subscription.
63///
64/// This state is an optional one-shot channel that is set by the
65/// subscription whenever it expects to receive a control message.
66type SharedState<M> = Arc<Mutex<Option<Sender<Option<Result<M, ()>>>>>>;
67
68
69/// A stream of messages that is associated with a [`Subscription`].
70#[derive(Debug)]
71pub struct MessageStream<S, M>
72where
73  M: Message,
74{
75  /// The internally used stream.
76  stream: S,
77  /// State shared between the message stream and the subscription.
78  shared: SharedState<M::ControlMessage>,
79}
80
81impl<S, M> MessageStream<S, M>
82where
83  M: Message,
84{
85  /// Inform the associated [`Subscription`] about a message, if it is
86  /// registered to receive such notifications.
87  fn inform_subscription(
88    shared: &SharedState<M::ControlMessage>,
89    message: Option<Result<M::ControlMessage, ()>>,
90  ) {
91    let sender = shared
92      .lock()
93      .map_err(|err| err.into_inner())
94      .unwrap_or_else(|err| err)
95      .take();
96
97    if let Some(sender) = sender {
98      // If the `Subscription` registered a `Sender`, use it to send the
99      // provided message. If delivery failed the `Subscription` object
100      // has already been dropped. That is fine and we will just ignore
101      // the error.
102      let _ = sender.send(message);
103    }
104  }
105}
106
107impl<S, M> Stream for MessageStream<S, M>
108where
109  S: Stream<Item = M> + Unpin,
110  M: Message,
111{
112  type Item = M::UserMessage;
113
114  fn poll_next(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
115    let shared = self.shared.clone();
116    let this = self.get_mut();
117
118    loop {
119      match this.stream.poll_next_unpin(ctx) {
120        Poll::Pending => {
121          // No new data is available yet. There is nothing to do for us
122          // except bubble up this result.
123          break Poll::Pending
124        },
125        Poll::Ready(None) => {
126          // The connection got terminated. We need to convey that to
127          // the `Subscription`.
128          Self::inform_subscription(&shared, None);
129          break Poll::Ready(None)
130        },
131        Poll::Ready(Some(message)) => {
132          match message.classify() {
133            Classification::UserMessage(user_message) => {
134              if M::is_error(&user_message) {
135                // If message classification deduced an error, inform
136                // the subscription about that fact (to unblock
137                // requests) but yield the actual error via the message
138                // stream (note that we cannot assume that errors are
139                // cloneable and that's why we only indicate *that* an
140                // error occurred to the subscription).
141                Self::inform_subscription(&shared, Some(Err(())));
142              }
143              // The `Subscription` is oblivious to user messages,
144              // so just return it.
145              break Poll::Ready(Some(user_message))
146            },
147            Classification::ControlMessage(control_message) => {
148              // We encountered a control message. Push it to the
149              // subscription and then just continue polling.
150              // Clients of the message stream do not care about
151              // these.
152              Self::inform_subscription(&shared, Some(Ok(control_message)));
153            },
154          }
155        },
156      }
157    }
158  }
159
160  fn size_hint(&self) -> (usize, Option<usize>) {
161    Stream::size_hint(&self.stream)
162  }
163}
164
165impl<S, M> FusedStream for MessageStream<S, M>
166where
167  S: FusedStream<Item = M> + Unpin,
168  M: Message,
169{
170  #[inline]
171  fn is_terminated(&self) -> bool {
172    self.stream.is_terminated()
173  }
174}
175
176
177/// A subscription associated with a [`MessageStream`] that allows for
178/// sending and receiving control messages over it.
179///
180/// # Notes
181/// - in order for any [`send`][Subscription::send] or
182///   [`read`][Subscription::read] operations to resolve, the associated
183///   [`MessageStream`] stream needs to be polled; that is necessary
184///   because this operation expects a control message response and that
185///   control message comes through the regular stream.
186#[derive(Debug)]
187pub struct Subscription<S, M, I>
188where
189  M: Message,
190{
191  /// A sink to which we can send control messages.
192  sink: S,
193  /// State shared between the subscription and the message stream.
194  shared: SharedState<M::ControlMessage>,
195  /// Phantom data for our sink's item type, which does not have to be
196  /// the actual control message type.
197  _phantom: PhantomData<I>,
198}
199
200impl<S, M, I> Subscription<S, M, I>
201where
202  S: Sink<I> + Unpin,
203  M: Message,
204{
205  /// Install a one-shot channel, run a function being passed our sink,
206  /// and then wait for a message being received through the channel.
207  async fn with_channel<'slf, F, G, E>(
208    &'slf mut self,
209    f: F,
210  ) -> Result<Option<Result<M::ControlMessage, ()>>, E>
211  where
212    F: FnOnce(&'slf mut S) -> G,
213    G: Future<Output = Result<(), E>>,
214  {
215    // Create a one-shot channel and register it with the message stream
216    // via our shared state.
217    let (sender, receiver) = channel();
218    let _prev = self
219      .shared
220      .lock()
221      .map_err(|err| err.into_inner())
222      .unwrap_or_else(|err| err)
223      .replace(sender);
224    debug_assert!(_prev.is_none());
225
226    if let Err(err) = f(&mut self.sink).await {
227      // We are about to exit early, because we failed to send our
228      // message over the control channel. Make sure to clean up the
229      // shared state that we installed earlier, so that the invariant
230      // that we always enter this function with a `None` in `shared` is
231      // preserved.
232      let _prev = self
233        .shared
234        .lock()
235        .map_err(|err| err.into_inner())
236        .unwrap_or_else(|err| err)
237        .take();
238      debug_assert!(_prev.is_some());
239      return Err(err)
240    }
241
242    let result = receiver.await;
243    // Our `MessageStream` type should make sure to "take" the sender
244    // that we put in.
245    debug_assert!(self
246      .shared
247      .lock()
248      .map_err(|err| err.into_inner())
249      .unwrap_or_else(|err| err)
250      .is_none());
251    // The only reason for getting back an `Err` here is if the sender
252    // got dropped. That should never happen (because we control it),
253    // but we map that to a `None`, just in case.
254    Ok(Result::<_, Canceled>::unwrap_or(result, None))
255  }
256
257  /// Send a message over the internal control channel and wait a
258  /// control message response.
259  ///
260  /// The method returns the following errors:
261  /// - `Err(..)` when the sink failed to send an item
262  /// - `Ok(None)` when the message stream got closed
263  /// - `Ok(Some(Err(())))` when message classification reported an
264  ///    error; the actual error still manifests through the message
265  ///    stream
266  pub async fn send(&mut self, item: I) -> Result<Option<Result<M::ControlMessage, ()>>, S::Error> {
267    self
268      .with_channel(|sink| async move { sink.send(item).await })
269      .await
270  }
271
272  /// Wait for a control message to arrive.
273  pub async fn read(&mut self) -> Option<Result<M::ControlMessage, ()>> {
274    let result = self.with_channel(|_sink| async { Ok(()) }).await;
275
276    // It's fine to unwrap here because we statically guarantee that an
277    // error can never occur.
278    Result::<_, Infallible>::unwrap(result)
279  }
280}
281
282
283/// Wrap a stream and an associated control channel into a connected
284/// ([`MessageStream`], [`Subscription`]) pair, in which the
285/// subscription can be used to send and receive control messages over
286/// the stream.
287pub fn subscribe<M, I, St, Si>(
288  stream: St,
289  control_channel: Si,
290) -> (MessageStream<St, M>, Subscription<Si, M, I>)
291where
292  M: Message,
293  St: Stream<Item = M>,
294  Si: Sink<I>,
295{
296  let shared = Arc::new(Mutex::new(None));
297
298  let subscription = Subscription {
299    sink: control_channel,
300    shared: shared.clone(),
301    _phantom: PhantomData,
302  };
303  let message_stream = MessageStream { stream, shared };
304
305  (message_stream, subscription)
306}
307
308
309/// Helper function to drive a [`Subscription`] related future to
310/// completion. The function makes sure to poll the provided stream,
311/// which is assumed to be associated with the `Subscription` that the
312/// future belongs to, so that control messages can be received. Errors
313/// reported by the stream (identified via [`Message::is_error`]) short
314/// circuit and fail the operation immediately.
315pub async fn drive<M, F, S>(future: F, stream: &mut S) -> Result<F::Output, M::UserMessage>
316where
317  M: Message,
318  F: Future + Unpin,
319  S: FusedStream<Item = M::UserMessage> + Unpin,
320{
321  let mut future = future.fuse();
322
323  'l: loop {
324    select_biased! {
325      output = future => break 'l Ok(output),
326      user_message = stream.next() => {
327        if let Some(user_message) = user_message {
328          if M::is_error(&user_message) {
329            break 'l Err(user_message)
330          }
331        }
332      },
333    }
334  }
335}
336
337
338#[cfg(test)]
339mod tests {
340  use super::*;
341
342  use futures::channel::mpsc::channel;
343  use futures::stream::iter;
344
345  use test_log::test;
346
347
348  /// A "dummy" message type used for testing.
349  #[derive(Debug)]
350  enum MockMessage<T> {
351    /// The actual user visible message.
352    Value(T),
353    /// A "control" message.
354    Close(u8),
355  }
356
357  impl<T> Message for MockMessage<T> {
358    type UserMessage = T;
359    type ControlMessage = u8;
360
361    fn classify(self) -> Classification<Self::UserMessage, Self::ControlMessage> {
362      match self {
363        MockMessage::Value(x) => Classification::UserMessage(x),
364        MockMessage::Close(x) => Classification::ControlMessage(x),
365      }
366    }
367
368    #[inline]
369    fn is_error(_user_message: &Self::UserMessage) -> bool {
370      // In this implementation there are no errors.
371      false
372    }
373  }
374
375
376  /// Check that we can send a message through a `Subscription` and
377  /// receive back the expected control message response.
378  #[test(tokio::test)]
379  async fn send_recv() {
380    let mut it = iter([
381      MockMessage::Value(1u64),
382      MockMessage::Value(2u64),
383      MockMessage::Value(3u64),
384      MockMessage::Close(200),
385      MockMessage::Close(201),
386      MockMessage::Value(4u64),
387    ])
388    .map(Ok);
389
390    // Note that for this test the channel's capacity must be greater or
391    // equal to the number of messages we list above, otherwise the
392    // `send_all` below may just deadlock, because nobody is draining
393    // the receiver yet.
394    let (mut send, recv) = channel::<MockMessage<u64>>(16);
395    let () = send.send_all(&mut it).await.unwrap();
396
397    let (mut message_stream, mut subscription) = subscribe(recv, send);
398    let close = subscription.send(MockMessage::Close(42)).boxed_local();
399    let message = drive::<MockMessage<u64>, _, _>(close, &mut message_stream)
400      .await
401      .unwrap()
402      .unwrap();
403
404    // We should have received back the first "control" message that we
405    // fed into the stream, which has the payload 200.
406    assert_eq!(message, Some(Ok(200)));
407  }
408
409  /// Check that we can wait for a control message without sending
410  /// anything beforehand.
411  #[test(tokio::test)]
412  async fn read() {
413    let mut it = iter([
414      MockMessage::Value(1u64),
415      MockMessage::Value(2u64),
416      MockMessage::Value(3u64),
417      MockMessage::Close(200),
418      MockMessage::Close(201),
419      MockMessage::Value(4u64),
420    ])
421    .map(Ok);
422
423    let (mut send, recv) = channel::<MockMessage<u64>>(16);
424    let () = send.send_all(&mut it).await.unwrap();
425
426    let (mut message_stream, mut subscription) = subscribe(recv, send);
427    let close = subscription.read().boxed_local();
428    let message = drive::<MockMessage<u64>, _, _>(close, &mut message_stream)
429      .await
430      .unwrap();
431
432    assert_eq!(message, Some(Ok(200)));
433  }
434
435  /// Check that `Subscription::send` behaves correctly if the
436  /// associated message stream gets dropped.
437  #[test(tokio::test)]
438  async fn stream_drop() {
439    let (send, recv) = channel::<MockMessage<u64>>(1);
440
441    let (message_stream, mut subscription) = subscribe(recv, send);
442    drop(message_stream);
443
444    let result = subscription.send(MockMessage::Close(42)).await;
445    assert!(result.is_err());
446
447    // Try another time to make sure that we do not leave any state
448    // invariant-violating state behind.
449    let result = subscription.send(MockMessage::Close(41)).await;
450    assert!(result.is_err());
451  }
452
453  /// Check that `Subscription::send` behaves correctly if the
454  /// underlying control channel gets closed.
455  #[test(tokio::test)]
456  async fn control_channel_closed() {
457    let (mut send, recv) = channel::<MockMessage<u64>>(1);
458    send.close_channel();
459
460    let (_message_stream, mut subscription) = subscribe(recv, send);
461
462    let result = subscription.send(MockMessage::Close(42)).await;
463    assert!(result.is_err());
464
465    let result = subscription.send(MockMessage::Close(41)).await;
466    assert!(result.is_err());
467  }
468
469  /// Check that a `MessageStream` behaves correctly if its associated
470  /// `Subscription` has been dropped.
471  #[test(tokio::test)]
472  async fn stream_processing_with_dropped_subscription() {
473    let mut it = iter([
474      MockMessage::Value(1u64),
475      MockMessage::Close(200),
476      MockMessage::Value(4u64),
477    ])
478    .map(Ok);
479
480    let (mut send, recv) = channel::<MockMessage<u64>>(4);
481    let () = send.send_all(&mut it).await.unwrap();
482
483    let (message_stream, subscription) = subscribe(recv, send);
484    drop(subscription);
485
486    let vec = message_stream.collect::<Vec<_>>().await;
487    assert_eq!(vec, vec![1u64, 4u64]);
488  }
489
490
491  impl<T> Message for Result<Result<MockMessage<T>, String>, u64> {
492    type UserMessage = Result<Result<T, String>, u64>;
493    type ControlMessage = u8;
494
495    fn classify(self) -> Classification<Self::UserMessage, Self::ControlMessage> {
496      match self {
497        Ok(Ok(MockMessage::Value(x))) => Classification::UserMessage(Ok(Ok(x))),
498        Ok(Ok(MockMessage::Close(x))) => Classification::ControlMessage(x),
499        // Inner errors (e.g., JSON errors) are directly reported as
500        // errors.
501        Ok(Err(err)) => Classification::UserMessage(Ok(Err(err))),
502        // We push through outer errors (simulating WebSocket errors) as
503        // user messages.
504        Err(err) => Classification::UserMessage(Err(err)),
505      }
506    }
507
508    fn is_error(user_message: &Self::UserMessage) -> bool {
509      // We only report inner errors as errors for the sake of testing.
510      user_message
511        .as_ref()
512        .map(|inner| inner.is_err())
513        .unwrap_or(false)
514    }
515  }
516
517
518  /// Make sure that event with nested errors sending and receiving type
519  /// checks and works.
520  #[test(tokio::test)]
521  async fn send_recv_with_errors() {
522    let mut it = iter([
523      Ok(Ok(MockMessage::Value(1u64))),
524      Ok(Ok(MockMessage::Value(2u64))),
525      Ok(Ok(MockMessage::Value(3u64))),
526      Ok(Ok(MockMessage::Close(200))),
527      Ok(Ok(MockMessage::Close(201))),
528      Ok(Ok(MockMessage::Value(4u64))),
529    ])
530    .map(Ok);
531
532    let (mut send, recv) = channel::<Result<Result<MockMessage<u64>, String>, u64>>(16);
533    let () = send.send_all(&mut it).await.unwrap();
534
535    let (mut message_stream, mut subscription) = subscribe(recv, send);
536    let close = subscription
537      .send(Ok(Ok(MockMessage::Close(42))))
538      .boxed_local();
539    let message =
540      drive::<Result<Result<MockMessage<u64>, String>, u64>, _, _>(close, &mut message_stream)
541        .await
542        .unwrap()
543        .unwrap();
544
545    // We should have received back the first "control" message that we
546    // fed into the stream, which has the payload 200.
547    assert_eq!(message, Some(Ok(200)));
548  }
549
550  /// Check that inner errors are pushed to the subscription properly.
551  #[test(tokio::test)]
552  async fn inner_error() {
553    let mut it = iter([
554      Ok(Ok(MockMessage::Value(1u64))),
555      Ok(Err("error".to_string())),
556      Ok(Ok(MockMessage::Close(200))),
557    ])
558    .map(Ok);
559
560    let (mut send, recv) = channel::<Result<Result<MockMessage<u64>, String>, u64>>(16);
561    let () = send.send_all(&mut it).await.unwrap();
562
563    let (mut message_stream, mut subscription) = subscribe(recv, send);
564    let close = subscription
565      .send(Ok(Ok(MockMessage::Close(42))))
566      .boxed_local();
567    let message =
568      drive::<Result<Result<MockMessage<u64>, String>, u64>, _, _>(close, &mut message_stream)
569        .await
570        .unwrap_err()
571        .unwrap();
572
573    assert_eq!(message, Err("error".to_string()));
574  }
575
576  /// Check that outer errors are ignored by the subscription.
577  #[test(tokio::test)]
578  async fn outer_error() {
579    let mut it = iter([
580      Ok(Ok(MockMessage::Value(1u64))),
581      Err(1337),
582      Ok(Ok(MockMessage::Close(200))),
583    ])
584    .map(Ok);
585
586    let (mut send, recv) = channel::<Result<Result<MockMessage<u64>, String>, u64>>(16);
587    let () = send.send_all(&mut it).await.unwrap();
588
589    let (mut message_stream, mut subscription) = subscribe(recv, send);
590    let close = subscription
591      .send(Ok(Ok(MockMessage::Close(42))))
592      .boxed_local();
593    let message =
594      drive::<Result<Result<MockMessage<u64>, String>, u64>, _, _>(close, &mut message_stream)
595        .await
596        .unwrap()
597        .unwrap();
598
599    assert_eq!(message, Some(Ok(200)));
600  }
601}