1use std::convert::Infallible;
9use std::fmt::Debug;
10use std::marker::PhantomData;
11use std::pin::Pin;
12use std::sync::Arc;
13use std::sync::Mutex;
14
15use futures::channel::oneshot::channel;
16use futures::channel::oneshot::Canceled;
17use futures::channel::oneshot::Sender;
18use futures::select_biased;
19use futures::sink::SinkExt as _;
20use futures::stream::FusedStream;
21use futures::task::Context;
22use futures::task::Poll;
23use futures::Future;
24use futures::FutureExt as _;
25use futures::Sink;
26use futures::Stream;
27use futures::StreamExt as _;
28
29
30#[derive(Debug)]
32pub enum Classification<U, C> {
33 UserMessage(U),
37 ControlMessage(C),
40}
41
42
43pub trait Message {
46 type UserMessage;
48 type ControlMessage;
50
51 fn classify(self) -> Classification<Self::UserMessage, Self::ControlMessage>;
54
55 fn is_error(user_message: &Self::UserMessage) -> bool;
59}
60
61
62type SharedState<M> = Arc<Mutex<Option<Sender<Option<Result<M, ()>>>>>>;
67
68
69#[derive(Debug)]
71pub struct MessageStream<S, M>
72where
73 M: Message,
74{
75 stream: S,
77 shared: SharedState<M::ControlMessage>,
79}
80
81impl<S, M> MessageStream<S, M>
82where
83 M: Message,
84{
85 fn inform_subscription(
88 shared: &SharedState<M::ControlMessage>,
89 message: Option<Result<M::ControlMessage, ()>>,
90 ) {
91 let sender = shared
92 .lock()
93 .map_err(|err| err.into_inner())
94 .unwrap_or_else(|err| err)
95 .take();
96
97 if let Some(sender) = sender {
98 let _ = sender.send(message);
103 }
104 }
105}
106
107impl<S, M> Stream for MessageStream<S, M>
108where
109 S: Stream<Item = M> + Unpin,
110 M: Message,
111{
112 type Item = M::UserMessage;
113
114 fn poll_next(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
115 let shared = self.shared.clone();
116 let this = self.get_mut();
117
118 loop {
119 match this.stream.poll_next_unpin(ctx) {
120 Poll::Pending => {
121 break Poll::Pending
124 },
125 Poll::Ready(None) => {
126 Self::inform_subscription(&shared, None);
129 break Poll::Ready(None)
130 },
131 Poll::Ready(Some(message)) => {
132 match message.classify() {
133 Classification::UserMessage(user_message) => {
134 if M::is_error(&user_message) {
135 Self::inform_subscription(&shared, Some(Err(())));
142 }
143 break Poll::Ready(Some(user_message))
146 },
147 Classification::ControlMessage(control_message) => {
148 Self::inform_subscription(&shared, Some(Ok(control_message)));
153 },
154 }
155 },
156 }
157 }
158 }
159
160 fn size_hint(&self) -> (usize, Option<usize>) {
161 Stream::size_hint(&self.stream)
162 }
163}
164
165impl<S, M> FusedStream for MessageStream<S, M>
166where
167 S: FusedStream<Item = M> + Unpin,
168 M: Message,
169{
170 #[inline]
171 fn is_terminated(&self) -> bool {
172 self.stream.is_terminated()
173 }
174}
175
176
177#[derive(Debug)]
187pub struct Subscription<S, M, I>
188where
189 M: Message,
190{
191 sink: S,
193 shared: SharedState<M::ControlMessage>,
195 _phantom: PhantomData<I>,
198}
199
200impl<S, M, I> Subscription<S, M, I>
201where
202 S: Sink<I> + Unpin,
203 M: Message,
204{
205 async fn with_channel<'slf, F, G, E>(
208 &'slf mut self,
209 f: F,
210 ) -> Result<Option<Result<M::ControlMessage, ()>>, E>
211 where
212 F: FnOnce(&'slf mut S) -> G,
213 G: Future<Output = Result<(), E>>,
214 {
215 let (sender, receiver) = channel();
218 let _prev = self
219 .shared
220 .lock()
221 .map_err(|err| err.into_inner())
222 .unwrap_or_else(|err| err)
223 .replace(sender);
224 debug_assert!(_prev.is_none());
225
226 if let Err(err) = f(&mut self.sink).await {
227 let _prev = self
233 .shared
234 .lock()
235 .map_err(|err| err.into_inner())
236 .unwrap_or_else(|err| err)
237 .take();
238 debug_assert!(_prev.is_some());
239 return Err(err)
240 }
241
242 let result = receiver.await;
243 debug_assert!(self
246 .shared
247 .lock()
248 .map_err(|err| err.into_inner())
249 .unwrap_or_else(|err| err)
250 .is_none());
251 Ok(Result::<_, Canceled>::unwrap_or(result, None))
255 }
256
257 pub async fn send(&mut self, item: I) -> Result<Option<Result<M::ControlMessage, ()>>, S::Error> {
267 self
268 .with_channel(|sink| async move { sink.send(item).await })
269 .await
270 }
271
272 pub async fn read(&mut self) -> Option<Result<M::ControlMessage, ()>> {
274 let result = self.with_channel(|_sink| async { Ok(()) }).await;
275
276 Result::<_, Infallible>::unwrap(result)
279 }
280}
281
282
283pub fn subscribe<M, I, St, Si>(
288 stream: St,
289 control_channel: Si,
290) -> (MessageStream<St, M>, Subscription<Si, M, I>)
291where
292 M: Message,
293 St: Stream<Item = M>,
294 Si: Sink<I>,
295{
296 let shared = Arc::new(Mutex::new(None));
297
298 let subscription = Subscription {
299 sink: control_channel,
300 shared: shared.clone(),
301 _phantom: PhantomData,
302 };
303 let message_stream = MessageStream { stream, shared };
304
305 (message_stream, subscription)
306}
307
308
309pub async fn drive<M, F, S>(future: F, stream: &mut S) -> Result<F::Output, M::UserMessage>
316where
317 M: Message,
318 F: Future + Unpin,
319 S: FusedStream<Item = M::UserMessage> + Unpin,
320{
321 let mut future = future.fuse();
322
323 'l: loop {
324 select_biased! {
325 output = future => break 'l Ok(output),
326 user_message = stream.next() => {
327 if let Some(user_message) = user_message {
328 if M::is_error(&user_message) {
329 break 'l Err(user_message)
330 }
331 }
332 },
333 }
334 }
335}
336
337
338#[cfg(test)]
339mod tests {
340 use super::*;
341
342 use futures::channel::mpsc::channel;
343 use futures::stream::iter;
344
345 use test_log::test;
346
347
348 #[derive(Debug)]
350 enum MockMessage<T> {
351 Value(T),
353 Close(u8),
355 }
356
357 impl<T> Message for MockMessage<T> {
358 type UserMessage = T;
359 type ControlMessage = u8;
360
361 fn classify(self) -> Classification<Self::UserMessage, Self::ControlMessage> {
362 match self {
363 MockMessage::Value(x) => Classification::UserMessage(x),
364 MockMessage::Close(x) => Classification::ControlMessage(x),
365 }
366 }
367
368 #[inline]
369 fn is_error(_user_message: &Self::UserMessage) -> bool {
370 false
372 }
373 }
374
375
376 #[test(tokio::test)]
379 async fn send_recv() {
380 let mut it = iter([
381 MockMessage::Value(1u64),
382 MockMessage::Value(2u64),
383 MockMessage::Value(3u64),
384 MockMessage::Close(200),
385 MockMessage::Close(201),
386 MockMessage::Value(4u64),
387 ])
388 .map(Ok);
389
390 let (mut send, recv) = channel::<MockMessage<u64>>(16);
395 let () = send.send_all(&mut it).await.unwrap();
396
397 let (mut message_stream, mut subscription) = subscribe(recv, send);
398 let close = subscription.send(MockMessage::Close(42)).boxed_local();
399 let message = drive::<MockMessage<u64>, _, _>(close, &mut message_stream)
400 .await
401 .unwrap()
402 .unwrap();
403
404 assert_eq!(message, Some(Ok(200)));
407 }
408
409 #[test(tokio::test)]
412 async fn read() {
413 let mut it = iter([
414 MockMessage::Value(1u64),
415 MockMessage::Value(2u64),
416 MockMessage::Value(3u64),
417 MockMessage::Close(200),
418 MockMessage::Close(201),
419 MockMessage::Value(4u64),
420 ])
421 .map(Ok);
422
423 let (mut send, recv) = channel::<MockMessage<u64>>(16);
424 let () = send.send_all(&mut it).await.unwrap();
425
426 let (mut message_stream, mut subscription) = subscribe(recv, send);
427 let close = subscription.read().boxed_local();
428 let message = drive::<MockMessage<u64>, _, _>(close, &mut message_stream)
429 .await
430 .unwrap();
431
432 assert_eq!(message, Some(Ok(200)));
433 }
434
435 #[test(tokio::test)]
438 async fn stream_drop() {
439 let (send, recv) = channel::<MockMessage<u64>>(1);
440
441 let (message_stream, mut subscription) = subscribe(recv, send);
442 drop(message_stream);
443
444 let result = subscription.send(MockMessage::Close(42)).await;
445 assert!(result.is_err());
446
447 let result = subscription.send(MockMessage::Close(41)).await;
450 assert!(result.is_err());
451 }
452
453 #[test(tokio::test)]
456 async fn control_channel_closed() {
457 let (mut send, recv) = channel::<MockMessage<u64>>(1);
458 send.close_channel();
459
460 let (_message_stream, mut subscription) = subscribe(recv, send);
461
462 let result = subscription.send(MockMessage::Close(42)).await;
463 assert!(result.is_err());
464
465 let result = subscription.send(MockMessage::Close(41)).await;
466 assert!(result.is_err());
467 }
468
469 #[test(tokio::test)]
472 async fn stream_processing_with_dropped_subscription() {
473 let mut it = iter([
474 MockMessage::Value(1u64),
475 MockMessage::Close(200),
476 MockMessage::Value(4u64),
477 ])
478 .map(Ok);
479
480 let (mut send, recv) = channel::<MockMessage<u64>>(4);
481 let () = send.send_all(&mut it).await.unwrap();
482
483 let (message_stream, subscription) = subscribe(recv, send);
484 drop(subscription);
485
486 let vec = message_stream.collect::<Vec<_>>().await;
487 assert_eq!(vec, vec![1u64, 4u64]);
488 }
489
490
491 impl<T> Message for Result<Result<MockMessage<T>, String>, u64> {
492 type UserMessage = Result<Result<T, String>, u64>;
493 type ControlMessage = u8;
494
495 fn classify(self) -> Classification<Self::UserMessage, Self::ControlMessage> {
496 match self {
497 Ok(Ok(MockMessage::Value(x))) => Classification::UserMessage(Ok(Ok(x))),
498 Ok(Ok(MockMessage::Close(x))) => Classification::ControlMessage(x),
499 Ok(Err(err)) => Classification::UserMessage(Ok(Err(err))),
502 Err(err) => Classification::UserMessage(Err(err)),
505 }
506 }
507
508 fn is_error(user_message: &Self::UserMessage) -> bool {
509 user_message
511 .as_ref()
512 .map(|inner| inner.is_err())
513 .unwrap_or(false)
514 }
515 }
516
517
518 #[test(tokio::test)]
521 async fn send_recv_with_errors() {
522 let mut it = iter([
523 Ok(Ok(MockMessage::Value(1u64))),
524 Ok(Ok(MockMessage::Value(2u64))),
525 Ok(Ok(MockMessage::Value(3u64))),
526 Ok(Ok(MockMessage::Close(200))),
527 Ok(Ok(MockMessage::Close(201))),
528 Ok(Ok(MockMessage::Value(4u64))),
529 ])
530 .map(Ok);
531
532 let (mut send, recv) = channel::<Result<Result<MockMessage<u64>, String>, u64>>(16);
533 let () = send.send_all(&mut it).await.unwrap();
534
535 let (mut message_stream, mut subscription) = subscribe(recv, send);
536 let close = subscription
537 .send(Ok(Ok(MockMessage::Close(42))))
538 .boxed_local();
539 let message =
540 drive::<Result<Result<MockMessage<u64>, String>, u64>, _, _>(close, &mut message_stream)
541 .await
542 .unwrap()
543 .unwrap();
544
545 assert_eq!(message, Some(Ok(200)));
548 }
549
550 #[test(tokio::test)]
552 async fn inner_error() {
553 let mut it = iter([
554 Ok(Ok(MockMessage::Value(1u64))),
555 Ok(Err("error".to_string())),
556 Ok(Ok(MockMessage::Close(200))),
557 ])
558 .map(Ok);
559
560 let (mut send, recv) = channel::<Result<Result<MockMessage<u64>, String>, u64>>(16);
561 let () = send.send_all(&mut it).await.unwrap();
562
563 let (mut message_stream, mut subscription) = subscribe(recv, send);
564 let close = subscription
565 .send(Ok(Ok(MockMessage::Close(42))))
566 .boxed_local();
567 let message =
568 drive::<Result<Result<MockMessage<u64>, String>, u64>, _, _>(close, &mut message_stream)
569 .await
570 .unwrap_err()
571 .unwrap();
572
573 assert_eq!(message, Err("error".to_string()));
574 }
575
576 #[test(tokio::test)]
578 async fn outer_error() {
579 let mut it = iter([
580 Ok(Ok(MockMessage::Value(1u64))),
581 Err(1337),
582 Ok(Ok(MockMessage::Close(200))),
583 ])
584 .map(Ok);
585
586 let (mut send, recv) = channel::<Result<Result<MockMessage<u64>, String>, u64>>(16);
587 let () = send.send_all(&mut it).await.unwrap();
588
589 let (mut message_stream, mut subscription) = subscribe(recv, send);
590 let close = subscription
591 .send(Ok(Ok(MockMessage::Close(42))))
592 .boxed_local();
593 let message =
594 drive::<Result<Result<MockMessage<u64>, String>, u64>, _, _>(close, &mut message_stream)
595 .await
596 .unwrap()
597 .unwrap();
598
599 assert_eq!(message, Some(Ok(200)));
600 }
601}