wireframe_testing/
helpers.rs

1//! Helper utilities for driving `WireframeApp` instances in tests.
2//!
3//! These functions spin up an application on an in-memory duplex stream and
4//! collect the bytes written back by the app for assertions.
5
6use std::io;
7
8use bincode::config;
9use bytes::BytesMut;
10use tokio::io::{AsyncReadExt, AsyncWriteExt, DuplexStream, duplex};
11use tokio_util::codec::{Decoder, Encoder, LengthDelimitedCodec};
12use wireframe::{
13    app::{Envelope, Packet, WireframeApp},
14    frame::{FrameMetadata, LengthFormat},
15    serializer::Serializer,
16};
17
18pub trait TestSerializer:
19    Serializer + FrameMetadata<Frame = Envelope> + Send + Sync + 'static
20{
21}
22
23impl<T> TestSerializer for T where
24    T: Serializer + FrameMetadata<Frame = Envelope> + Send + Sync + 'static
25{
26}
27
28/// Run `server_fn` against a duplex stream, writing each `frame` to the client
29/// half and returning the bytes produced by the server.
30///
31/// The server function receives the server half of a `tokio::io::duplex`
32/// connection. Every provided frame is written to the client side in order and
33/// the collected output is returned once the server task completes. If the
34/// server panics, the panic message is surfaced as an `io::Error` beginning
35/// with `"server task failed"`.
36///
37/// ```rust
38/// use tokio::io::{AsyncWriteExt, DuplexStream};
39/// use wireframe_testing::helpers::drive_internal;
40///
41/// async fn echo(mut server: DuplexStream) { let _ = server.write_all(&[1, 2]).await; }
42///
43/// # async fn demo() -> std::io::Result<()> {
44/// let bytes = drive_internal(echo, vec![vec![0]], 64).await?;
45/// assert_eq!(bytes, [1, 2]);
46/// # Ok(())
47/// # }
48/// ```
49async fn drive_internal<F, Fut>(
50    server_fn: F,
51    frames: Vec<Vec<u8>>,
52    capacity: usize,
53) -> io::Result<Vec<u8>>
54where
55    F: FnOnce(DuplexStream) -> Fut,
56    Fut: std::future::Future<Output = ()> + Send,
57{
58    let (mut client, server) = duplex(capacity);
59
60    let server_fut = async {
61        use futures::FutureExt as _;
62        let result = std::panic::AssertUnwindSafe(server_fn(server))
63            .catch_unwind()
64            .await;
65        match result {
66            Ok(_) => Ok(()),
67            Err(panic) => {
68                let panic_msg = wireframe::panic::format_panic(&panic);
69                Err(io::Error::new(
70                    io::ErrorKind::Other,
71                    format!("server task failed: {panic_msg}"),
72                ))
73            }
74        }
75    };
76
77    let client_fut = async {
78        for frame in &frames {
79            client.write_all(frame).await?;
80        }
81        client.shutdown().await?;
82
83        let mut buf = Vec::new();
84        client.read_to_end(&mut buf).await?;
85        io::Result::Ok(buf)
86    };
87
88    let ((), buf) = tokio::try_join!(server_fut, client_fut)?;
89    Ok(buf)
90}
91
92const DEFAULT_CAPACITY: usize = 4096;
93const MAX_CAPACITY: usize = 1024 * 1024 * 10; // 10MB limit
94pub(crate) const EMPTY_SERVER_CAPACITY: usize = 64;
95/// Shared frame cap used by helpers and tests to avoid drift.
96pub const TEST_MAX_FRAME: usize = DEFAULT_CAPACITY;
97
98#[inline]
99pub fn new_test_codec(max_len: usize) -> LengthDelimitedCodec {
100    let mut builder = LengthDelimitedCodec::builder();
101    builder.max_frame_length(max_len);
102    builder.new_codec()
103}
104
105/// Decode all length-prefixed `frames` using a test codec and assert no bytes remain.
106///
107/// This helper constructs a [`LengthDelimitedCodec`] capped at [`TEST_MAX_FRAME`]
108/// and decodes each frame in `bytes`, ensuring the buffer is fully consumed.
109///
110/// ```rust
111/// # use wireframe_testing::decode_frames;
112/// let frames = decode_frames(vec![0, 0, 0, 1, 42]);
113/// assert_eq!(frames, vec![vec![42]]);
114/// ```
115#[must_use]
116pub fn decode_frames(bytes: Vec<u8>) -> Vec<Vec<u8>> {
117    decode_frames_with_max(bytes, TEST_MAX_FRAME)
118}
119
120/// Decode `bytes` into frames using a codec capped at `max_len`.
121///
122/// Asserts that no trailing bytes remain after all frames are decoded.
123#[must_use]
124pub fn decode_frames_with_max(bytes: Vec<u8>, max_len: usize) -> Vec<Vec<u8>> {
125    let mut codec = new_test_codec(max_len);
126    let mut buf = BytesMut::from(&bytes[..]);
127    let mut frames = Vec::new();
128    while let Some(frame) = codec.decode(&mut buf).expect("decode failed") {
129        frames.push(frame.to_vec());
130    }
131    assert!(buf.is_empty(), "unexpected trailing bytes after decode");
132    frames
133}
134
135macro_rules! forward_default {
136    (
137        $(#[$docs:meta])* $vis:vis fn $name:ident(
138            $app:ident : $app_ty:ty,
139            $arg:ident : $arg_ty:ty
140        ) -> $ret:ty
141        => $inner:ident($app_expr:ident, $arg_expr:expr)
142    ) => {
143        $(#[$docs])*
144        $vis async fn $name<S, C, E>(
145            $app: $app_ty,
146            $arg: $arg_ty,
147        ) -> $ret
148        where
149            S: TestSerializer,
150            C: Send + 'static,
151            E: Packet,
152        {
153            $inner($app_expr, $arg_expr, DEFAULT_CAPACITY).await
154        }
155    };
156}
157
158macro_rules! forward_with_capacity {
159    (
160        $(#[$docs:meta])* $vis:vis fn $name:ident(
161            $app:ident : $app_ty:ty,
162            $arg:ident : $arg_ty:ty,
163            capacity: usize
164        ) -> $ret:ty
165        => $inner:ident($app_expr:ident, $arg_expr:expr, capacity)
166    ) => {
167        $(#[$docs])*
168        $vis async fn $name<S, C, E>(
169            $app: $app_ty,
170            $arg: $arg_ty,
171            capacity: usize,
172        ) -> $ret
173        where
174            S: TestSerializer,
175            C: Send + 'static,
176            E: Packet,
177        {
178            $inner($app_expr, $arg_expr, capacity).await
179        }
180    };
181}
182
183/// Drive `app` with a single length-prefixed `frame` and return the bytes
184/// produced by the server.
185///
186/// The app runs on an in-memory duplex stream so tests need not open real
187/// sockets.
188///
189/// # Errors
190///
191/// Returns any I/O errors encountered while interacting with the in-memory
192/// duplex stream.
193///
194/// ```rust
195/// # use wireframe_testing::drive_with_frame;
196/// # use wireframe::app::WireframeApp;
197/// # async fn demo() -> std::io::Result<()> {
198/// let app = WireframeApp::new().expect("failed to initialize app");
199/// let bytes = drive_with_frame(app, vec![1, 2, 3]).await?;
200/// # Ok(())
201/// # }
202/// ```
203pub async fn drive_with_frame<S, C, E>(
204    app: WireframeApp<S, C, E>,
205    frame: Vec<u8>,
206) -> io::Result<Vec<u8>>
207where
208    S: TestSerializer,
209    C: Send + 'static,
210    E: Packet,
211{
212    drive_with_frame_with_capacity(app, frame, DEFAULT_CAPACITY).await
213}
214
215forward_with_capacity! {
216    /// Drive `app` with a single frame using a duplex buffer of `capacity` bytes.
217    ///
218    /// Adjusting the buffer size helps exercise edge cases such as small channels.
219    ///
220    /// ```rust
221    /// # use wireframe_testing::drive_with_frame_with_capacity;
222    /// # use wireframe::app::WireframeApp;
223    /// # async fn demo() -> std::io::Result<()> {
224    /// let app = WireframeApp::new().expect("failed to initialize app");
225    /// let bytes = drive_with_frame_with_capacity(app, vec![0], 512).await?;
226    /// # Ok(())
227    /// # }
228    /// ```
229    pub fn drive_with_frame_with_capacity(app: WireframeApp<S, C, E>, frame: Vec<u8>, capacity: usize) -> io::Result<Vec<u8>>
230    => drive_with_frames_with_capacity(app, vec![frame], capacity)
231}
232
233forward_default! {
234    /// Drive `app` with a sequence of frames using the default buffer size.
235    ///
236    /// Each frame is written to the duplex stream in order.
237    ///
238    /// ```rust
239    /// # use wireframe_testing::drive_with_frames;
240    /// # use wireframe::app::WireframeApp;
241    /// # async fn demo() -> std::io::Result<()> {
242    /// let app = WireframeApp::new().expect("failed to initialize app");
243    /// let out = drive_with_frames(app, vec![vec![1], vec![2]]).await?;
244    /// # Ok(())
245    /// # }
246    /// ```
247    pub fn drive_with_frames(app: WireframeApp<S, C, E>, frames: Vec<Vec<u8>>) -> io::Result<Vec<u8>>
248    => drive_with_frames_with_capacity(app, frames)
249}
250
251/// Drive `app` with multiple frames using a duplex buffer of `capacity` bytes.
252///
253/// This variant exposes the buffer size for fine-grained control in tests.
254///
255/// ```rust
256/// # use wireframe_testing::drive_with_frames_with_capacity;
257/// # use wireframe::app::WireframeApp;
258/// # async fn demo() -> std::io::Result<()> {
259/// let app = WireframeApp::new().expect("failed to initialize app");
260/// let out = drive_with_frames_with_capacity(app, vec![vec![1], vec![2]], 1024).await?;
261/// # Ok(())
262/// # }
263/// ```
264pub async fn drive_with_frames_with_capacity<S, C, E>(
265    app: WireframeApp<S, C, E>,
266    frames: Vec<Vec<u8>>,
267    capacity: usize,
268) -> io::Result<Vec<u8>>
269where
270    S: TestSerializer,
271    C: Send + 'static,
272    E: Packet,
273{
274    drive_internal(
275        |server| async move { app.handle_connection(server).await },
276        frames,
277        capacity,
278    )
279    .await
280}
281
282forward_default! {
283    /// Feed a single frame into a mutable `app`, allowing the instance to be reused
284    /// across calls.
285    ///
286    /// ```rust
287    /// # use wireframe_testing::drive_with_frame_mut;
288    /// # use wireframe::app::WireframeApp;
289    /// # async fn demo() -> std::io::Result<()> {
290    /// let mut app = WireframeApp::new().expect("failed to initialize app");
291    /// let bytes = drive_with_frame_mut(&mut app, vec![1]).await?;
292    /// # Ok(())
293    /// # }
294    /// ```
295    pub fn drive_with_frame_mut(app: &mut WireframeApp<S, C, E>, frame: Vec<u8>) -> io::Result<Vec<u8>>
296    => drive_with_frame_with_capacity_mut(app, frame)
297}
298
299forward_with_capacity! {
300    /// Feed a single frame into `app` using a duplex buffer of `capacity` bytes.
301    ///
302    /// ```rust
303    /// # use wireframe_testing::drive_with_frame_with_capacity_mut;
304    /// # use wireframe::app::WireframeApp;
305    /// # async fn demo() -> std::io::Result<()> {
306    /// let mut app = WireframeApp::new().expect("failed to initialize app");
307    /// let bytes = drive_with_frame_with_capacity_mut(&mut app, vec![1], 256).await?;
308    /// # Ok(())
309    /// # }
310    /// ```
311    pub fn drive_with_frame_with_capacity_mut(app: &mut WireframeApp<S, C, E>, frame: Vec<u8>, capacity: usize) -> io::Result<Vec<u8>>
312    => drive_with_frames_with_capacity_mut(app, vec![frame], capacity)
313}
314
315forward_default! {
316    /// Feed multiple frames into a mutable `app`.
317    ///
318    /// ```rust
319    /// # use wireframe_testing::drive_with_frames_mut;
320    /// # use wireframe::app::WireframeApp;
321    /// # async fn demo() -> std::io::Result<()> {
322    /// let mut app = WireframeApp::new().expect("failed to initialize app");
323    /// let out = drive_with_frames_mut(&mut app, vec![vec![1], vec![2]]).await?;
324    /// # Ok(())
325    /// # }
326    /// ```
327    pub fn drive_with_frames_mut(app: &mut WireframeApp<S, C, E>, frames: Vec<Vec<u8>>) -> io::Result<Vec<u8>>
328    => drive_with_frames_with_capacity_mut(app, frames)
329}
330
331/// Feed multiple frames into `app` with a duplex buffer of `capacity` bytes.
332///
333/// ```rust
334/// # use wireframe_testing::drive_with_frames_with_capacity_mut;
335/// # use wireframe::app::WireframeApp;
336/// # async fn demo() -> std::io::Result<()> {
337/// let mut app = WireframeApp::new().expect("failed to initialize app");
338/// let out = drive_with_frames_with_capacity_mut(&mut app, vec![vec![1], vec![2]], 64).await?;
339/// # Ok(())
340/// # }
341/// ```
342pub async fn drive_with_frames_with_capacity_mut<S, C, E>(
343    app: &mut WireframeApp<S, C, E>,
344    frames: Vec<Vec<u8>>,
345    capacity: usize,
346) -> io::Result<Vec<u8>>
347where
348    S: TestSerializer,
349    C: Send + 'static,
350    E: Packet,
351{
352    drive_internal(
353        |server| async { app.handle_connection(server).await },
354        frames,
355        capacity,
356    )
357    .await
358}
359
360/// Encode `msg` using bincode, frame it and drive `app`.
361///
362/// ```rust
363/// # use wireframe_testing::drive_with_bincode;
364/// # use wireframe::app::WireframeApp;
365/// #[derive(bincode::Encode)]
366/// struct Ping(u8);
367/// # async fn demo() -> std::io::Result<()> {
368/// let app = WireframeApp::new().expect("failed to initialize app");
369/// let bytes = drive_with_bincode(app, Ping(1)).await?;
370/// # Ok(())
371/// # }
372/// ```
373pub async fn drive_with_bincode<M, S, C, E>(
374    app: WireframeApp<S, C, E>,
375    msg: M,
376) -> io::Result<Vec<u8>>
377where
378    M: bincode::Encode,
379    S: TestSerializer,
380    C: Send + 'static,
381    E: Packet,
382{
383    let bytes = bincode::encode_to_vec(msg, config::standard()).map_err(|e| {
384        io::Error::new(
385            io::ErrorKind::InvalidData,
386            format!("bincode encode failed: {e}"),
387        )
388    })?;
389    let mut codec = new_test_codec(DEFAULT_CAPACITY);
390    let mut framed = BytesMut::with_capacity(bytes.len() + 4);
391    codec.encode(bytes.into(), &mut framed)?;
392    drive_with_frame(app, framed.to_vec()).await
393}
394
395/// Run `app` with input `frames` using an optional duplex buffer `capacity`.
396///
397/// When `capacity` is `None`, a buffer of [`DEFAULT_CAPACITY`] bytes is used.
398/// Frames are written to the client side in order and the bytes emitted by the
399/// server are collected for inspection.
400///
401/// # Errors
402///
403/// Returns an error if `capacity` is zero or exceeds [`MAX_CAPACITY`]. Any
404/// panic in the application task or I/O error on the duplex stream is also
405/// surfaced as an error.
406///
407/// ```rust
408/// # use wireframe_testing::run_app;
409/// # use wireframe::app::WireframeApp;
410/// # async fn demo() -> std::io::Result<()> {
411/// let app = WireframeApp::new().expect("failed to initialize app");
412/// let out = run_app(app, vec![vec![1]], None).await?;
413/// # Ok(())
414/// # }
415/// ```
416
417/// Encode bytes with a length-delimited `codec`, preallocating the prefix.
418///
419/// Panics if encoding fails.
420#[must_use]
421pub fn encode_frame(codec: &mut LengthDelimitedCodec, bytes: Vec<u8>) -> Vec<u8> {
422    let header_len = LengthFormat::default().bytes();
423    let mut buf = BytesMut::with_capacity(bytes.len() + header_len);
424    codec.encode(bytes.into(), &mut buf).expect("encode failed");
425    buf.to_vec()
426}
427
428pub async fn run_app<S, C, E>(
429    app: WireframeApp<S, C, E>,
430    frames: Vec<Vec<u8>>,
431    capacity: Option<usize>,
432) -> io::Result<Vec<u8>>
433where
434    S: TestSerializer,
435    C: Send + 'static,
436    E: Packet,
437{
438    let capacity = capacity.unwrap_or(DEFAULT_CAPACITY);
439    if capacity == 0 {
440        return Err(io::Error::new(
441            io::ErrorKind::InvalidInput,
442            "capacity must be greater than zero",
443        ));
444    }
445    if capacity > MAX_CAPACITY {
446        return Err(io::Error::new(
447            io::ErrorKind::InvalidInput,
448            format!("capacity must not exceed {MAX_CAPACITY} bytes"),
449        ));
450    }
451
452    let (mut client, server) = duplex(capacity);
453    let server_task = tokio::spawn(async move { app.handle_connection(server).await });
454
455    for frame in &frames {
456        client.write_all(frame).await?;
457    }
458    client.shutdown().await?;
459
460    let mut buf = Vec::new();
461    client.read_to_end(&mut buf).await?;
462
463    if let Err(e) = server_task.await {
464        return Err(io::Error::new(
465            io::ErrorKind::Other,
466            format!("server task failed: {e}"),
467        ));
468    }
469
470    Ok(buf)
471}
472
473#[cfg(test)]
474mod tests {
475    use wireframe::app::WireframeApp;
476
477    use super::*;
478
479    #[tokio::test]
480    async fn run_app_rejects_zero_capacity() {
481        let app = WireframeApp::new().expect("failed to create app");
482        let err = run_app(app, vec![], Some(0))
483            .await
484            .expect_err("capacity of zero should error");
485        assert_eq!(err.kind(), std::io::ErrorKind::InvalidInput);
486    }
487
488    #[tokio::test]
489    async fn run_app_rejects_excess_capacity() {
490        let app = WireframeApp::new().expect("failed to create app");
491        let err = run_app(app, vec![], Some(MAX_CAPACITY + 1))
492            .await
493            .expect_err("capacity beyond max should error");
494        assert_eq!(err.kind(), std::io::ErrorKind::InvalidInput);
495    }
496}
497
498/// Run `app` against an empty duplex stream.
499///
500/// This helper drives the connection lifecycle without sending any frames,
501/// ensuring setup and teardown callbacks execute.
502///
503/// # Panics
504///
505/// Panics if `handle_connection` fails.
506///
507/// ```rust
508/// # use wireframe_testing::run_with_duplex_server;
509/// # use wireframe::app::WireframeApp;
510/// # async fn demo() {
511/// let app = WireframeApp::new()
512///     .expect("failed to initialize app");
513/// run_with_duplex_server(app).await;
514/// }
515/// ```
516pub async fn run_with_duplex_server<S, C, E>(app: WireframeApp<S, C, E>)
517where
518    S: TestSerializer,
519    C: Send + 'static,
520    E: Packet,
521{
522    let (_, server) = duplex(EMPTY_SERVER_CAPACITY); // discard client half
523    app.handle_connection(server).await;
524}
525
526/// Await the provided future and panic with context on failure.
527///
528/// In debug builds, the generated message includes the call site for easier
529/// troubleshooting.
530#[macro_export]
531macro_rules! push_expect {
532    ($fut:expr) => {{
533        $fut.await
534            .expect(concat!("push failed at ", file!(), ":", line!()))
535    }};
536    ($fut:expr, $msg:expr) => {{
537        let m = ::std::format!("{msg} at {}:{}", file!(), line!(), msg = $msg);
538        $fut.await.expect(&m)
539    }};
540}
541
542/// Await the provided future and panic with context on failure.
543///
544/// In debug builds, the generated message includes the call site for easier
545/// troubleshooting.
546#[macro_export]
547macro_rules! recv_expect {
548    ($fut:expr) => {{
549        $fut.await
550            .expect(concat!("recv failed at ", file!(), ":", line!()))
551    }};
552    ($fut:expr, $msg:expr) => {{
553        let m = ::std::format!("{msg} at {}:{}", file!(), line!(), msg = $msg);
554        $fut.await.expect(&m)
555    }};
556}