stomp_test_utils/
framework.rs

1use std::{convert::TryInto, pin::Pin, time::Duration};
2
3use futures::{future::join, Future, FutureExt};
4use tokio::{
5    sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender},
6    task::yield_now,
7};
8
9pub trait ErrorType: Send + std::fmt::Debug + 'static {}
10
11impl<T: Send + std::fmt::Debug + 'static> ErrorType for T {}
12
13pub type InSender<E> = UnboundedSender<Result<Vec<u8>, E>>;
14pub type InReceiver<E> = UnboundedReceiver<Result<Vec<u8>, E>>;
15
16pub type OutReceiver = UnboundedReceiver<Vec<u8>>;
17pub type OutSender = UnboundedSender<Vec<u8>>;
18
19/// Creates a session which receives and sends messages on the provided receiver and sender respectively.
20pub trait SessionFactory<E: ErrorType>:
21    FnOnce(InReceiver<E>, OutSender) -> Pin<Box<dyn Future<Output = Result<(), E>> + Send>>
22{
23}
24
25impl<
26        E: ErrorType,
27        F: FnOnce(InReceiver<E>, OutSender) -> Pin<Box<dyn Future<Output = Result<(), E>> + Send>>,
28    > SessionFactory<E> for F
29{
30}
31
32/// A `BehaviourFunction` can send messages to the provided sender and check responses on the provided receiver,
33/// thereby testing expected behaviour. It returns the channels it received as inputs in order to faciliate
34/// further checks downstream, and enable chaining.
35///
36/// DOes
37/// A blanket implementation for any Function with the appropriate Signature is provided.
38pub trait BehaviourFunction<E: ErrorType>:
39    for<'a> FnOnce(
40        &'a mut InSender<E>,
41        &'a mut OutReceiver,
42    ) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>>
43    + Send
44{
45}
46
47impl<E: ErrorType, T> BehaviourFunction<E> for T where
48    for<'a> T: FnOnce(
49            &'a mut InSender<E>,
50            &'a mut OutReceiver,
51        ) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>>
52        + Send
53{
54}
55
56/// Enables chaining of [`BehaviourFunction`]s.
57pub trait Chainable<E: ErrorType>: BehaviourFunction<E> + Sized {
58    /// Constructs a new [`BehaviourFunction`] which will first execute `self`, and then `followed_by`.
59    fn then<S: BehaviourFunction<E> + 'static>(
60        self,
61        followed_by: S,
62    ) -> Box<dyn BehaviourFunction<E>>;
63}
64
65impl<E: ErrorType, T: BehaviourFunction<E> + 'static> Chainable<E> for T {
66    fn then<S: BehaviourFunction<E> + 'static>(
67        self,
68        followed_by: S,
69    ) -> Box<dyn BehaviourFunction<E>> {
70        Box::new(|sender: &mut InSender<E>, receiver: &mut OutReceiver| {
71            async move {
72                self(sender, receiver).await;
73                followed_by(sender, receiver).await
74            }
75            .boxed()
76        })
77    }
78}
79
80/// Help the compiler to assign appropriate lifetimes to inputs and outputs of BehaviourFunction-equivalent closure.
81pub fn into_behaviour<E, C>(closure: C) -> impl BehaviourFunction<E>
82where
83    E: ErrorType,
84    C: for<'a> FnOnce(
85            &'a mut InSender<E>,
86            &'a mut OutReceiver,
87        ) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>>
88        + Send
89        + 'static,
90{
91    closure
92}
93
94/// The core function of this crate, it can be used to test the session returned by `session_factory` by executing
95/// `behaviour` against it - in other words, the [`InSender`] provided to `behaviour` will send messages to the created
96/// session, and the [`OutReceiver`] will receive messages the session sends to its 'counterparty'. Any errors
97/// returned by either side will result in a panic.
98pub async fn assert_behaviour<
99    E: ErrorType,
100    F: SessionFactory<E>,
101    T: BehaviourFunction<E> + 'static,
102>(
103    session_factory: F,
104    behaviour: T,
105) {
106    let (in_sender, in_receiver) = unbounded_channel::<Result<Vec<u8>, E>>();
107    let (out_sender, out_receiver) = unbounded_channel();
108
109    let session_future = session_factory(in_receiver, out_sender);
110
111    let other_future = tokio::task::spawn(async move {
112        let mut out_receiver = out_receiver;
113        let mut in_sender = in_sender;
114
115        behaviour(&mut in_sender, &mut out_receiver).await;
116        out_receiver.close();
117        drop(in_sender);
118        ()
119    });
120
121    let results = join(session_future, other_future).await;
122
123    results.0.expect("Session returned error");
124    results.1.expect("Behaviour returned error");
125}
126
127/// Returns a [`BehaviourFuntion`] which sends the provided `data`, and then yields.
128pub fn send<T: Into<Vec<u8>> + Send + 'static, E: ErrorType>(
129    data: T,
130) -> impl BehaviourFunction<E> {
131    into_behaviour(move |in_sender: &mut InSender<E>, _: &mut OutReceiver| {
132        send_data(in_sender, data);
133        yield_now().boxed()
134    })
135}
136
137/// Sends `data` via `sender`, after converting it to bytes and transforming any error using `from`. Panics
138/// it the send fails.
139pub fn send_data<T: Into<Vec<u8>>, E: ErrorType>(
140    sender: &InSender<E>,
141    data: T,
142) {
143    sender
144        .send(Ok(data.into()))
145        .expect("Send failed");
146}
147
148/// Asserts that the receiver can _immediately_ provide a message which passes
149// the provided `predicate`. Thus the message must already have been sent by the session being tested.
150pub fn assert_receive<T: FnOnce(Vec<u8>) -> bool>(out_receiver: &mut OutReceiver, predicate: T) {
151    let response = out_receiver.recv().now_or_never();
152
153    assert!(predicate(
154        response
155            .expect("No message from session") // Now or never was 'never'
156            .expect("Session closed") // Message on channel was 'None'
157    ));
158}
159
160/// Returns a [`BehaviourFunction`] which will assert that a message is waiting, and that that
161/// message matches the provided `predicate`.
162pub fn receive<E: ErrorType, T: FnOnce(Vec<u8>) -> bool + Send + 'static>(
163    predicate: T,
164) -> impl BehaviourFunction<E> {
165    into_behaviour(|_: &mut InSender<E>, out_receiver: &mut OutReceiver| {
166        async move {
167            yield_now().await;
168            assert_receive(out_receiver, predicate)
169        }
170        .boxed()
171    })
172}
173
174/// Pauses tokio, sleeps for `millis` milliseconds, and then resumes tokio. Allows testing of actions that occur
175/// after some time, such as heartbeats, without actually having to wait for that amount of time.
176pub fn sleep_in_pause(millis: u64) -> impl Future<Output = ()> {
177    tokio::time::pause();
178    tokio::time::sleep(Duration::from_millis(millis)).inspect(|_| tokio::time::resume())
179}
180
181pub fn wait_for_disconnect<'a, E: ErrorType>(
182    _: &'a mut InSender<E>,
183    out_receiver: &'a mut OutReceiver,
184) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>> {
185    async move {
186        sleep_in_pause(5050).await;
187
188        assert!(matches!(out_receiver.recv().now_or_never(), Some(None)));
189        ()
190    }
191    .boxed()
192}
193
194#[cfg(test)]
195mod test {
196
197    use std::{
198        any::Any,
199        convert::Infallible,
200        panic::AssertUnwindSafe,
201        sync::{
202            atomic::{AtomicBool, Ordering},
203            Arc,
204        },
205    };
206
207    use std::panic;
208
209    use tokio::join;
210
211    use super::*;
212
213    #[tokio::test]
214    async fn chaining_works() {
215        let (mut tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
216        let (_, mut out_rx) = tokio::sync::mpsc::unbounded_channel();
217
218        let behaviour_a = into_behaviour(|sender: &mut InSender<()>, _: &mut OutReceiver| {
219            async move {
220                sender.send(Ok(b"foo".to_vec())).expect("Send failed");
221                ()
222            }
223            .boxed()
224        });
225
226        let behaviour_b = into_behaviour(|sender: &mut InSender<()>, _: &mut OutReceiver| {
227            async move {
228                sender.send(Ok(b"bar".to_vec())).expect("Send failed");
229                ()
230            }
231            .boxed()
232        });
233
234        // a then b...
235        behaviour_a.then(behaviour_b)(&mut tx, &mut out_rx).await;
236
237        // ... then close the channel
238        drop(tx);
239
240        assert_eq!(
241            "foo",
242            String::from_utf8(
243                rx.recv()
244                    .now_or_never()
245                    .unwrap()
246                    .unwrap()
247                    .expect("recv failed"),
248            )
249            .expect("Parse failed"),
250        );
251        assert_eq!(
252            "bar",
253            String::from_utf8(
254                rx.recv()
255                    .now_or_never()
256                    .unwrap()
257                    .unwrap()
258                    .expect("recv failed"),
259            )
260            .expect("Parse failed"),
261        );
262
263        assert_eq!(None, rx.recv().now_or_never().unwrap());
264    }
265
266    #[derive(Debug, PartialEq, Eq)]
267    struct TestError;
268
269    impl From<Infallible> for TestError {
270        fn from(_: Infallible) -> Self {
271            TestError
272        }
273    }
274
275    #[tokio::test]
276    async fn send_works() {
277        let (mut tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
278        let (_, mut out_rx) = tokio::sync::mpsc::unbounded_channel();
279
280        let behaviour = send::<String, TestError>("Hello".to_owned());
281
282        behaviour(&mut tx, &mut out_rx);
283
284        // ... then close the channel
285        drop(tx);
286
287        assert_eq!(
288            "Hello",
289            String::from_utf8(
290                rx.recv()
291                    .now_or_never()
292                    .unwrap()
293                    .unwrap()
294                    .expect("recv failed"),
295            )
296            .expect("Parse failed"),
297        );
298
299        assert_eq!(None, rx.recv().now_or_never().unwrap());
300    }
301
302    struct TestData;
303
304    #[derive(Debug, PartialEq, Eq)]
305    struct TestError2;
306
307    impl From<TestError2> for TestError {
308        fn from(_: TestError2) -> Self {
309            TestError
310        }
311    }
312    impl Into<Vec<u8>> for TestData {
313        
314        fn into(self) -> Vec<u8> {
315            vec![]
316        }
317    }
318
319    #[tokio::test]
320    async fn send_handles_conversion_error() {
321        let (mut tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
322        let (_, mut out_rx) = tokio::sync::mpsc::unbounded_channel();
323
324        let behaviour = send::<TestData, TestError>(TestData);
325
326        behaviour(&mut tx, &mut out_rx);
327
328        // ... then close the channel
329        drop(tx);
330
331        assert_eq!(Err(TestError), rx.recv().now_or_never().unwrap().unwrap());
332
333        assert_eq!(None, rx.recv().now_or_never().unwrap());
334    }
335
336    #[tokio::test]
337    async fn send_yields() {
338        let (mut tx, mut rx) = tokio::sync::mpsc::unbounded_channel::<Result<Vec<u8>, TestError>>();
339        let (_, mut out_rx) = tokio::sync::mpsc::unbounded_channel();
340
341        let handle = tokio::task::spawn(async move {
342            let first = rx.recv().await;
343
344            assert_eq!(
345                "Hello",
346                String::from_utf8(first.unwrap().expect("recv failed"),).expect("Parse failed"),
347            );
348
349            // Because the first send yielded to this task, the second one has not send yet
350            assert_eq!(None, rx.recv().now_or_never());
351
352            let second = rx.recv().await;
353
354            assert_eq!(
355                "world",
356                String::from_utf8(second.unwrap().expect("recv failed"),).expect("Parse failed"),
357            );
358        });
359
360        let behaviour = send::<String, TestError>("Hello".to_owned())
361            .then(send::<String, TestError>("world".to_owned()));
362
363        let x = join!(handle, behaviour(&mut tx, &mut out_rx));
364
365        assert!(x.0.is_ok());
366    }
367
368    #[tokio::test]
369    async fn send_data_sends() {
370        let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::<Result<Vec<u8>, TestError>>();
371
372        send_data(&tx, b"1".to_vec());
373
374        assert_eq!(
375            b"1".to_vec(),
376            rx.recv()
377                .await
378                .expect("Should be Some")
379                .expect("Should be ok")
380        );
381    }
382
383    #[tokio::test]
384    async fn test_expectations_executes_behaviour() {
385        let called = Arc::new(AtomicBool::new(false));
386
387        let session_factory = |_: InReceiver<TestError>, _: OutSender| async { Ok(()) }.boxed();
388        let client_behaviour = into_behaviour({
389            let called = called.clone();
390            |_, _| async move { called.store(true, Ordering::Release) }.boxed()
391        });
392
393        assert_behaviour(session_factory, client_behaviour).await;
394
395        assert_eq!(true, called.load(Ordering::Acquire));
396    }
397
398    #[tokio::test]
399    async fn test_expectations_starts_session() {
400        let called = Arc::new(AtomicBool::new(false));
401
402        let session_factory = {
403            let called = called.clone();
404            |_: InReceiver<TestError>, _: OutSender| {
405                async move { Ok(called.store(true, Ordering::Release)) }.boxed()
406            }
407        };
408        let client_behaviour = into_behaviour(|_, _| async {}.boxed());
409
410        assert_behaviour(session_factory, client_behaviour).await;
411
412        assert_eq!(true, called.load(Ordering::Acquire));
413    }
414
415    #[tokio::test]
416    async fn test_expectations_sends_in_to_session() {
417        let session_factory = {
418            |mut receiver: InReceiver<TestError>, _: OutSender| {
419                async move {
420                    match receiver.recv().await {
421                        Some(Ok(x)) if x == vec![42u8] => Ok(()),
422                        _ => Err(TestError),
423                    }
424                }
425                .boxed()
426            }
427        };
428        let client_behaviour = into_behaviour(|in_sender, _| {
429            async move {
430                in_sender.send(Ok(vec![42u8])).expect("Send failed");
431            }
432            .boxed()
433        });
434
435        assert_behaviour(session_factory, client_behaviour).await;
436    }
437
438    #[tokio::test]
439    async fn test_expectations_receives_out_from_session() {
440        let session_factory = {
441            |_, sender: OutSender| {
442                async move { sender.send(vec![42u8]).map_err(|_| TestError) }.boxed()
443            }
444        };
445        let client_behaviour =
446            into_behaviour::<TestError, _>(|_, out_receiver: &mut OutReceiver| {
447                async move {
448                    assert!(matches!(out_receiver.recv().await, Some(x) if x == vec![42u8]));
449                }
450                .boxed()
451            });
452
453        assert_behaviour(session_factory, client_behaviour).await;
454    }
455
456    fn assert_unwind_safe<O, F: Future<Output = O>>(
457        future: F,
458    ) -> impl Future<Output = Result<O, Box<dyn Any + std::marker::Send>>> {
459        panic::set_hook(Box::new(|_info| {
460            // do nothing
461        }));
462
463        AssertUnwindSafe(future).catch_unwind()
464    }
465
466    async fn assert_beaviour_test_result<
467        E: ErrorType,
468        F: SessionFactory<E> + 'static,
469        B: BehaviourFunction<E> + 'static,
470    >(
471        session_factory: F,
472        behaviour: B,
473        expect_err: bool,
474    ) {
475        panic::set_hook(Box::new(|_info| {
476            // do nothing
477        }));
478        let result = assert_unwind_safe(assert_behaviour(session_factory, behaviour)).await;
479
480        assert_eq!(expect_err, result.is_err());
481    }
482
483    async fn assert_beaviour_test_fails<
484        E: ErrorType,
485        F: SessionFactory<E> + 'static,
486        B: BehaviourFunction<E> + 'static,
487    >(
488        session_factory: F,
489        behaviour: B,
490    ) {
491        assert_beaviour_test_result(session_factory, behaviour, true).await;
492    }
493
494    async fn assert_beaviour_test_succeeds<
495        E: ErrorType,
496        F: SessionFactory<E> + 'static,
497        B: BehaviourFunction<E> + 'static,
498    >(
499        session_factory: F,
500        behaviour: B,
501    ) {
502        assert_beaviour_test_result(session_factory, behaviour, false).await;
503    }
504
505    #[tokio::test]
506    async fn test_expectations_fails_if_error_in_session() {
507        let session_factory = { |_, _| async move { Err(TestError) }.boxed() };
508        let client_behaviour = into_behaviour::<TestError, _>(|_, _| async move {}.boxed());
509
510        assert_beaviour_test_fails(session_factory, client_behaviour).await;
511    }
512
513    #[tokio::test]
514    async fn test_expectations_fails_if_error_in_behaviour() {
515        let session_factory = |_, _| async move { Ok(()) }.boxed();
516        let client_behaviour = into_behaviour::<TestError, _>(|_, _| {
517            async move {
518                assert!(false);
519            }
520            .boxed()
521        });
522
523        assert_beaviour_test_fails(session_factory, client_behaviour).await;
524    }
525
526    #[tokio::test]
527    async fn test_expectations_succeeds_if_empty() {
528        let session_factory = |_, _| async move { Ok(()) }.boxed();
529        let client_behaviour = into_behaviour::<TestError, _>(|_, _| async move {}.boxed());
530
531        assert_beaviour_test_succeeds(session_factory, client_behaviour).await;
532    }
533
534    #[tokio::test]
535    async fn assert_receive_fails_if_no_message() {
536        let (_, mut out_rx) = tokio::sync::mpsc::unbounded_channel();
537
538        let result = assert_unwind_safe(async { assert_receive(&mut out_rx, |_| true) }).await;
539
540        assert!(result.is_err());
541    }
542
543    #[tokio::test]
544    async fn assert_receive_succeeds_if_message_matches() {
545        let (out_tx, mut out_rx) = tokio::sync::mpsc::unbounded_channel();
546
547        out_tx.send(vec![0u8]).expect("Should succeed");
548        let result = assert_unwind_safe(async { assert_receive(&mut out_rx, |_| true) }).await;
549
550        assert!(result.is_ok());
551    }
552
553    #[tokio::test]
554    async fn assert_receive_succeeds_if_session_closed() {
555        let (out_tx, mut out_rx) = tokio::sync::mpsc::unbounded_channel();
556
557        out_tx.send(vec![0u8]).expect("Should succeed");
558        out_rx.close();
559
560        let result = assert_unwind_safe(async { assert_receive(&mut out_rx, |_| true) }).await;
561
562        assert!(result.is_ok());
563
564        let result = assert_unwind_safe(async { assert_receive(&mut out_rx, |_| true) }).await;
565        assert!(result.is_err());
566    }
567
568    #[tokio::test]
569    async fn assert_receive_fails_if_predicate_false() {
570        let (out_tx, mut out_rx) = tokio::sync::mpsc::unbounded_channel();
571
572        out_tx.send(vec![0u8]).expect("Should succeed");
573
574        let result = assert_unwind_safe(async { assert_receive(&mut out_rx, |_| false) }).await;
575
576        assert!(result.is_err());
577    }
578
579    #[tokio::test]
580    async fn assert_receive_passes_message_to_predicate() {
581        let (out_tx, mut out_rx) = tokio::sync::mpsc::unbounded_channel();
582
583        out_tx.send(vec![0u8]).expect("Should succeed");
584
585        let result =
586            assert_unwind_safe(async { assert_receive(&mut out_rx, |data| data == vec![0u8]) })
587                .await;
588
589        assert!(result.is_ok());
590    }
591
592    #[tokio::test]
593    async fn receive_succeeds_with_message() {
594        let (mut tx, _) = tokio::sync::mpsc::unbounded_channel();
595        let (out_tx, mut out_rx) = tokio::sync::mpsc::unbounded_channel();
596
597        let behaviour = receive::<TestError, _>(|bytes| bytes == vec![42u8]);
598
599        out_tx.send(vec![42u8]).expect("Send Failed");
600
601        behaviour(&mut tx, &mut out_rx).await;
602    }
603
604    #[tokio::test]
605    async fn receive_fails_with_incorrect_message() {
606        let (mut tx, _) = tokio::sync::mpsc::unbounded_channel();
607        let (out_tx, mut out_rx) = tokio::sync::mpsc::unbounded_channel();
608
609        let behaviour = receive::<TestError, _>(|bytes| bytes == vec![43u8]);
610
611        out_tx.send(vec![42u8]).expect("Send Failed");
612
613        assert_unwind_safe(behaviour(&mut tx, &mut out_rx))
614            .await
615            .expect_err("Behaviour passed");
616    }
617
618    #[tokio::test]
619    async fn sleep_in_pause_passes_time() {
620        let session_factory = |_, sender: OutSender| {
621            async move {
622                tokio::time::sleep(Duration::from_millis(2000)).await;
623                sender.send(vec![1]).expect("Send failed");
624                Ok(())
625            }
626            .boxed()
627        };
628
629        let client_behaviour =
630            into_behaviour::<TestError, _>(|_, out_receiver: &mut OutReceiver| {
631                async move {
632                    assert!(out_receiver.recv().now_or_never().is_none());
633                    sleep_in_pause(3000).await;
634                    assert_receive(out_receiver, |bytes| bytes == vec![1]);
635                }
636                .boxed()
637            });
638
639        assert_beaviour_test_succeeds(session_factory, client_behaviour).await;
640    }
641
642    #[tokio::test]
643    async fn wait_for_disconnect_fails_if_not_disconnected() {
644        let session_factory = |_, unused: OutSender| {
645            async move {
646                tokio::time::sleep(Duration::from_millis(5060)).await;
647                drop(unused); // disconnects, but too late
648                Ok(())
649            }
650            .boxed()
651        };
652
653        assert_beaviour_test_fails(session_factory, wait_for_disconnect::<TestError>).await;
654    }
655
656    #[tokio::test]
657    async fn wait_for_disconnect_succeeds_if_disconnected() {
658        let session_factory = |_, unused| {
659            async move {
660                tokio::time::sleep(Duration::from_millis(5000)).await;
661                drop(unused); // closes
662                Ok(())
663            }
664            .boxed()
665        };
666
667        assert_beaviour_test_succeeds(session_factory, wait_for_disconnect::<TestError>).await;
668    }
669}