1use std::{convert::TryInto, pin::Pin, time::Duration};
2
3use futures::{future::join, Future, FutureExt};
4use tokio::{
5 sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender},
6 task::yield_now,
7};
8
9pub trait ErrorType: Send + std::fmt::Debug + 'static {}
10
11impl<T: Send + std::fmt::Debug + 'static> ErrorType for T {}
12
13pub type InSender<E> = UnboundedSender<Result<Vec<u8>, E>>;
14pub type InReceiver<E> = UnboundedReceiver<Result<Vec<u8>, E>>;
15
16pub type OutReceiver = UnboundedReceiver<Vec<u8>>;
17pub type OutSender = UnboundedSender<Vec<u8>>;
18
19pub trait SessionFactory<E: ErrorType>:
21 FnOnce(InReceiver<E>, OutSender) -> Pin<Box<dyn Future<Output = Result<(), E>> + Send>>
22{
23}
24
25impl<
26 E: ErrorType,
27 F: FnOnce(InReceiver<E>, OutSender) -> Pin<Box<dyn Future<Output = Result<(), E>> + Send>>,
28 > SessionFactory<E> for F
29{
30}
31
32pub trait BehaviourFunction<E: ErrorType>:
39 for<'a> FnOnce(
40 &'a mut InSender<E>,
41 &'a mut OutReceiver,
42 ) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>>
43 + Send
44{
45}
46
47impl<E: ErrorType, T> BehaviourFunction<E> for T where
48 for<'a> T: FnOnce(
49 &'a mut InSender<E>,
50 &'a mut OutReceiver,
51 ) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>>
52 + Send
53{
54}
55
56pub trait Chainable<E: ErrorType>: BehaviourFunction<E> + Sized {
58 fn then<S: BehaviourFunction<E> + 'static>(
60 self,
61 followed_by: S,
62 ) -> Box<dyn BehaviourFunction<E>>;
63}
64
65impl<E: ErrorType, T: BehaviourFunction<E> + 'static> Chainable<E> for T {
66 fn then<S: BehaviourFunction<E> + 'static>(
67 self,
68 followed_by: S,
69 ) -> Box<dyn BehaviourFunction<E>> {
70 Box::new(|sender: &mut InSender<E>, receiver: &mut OutReceiver| {
71 async move {
72 self(sender, receiver).await;
73 followed_by(sender, receiver).await
74 }
75 .boxed()
76 })
77 }
78}
79
80pub fn into_behaviour<E, C>(closure: C) -> impl BehaviourFunction<E>
82where
83 E: ErrorType,
84 C: for<'a> FnOnce(
85 &'a mut InSender<E>,
86 &'a mut OutReceiver,
87 ) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>>
88 + Send
89 + 'static,
90{
91 closure
92}
93
94pub async fn assert_behaviour<
99 E: ErrorType,
100 F: SessionFactory<E>,
101 T: BehaviourFunction<E> + 'static,
102>(
103 session_factory: F,
104 behaviour: T,
105) {
106 let (in_sender, in_receiver) = unbounded_channel::<Result<Vec<u8>, E>>();
107 let (out_sender, out_receiver) = unbounded_channel();
108
109 let session_future = session_factory(in_receiver, out_sender);
110
111 let other_future = tokio::task::spawn(async move {
112 let mut out_receiver = out_receiver;
113 let mut in_sender = in_sender;
114
115 behaviour(&mut in_sender, &mut out_receiver).await;
116 out_receiver.close();
117 drop(in_sender);
118 ()
119 });
120
121 let results = join(session_future, other_future).await;
122
123 results.0.expect("Session returned error");
124 results.1.expect("Behaviour returned error");
125}
126
127pub fn send<T: Into<Vec<u8>> + Send + 'static, E: ErrorType>(
129 data: T,
130) -> impl BehaviourFunction<E> {
131 into_behaviour(move |in_sender: &mut InSender<E>, _: &mut OutReceiver| {
132 send_data(in_sender, data);
133 yield_now().boxed()
134 })
135}
136
137pub fn send_data<T: Into<Vec<u8>>, E: ErrorType>(
140 sender: &InSender<E>,
141 data: T,
142) {
143 sender
144 .send(Ok(data.into()))
145 .expect("Send failed");
146}
147
148pub fn assert_receive<T: FnOnce(Vec<u8>) -> bool>(out_receiver: &mut OutReceiver, predicate: T) {
151 let response = out_receiver.recv().now_or_never();
152
153 assert!(predicate(
154 response
155 .expect("No message from session") .expect("Session closed") ));
158}
159
160pub fn receive<E: ErrorType, T: FnOnce(Vec<u8>) -> bool + Send + 'static>(
163 predicate: T,
164) -> impl BehaviourFunction<E> {
165 into_behaviour(|_: &mut InSender<E>, out_receiver: &mut OutReceiver| {
166 async move {
167 yield_now().await;
168 assert_receive(out_receiver, predicate)
169 }
170 .boxed()
171 })
172}
173
174pub fn sleep_in_pause(millis: u64) -> impl Future<Output = ()> {
177 tokio::time::pause();
178 tokio::time::sleep(Duration::from_millis(millis)).inspect(|_| tokio::time::resume())
179}
180
181pub fn wait_for_disconnect<'a, E: ErrorType>(
182 _: &'a mut InSender<E>,
183 out_receiver: &'a mut OutReceiver,
184) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>> {
185 async move {
186 sleep_in_pause(5050).await;
187
188 assert!(matches!(out_receiver.recv().now_or_never(), Some(None)));
189 ()
190 }
191 .boxed()
192}
193
194#[cfg(test)]
195mod test {
196
197 use std::{
198 any::Any,
199 convert::Infallible,
200 panic::AssertUnwindSafe,
201 sync::{
202 atomic::{AtomicBool, Ordering},
203 Arc,
204 },
205 };
206
207 use std::panic;
208
209 use tokio::join;
210
211 use super::*;
212
213 #[tokio::test]
214 async fn chaining_works() {
215 let (mut tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
216 let (_, mut out_rx) = tokio::sync::mpsc::unbounded_channel();
217
218 let behaviour_a = into_behaviour(|sender: &mut InSender<()>, _: &mut OutReceiver| {
219 async move {
220 sender.send(Ok(b"foo".to_vec())).expect("Send failed");
221 ()
222 }
223 .boxed()
224 });
225
226 let behaviour_b = into_behaviour(|sender: &mut InSender<()>, _: &mut OutReceiver| {
227 async move {
228 sender.send(Ok(b"bar".to_vec())).expect("Send failed");
229 ()
230 }
231 .boxed()
232 });
233
234 behaviour_a.then(behaviour_b)(&mut tx, &mut out_rx).await;
236
237 drop(tx);
239
240 assert_eq!(
241 "foo",
242 String::from_utf8(
243 rx.recv()
244 .now_or_never()
245 .unwrap()
246 .unwrap()
247 .expect("recv failed"),
248 )
249 .expect("Parse failed"),
250 );
251 assert_eq!(
252 "bar",
253 String::from_utf8(
254 rx.recv()
255 .now_or_never()
256 .unwrap()
257 .unwrap()
258 .expect("recv failed"),
259 )
260 .expect("Parse failed"),
261 );
262
263 assert_eq!(None, rx.recv().now_or_never().unwrap());
264 }
265
266 #[derive(Debug, PartialEq, Eq)]
267 struct TestError;
268
269 impl From<Infallible> for TestError {
270 fn from(_: Infallible) -> Self {
271 TestError
272 }
273 }
274
275 #[tokio::test]
276 async fn send_works() {
277 let (mut tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
278 let (_, mut out_rx) = tokio::sync::mpsc::unbounded_channel();
279
280 let behaviour = send::<String, TestError>("Hello".to_owned());
281
282 behaviour(&mut tx, &mut out_rx);
283
284 drop(tx);
286
287 assert_eq!(
288 "Hello",
289 String::from_utf8(
290 rx.recv()
291 .now_or_never()
292 .unwrap()
293 .unwrap()
294 .expect("recv failed"),
295 )
296 .expect("Parse failed"),
297 );
298
299 assert_eq!(None, rx.recv().now_or_never().unwrap());
300 }
301
302 struct TestData;
303
304 #[derive(Debug, PartialEq, Eq)]
305 struct TestError2;
306
307 impl From<TestError2> for TestError {
308 fn from(_: TestError2) -> Self {
309 TestError
310 }
311 }
312 impl Into<Vec<u8>> for TestData {
313
314 fn into(self) -> Vec<u8> {
315 vec![]
316 }
317 }
318
319 #[tokio::test]
320 async fn send_handles_conversion_error() {
321 let (mut tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
322 let (_, mut out_rx) = tokio::sync::mpsc::unbounded_channel();
323
324 let behaviour = send::<TestData, TestError>(TestData);
325
326 behaviour(&mut tx, &mut out_rx);
327
328 drop(tx);
330
331 assert_eq!(Err(TestError), rx.recv().now_or_never().unwrap().unwrap());
332
333 assert_eq!(None, rx.recv().now_or_never().unwrap());
334 }
335
336 #[tokio::test]
337 async fn send_yields() {
338 let (mut tx, mut rx) = tokio::sync::mpsc::unbounded_channel::<Result<Vec<u8>, TestError>>();
339 let (_, mut out_rx) = tokio::sync::mpsc::unbounded_channel();
340
341 let handle = tokio::task::spawn(async move {
342 let first = rx.recv().await;
343
344 assert_eq!(
345 "Hello",
346 String::from_utf8(first.unwrap().expect("recv failed"),).expect("Parse failed"),
347 );
348
349 assert_eq!(None, rx.recv().now_or_never());
351
352 let second = rx.recv().await;
353
354 assert_eq!(
355 "world",
356 String::from_utf8(second.unwrap().expect("recv failed"),).expect("Parse failed"),
357 );
358 });
359
360 let behaviour = send::<String, TestError>("Hello".to_owned())
361 .then(send::<String, TestError>("world".to_owned()));
362
363 let x = join!(handle, behaviour(&mut tx, &mut out_rx));
364
365 assert!(x.0.is_ok());
366 }
367
368 #[tokio::test]
369 async fn send_data_sends() {
370 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::<Result<Vec<u8>, TestError>>();
371
372 send_data(&tx, b"1".to_vec());
373
374 assert_eq!(
375 b"1".to_vec(),
376 rx.recv()
377 .await
378 .expect("Should be Some")
379 .expect("Should be ok")
380 );
381 }
382
383 #[tokio::test]
384 async fn test_expectations_executes_behaviour() {
385 let called = Arc::new(AtomicBool::new(false));
386
387 let session_factory = |_: InReceiver<TestError>, _: OutSender| async { Ok(()) }.boxed();
388 let client_behaviour = into_behaviour({
389 let called = called.clone();
390 |_, _| async move { called.store(true, Ordering::Release) }.boxed()
391 });
392
393 assert_behaviour(session_factory, client_behaviour).await;
394
395 assert_eq!(true, called.load(Ordering::Acquire));
396 }
397
398 #[tokio::test]
399 async fn test_expectations_starts_session() {
400 let called = Arc::new(AtomicBool::new(false));
401
402 let session_factory = {
403 let called = called.clone();
404 |_: InReceiver<TestError>, _: OutSender| {
405 async move { Ok(called.store(true, Ordering::Release)) }.boxed()
406 }
407 };
408 let client_behaviour = into_behaviour(|_, _| async {}.boxed());
409
410 assert_behaviour(session_factory, client_behaviour).await;
411
412 assert_eq!(true, called.load(Ordering::Acquire));
413 }
414
415 #[tokio::test]
416 async fn test_expectations_sends_in_to_session() {
417 let session_factory = {
418 |mut receiver: InReceiver<TestError>, _: OutSender| {
419 async move {
420 match receiver.recv().await {
421 Some(Ok(x)) if x == vec![42u8] => Ok(()),
422 _ => Err(TestError),
423 }
424 }
425 .boxed()
426 }
427 };
428 let client_behaviour = into_behaviour(|in_sender, _| {
429 async move {
430 in_sender.send(Ok(vec![42u8])).expect("Send failed");
431 }
432 .boxed()
433 });
434
435 assert_behaviour(session_factory, client_behaviour).await;
436 }
437
438 #[tokio::test]
439 async fn test_expectations_receives_out_from_session() {
440 let session_factory = {
441 |_, sender: OutSender| {
442 async move { sender.send(vec![42u8]).map_err(|_| TestError) }.boxed()
443 }
444 };
445 let client_behaviour =
446 into_behaviour::<TestError, _>(|_, out_receiver: &mut OutReceiver| {
447 async move {
448 assert!(matches!(out_receiver.recv().await, Some(x) if x == vec![42u8]));
449 }
450 .boxed()
451 });
452
453 assert_behaviour(session_factory, client_behaviour).await;
454 }
455
456 fn assert_unwind_safe<O, F: Future<Output = O>>(
457 future: F,
458 ) -> impl Future<Output = Result<O, Box<dyn Any + std::marker::Send>>> {
459 panic::set_hook(Box::new(|_info| {
460 }));
462
463 AssertUnwindSafe(future).catch_unwind()
464 }
465
466 async fn assert_beaviour_test_result<
467 E: ErrorType,
468 F: SessionFactory<E> + 'static,
469 B: BehaviourFunction<E> + 'static,
470 >(
471 session_factory: F,
472 behaviour: B,
473 expect_err: bool,
474 ) {
475 panic::set_hook(Box::new(|_info| {
476 }));
478 let result = assert_unwind_safe(assert_behaviour(session_factory, behaviour)).await;
479
480 assert_eq!(expect_err, result.is_err());
481 }
482
483 async fn assert_beaviour_test_fails<
484 E: ErrorType,
485 F: SessionFactory<E> + 'static,
486 B: BehaviourFunction<E> + 'static,
487 >(
488 session_factory: F,
489 behaviour: B,
490 ) {
491 assert_beaviour_test_result(session_factory, behaviour, true).await;
492 }
493
494 async fn assert_beaviour_test_succeeds<
495 E: ErrorType,
496 F: SessionFactory<E> + 'static,
497 B: BehaviourFunction<E> + 'static,
498 >(
499 session_factory: F,
500 behaviour: B,
501 ) {
502 assert_beaviour_test_result(session_factory, behaviour, false).await;
503 }
504
505 #[tokio::test]
506 async fn test_expectations_fails_if_error_in_session() {
507 let session_factory = { |_, _| async move { Err(TestError) }.boxed() };
508 let client_behaviour = into_behaviour::<TestError, _>(|_, _| async move {}.boxed());
509
510 assert_beaviour_test_fails(session_factory, client_behaviour).await;
511 }
512
513 #[tokio::test]
514 async fn test_expectations_fails_if_error_in_behaviour() {
515 let session_factory = |_, _| async move { Ok(()) }.boxed();
516 let client_behaviour = into_behaviour::<TestError, _>(|_, _| {
517 async move {
518 assert!(false);
519 }
520 .boxed()
521 });
522
523 assert_beaviour_test_fails(session_factory, client_behaviour).await;
524 }
525
526 #[tokio::test]
527 async fn test_expectations_succeeds_if_empty() {
528 let session_factory = |_, _| async move { Ok(()) }.boxed();
529 let client_behaviour = into_behaviour::<TestError, _>(|_, _| async move {}.boxed());
530
531 assert_beaviour_test_succeeds(session_factory, client_behaviour).await;
532 }
533
534 #[tokio::test]
535 async fn assert_receive_fails_if_no_message() {
536 let (_, mut out_rx) = tokio::sync::mpsc::unbounded_channel();
537
538 let result = assert_unwind_safe(async { assert_receive(&mut out_rx, |_| true) }).await;
539
540 assert!(result.is_err());
541 }
542
543 #[tokio::test]
544 async fn assert_receive_succeeds_if_message_matches() {
545 let (out_tx, mut out_rx) = tokio::sync::mpsc::unbounded_channel();
546
547 out_tx.send(vec![0u8]).expect("Should succeed");
548 let result = assert_unwind_safe(async { assert_receive(&mut out_rx, |_| true) }).await;
549
550 assert!(result.is_ok());
551 }
552
553 #[tokio::test]
554 async fn assert_receive_succeeds_if_session_closed() {
555 let (out_tx, mut out_rx) = tokio::sync::mpsc::unbounded_channel();
556
557 out_tx.send(vec![0u8]).expect("Should succeed");
558 out_rx.close();
559
560 let result = assert_unwind_safe(async { assert_receive(&mut out_rx, |_| true) }).await;
561
562 assert!(result.is_ok());
563
564 let result = assert_unwind_safe(async { assert_receive(&mut out_rx, |_| true) }).await;
565 assert!(result.is_err());
566 }
567
568 #[tokio::test]
569 async fn assert_receive_fails_if_predicate_false() {
570 let (out_tx, mut out_rx) = tokio::sync::mpsc::unbounded_channel();
571
572 out_tx.send(vec![0u8]).expect("Should succeed");
573
574 let result = assert_unwind_safe(async { assert_receive(&mut out_rx, |_| false) }).await;
575
576 assert!(result.is_err());
577 }
578
579 #[tokio::test]
580 async fn assert_receive_passes_message_to_predicate() {
581 let (out_tx, mut out_rx) = tokio::sync::mpsc::unbounded_channel();
582
583 out_tx.send(vec![0u8]).expect("Should succeed");
584
585 let result =
586 assert_unwind_safe(async { assert_receive(&mut out_rx, |data| data == vec![0u8]) })
587 .await;
588
589 assert!(result.is_ok());
590 }
591
592 #[tokio::test]
593 async fn receive_succeeds_with_message() {
594 let (mut tx, _) = tokio::sync::mpsc::unbounded_channel();
595 let (out_tx, mut out_rx) = tokio::sync::mpsc::unbounded_channel();
596
597 let behaviour = receive::<TestError, _>(|bytes| bytes == vec![42u8]);
598
599 out_tx.send(vec![42u8]).expect("Send Failed");
600
601 behaviour(&mut tx, &mut out_rx).await;
602 }
603
604 #[tokio::test]
605 async fn receive_fails_with_incorrect_message() {
606 let (mut tx, _) = tokio::sync::mpsc::unbounded_channel();
607 let (out_tx, mut out_rx) = tokio::sync::mpsc::unbounded_channel();
608
609 let behaviour = receive::<TestError, _>(|bytes| bytes == vec![43u8]);
610
611 out_tx.send(vec![42u8]).expect("Send Failed");
612
613 assert_unwind_safe(behaviour(&mut tx, &mut out_rx))
614 .await
615 .expect_err("Behaviour passed");
616 }
617
618 #[tokio::test]
619 async fn sleep_in_pause_passes_time() {
620 let session_factory = |_, sender: OutSender| {
621 async move {
622 tokio::time::sleep(Duration::from_millis(2000)).await;
623 sender.send(vec![1]).expect("Send failed");
624 Ok(())
625 }
626 .boxed()
627 };
628
629 let client_behaviour =
630 into_behaviour::<TestError, _>(|_, out_receiver: &mut OutReceiver| {
631 async move {
632 assert!(out_receiver.recv().now_or_never().is_none());
633 sleep_in_pause(3000).await;
634 assert_receive(out_receiver, |bytes| bytes == vec![1]);
635 }
636 .boxed()
637 });
638
639 assert_beaviour_test_succeeds(session_factory, client_behaviour).await;
640 }
641
642 #[tokio::test]
643 async fn wait_for_disconnect_fails_if_not_disconnected() {
644 let session_factory = |_, unused: OutSender| {
645 async move {
646 tokio::time::sleep(Duration::from_millis(5060)).await;
647 drop(unused); Ok(())
649 }
650 .boxed()
651 };
652
653 assert_beaviour_test_fails(session_factory, wait_for_disconnect::<TestError>).await;
654 }
655
656 #[tokio::test]
657 async fn wait_for_disconnect_succeeds_if_disconnected() {
658 let session_factory = |_, unused| {
659 async move {
660 tokio::time::sleep(Duration::from_millis(5000)).await;
661 drop(unused); Ok(())
663 }
664 .boxed()
665 };
666
667 assert_beaviour_test_succeeds(session_factory, wait_for_disconnect::<TestError>).await;
668 }
669}