wireframe_testing/helpers.rs
1//! Helper utilities for driving `WireframeApp` instances in tests.
2//!
3//! These functions spin up an application on an in-memory duplex stream and
4//! collect the bytes written back by the app for assertions.
5
6use std::io;
7
8use bincode::config;
9use bytes::BytesMut;
10use tokio::io::{AsyncReadExt, AsyncWriteExt, DuplexStream, duplex};
11use tokio_util::codec::{Decoder, Encoder, LengthDelimitedCodec};
12use wireframe::{
13 app::{Envelope, Packet, WireframeApp},
14 frame::{FrameMetadata, LengthFormat},
15 serializer::Serializer,
16};
17
18pub trait TestSerializer:
19 Serializer + FrameMetadata<Frame = Envelope> + Send + Sync + 'static
20{
21}
22
23impl<T> TestSerializer for T where
24 T: Serializer + FrameMetadata<Frame = Envelope> + Send + Sync + 'static
25{
26}
27
28/// Run `server_fn` against a duplex stream, writing each `frame` to the client
29/// half and returning the bytes produced by the server.
30///
31/// The server function receives the server half of a `tokio::io::duplex`
32/// connection. Every provided frame is written to the client side in order and
33/// the collected output is returned once the server task completes. If the
34/// server panics, the panic message is surfaced as an `io::Error` beginning
35/// with `"server task failed"`.
36///
37/// ```rust
38/// use tokio::io::{AsyncWriteExt, DuplexStream};
39/// use wireframe_testing::helpers::drive_internal;
40///
41/// async fn echo(mut server: DuplexStream) { let _ = server.write_all(&[1, 2]).await; }
42///
43/// # async fn demo() -> std::io::Result<()> {
44/// let bytes = drive_internal(echo, vec![vec![0]], 64).await?;
45/// assert_eq!(bytes, [1, 2]);
46/// # Ok(())
47/// # }
48/// ```
49async fn drive_internal<F, Fut>(
50 server_fn: F,
51 frames: Vec<Vec<u8>>,
52 capacity: usize,
53) -> io::Result<Vec<u8>>
54where
55 F: FnOnce(DuplexStream) -> Fut,
56 Fut: std::future::Future<Output = ()> + Send,
57{
58 let (mut client, server) = duplex(capacity);
59
60 let server_fut = async {
61 use futures::FutureExt as _;
62 let result = std::panic::AssertUnwindSafe(server_fn(server))
63 .catch_unwind()
64 .await;
65 match result {
66 Ok(_) => Ok(()),
67 Err(panic) => {
68 let panic_msg = wireframe::panic::format_panic(&panic);
69 Err(io::Error::new(
70 io::ErrorKind::Other,
71 format!("server task failed: {panic_msg}"),
72 ))
73 }
74 }
75 };
76
77 let client_fut = async {
78 for frame in &frames {
79 client.write_all(frame).await?;
80 }
81 client.shutdown().await?;
82
83 let mut buf = Vec::new();
84 client.read_to_end(&mut buf).await?;
85 io::Result::Ok(buf)
86 };
87
88 let ((), buf) = tokio::try_join!(server_fut, client_fut)?;
89 Ok(buf)
90}
91
92const DEFAULT_CAPACITY: usize = 4096;
93const MAX_CAPACITY: usize = 1024 * 1024 * 10; // 10MB limit
94pub(crate) const EMPTY_SERVER_CAPACITY: usize = 64;
95/// Shared frame cap used by helpers and tests to avoid drift.
96pub const TEST_MAX_FRAME: usize = DEFAULT_CAPACITY;
97
98#[inline]
99pub fn new_test_codec(max_len: usize) -> LengthDelimitedCodec {
100 let mut builder = LengthDelimitedCodec::builder();
101 builder.max_frame_length(max_len);
102 builder.new_codec()
103}
104
105/// Decode all length-prefixed `frames` using a test codec and assert no bytes remain.
106///
107/// This helper constructs a [`LengthDelimitedCodec`] capped at [`TEST_MAX_FRAME`]
108/// and decodes each frame in `bytes`, ensuring the buffer is fully consumed.
109///
110/// ```rust
111/// # use wireframe_testing::decode_frames;
112/// let frames = decode_frames(vec![0, 0, 0, 1, 42]);
113/// assert_eq!(frames, vec![vec![42]]);
114/// ```
115#[must_use]
116pub fn decode_frames(bytes: Vec<u8>) -> Vec<Vec<u8>> {
117 decode_frames_with_max(bytes, TEST_MAX_FRAME)
118}
119
120/// Decode `bytes` into frames using a codec capped at `max_len`.
121///
122/// Asserts that no trailing bytes remain after all frames are decoded.
123#[must_use]
124pub fn decode_frames_with_max(bytes: Vec<u8>, max_len: usize) -> Vec<Vec<u8>> {
125 let mut codec = new_test_codec(max_len);
126 let mut buf = BytesMut::from(&bytes[..]);
127 let mut frames = Vec::new();
128 while let Some(frame) = codec.decode(&mut buf).expect("decode failed") {
129 frames.push(frame.to_vec());
130 }
131 assert!(buf.is_empty(), "unexpected trailing bytes after decode");
132 frames
133}
134
135macro_rules! forward_default {
136 (
137 $(#[$docs:meta])* $vis:vis fn $name:ident(
138 $app:ident : $app_ty:ty,
139 $arg:ident : $arg_ty:ty
140 ) -> $ret:ty
141 => $inner:ident($app_expr:ident, $arg_expr:expr)
142 ) => {
143 $(#[$docs])*
144 $vis async fn $name<S, C, E>(
145 $app: $app_ty,
146 $arg: $arg_ty,
147 ) -> $ret
148 where
149 S: TestSerializer,
150 C: Send + 'static,
151 E: Packet,
152 {
153 $inner($app_expr, $arg_expr, DEFAULT_CAPACITY).await
154 }
155 };
156}
157
158macro_rules! forward_with_capacity {
159 (
160 $(#[$docs:meta])* $vis:vis fn $name:ident(
161 $app:ident : $app_ty:ty,
162 $arg:ident : $arg_ty:ty,
163 capacity: usize
164 ) -> $ret:ty
165 => $inner:ident($app_expr:ident, $arg_expr:expr, capacity)
166 ) => {
167 $(#[$docs])*
168 $vis async fn $name<S, C, E>(
169 $app: $app_ty,
170 $arg: $arg_ty,
171 capacity: usize,
172 ) -> $ret
173 where
174 S: TestSerializer,
175 C: Send + 'static,
176 E: Packet,
177 {
178 $inner($app_expr, $arg_expr, capacity).await
179 }
180 };
181}
182
183/// Drive `app` with a single length-prefixed `frame` and return the bytes
184/// produced by the server.
185///
186/// The app runs on an in-memory duplex stream so tests need not open real
187/// sockets.
188///
189/// # Errors
190///
191/// Returns any I/O errors encountered while interacting with the in-memory
192/// duplex stream.
193///
194/// ```rust
195/// # use wireframe_testing::drive_with_frame;
196/// # use wireframe::app::WireframeApp;
197/// # async fn demo() -> std::io::Result<()> {
198/// let app = WireframeApp::new().expect("failed to initialize app");
199/// let bytes = drive_with_frame(app, vec![1, 2, 3]).await?;
200/// # Ok(())
201/// # }
202/// ```
203pub async fn drive_with_frame<S, C, E>(
204 app: WireframeApp<S, C, E>,
205 frame: Vec<u8>,
206) -> io::Result<Vec<u8>>
207where
208 S: TestSerializer,
209 C: Send + 'static,
210 E: Packet,
211{
212 drive_with_frame_with_capacity(app, frame, DEFAULT_CAPACITY).await
213}
214
215forward_with_capacity! {
216 /// Drive `app` with a single frame using a duplex buffer of `capacity` bytes.
217 ///
218 /// Adjusting the buffer size helps exercise edge cases such as small channels.
219 ///
220 /// ```rust
221 /// # use wireframe_testing::drive_with_frame_with_capacity;
222 /// # use wireframe::app::WireframeApp;
223 /// # async fn demo() -> std::io::Result<()> {
224 /// let app = WireframeApp::new().expect("failed to initialize app");
225 /// let bytes = drive_with_frame_with_capacity(app, vec![0], 512).await?;
226 /// # Ok(())
227 /// # }
228 /// ```
229 pub fn drive_with_frame_with_capacity(app: WireframeApp<S, C, E>, frame: Vec<u8>, capacity: usize) -> io::Result<Vec<u8>>
230 => drive_with_frames_with_capacity(app, vec![frame], capacity)
231}
232
233forward_default! {
234 /// Drive `app` with a sequence of frames using the default buffer size.
235 ///
236 /// Each frame is written to the duplex stream in order.
237 ///
238 /// ```rust
239 /// # use wireframe_testing::drive_with_frames;
240 /// # use wireframe::app::WireframeApp;
241 /// # async fn demo() -> std::io::Result<()> {
242 /// let app = WireframeApp::new().expect("failed to initialize app");
243 /// let out = drive_with_frames(app, vec![vec![1], vec![2]]).await?;
244 /// # Ok(())
245 /// # }
246 /// ```
247 pub fn drive_with_frames(app: WireframeApp<S, C, E>, frames: Vec<Vec<u8>>) -> io::Result<Vec<u8>>
248 => drive_with_frames_with_capacity(app, frames)
249}
250
251/// Drive `app` with multiple frames using a duplex buffer of `capacity` bytes.
252///
253/// This variant exposes the buffer size for fine-grained control in tests.
254///
255/// ```rust
256/// # use wireframe_testing::drive_with_frames_with_capacity;
257/// # use wireframe::app::WireframeApp;
258/// # async fn demo() -> std::io::Result<()> {
259/// let app = WireframeApp::new().expect("failed to initialize app");
260/// let out = drive_with_frames_with_capacity(app, vec![vec![1], vec![2]], 1024).await?;
261/// # Ok(())
262/// # }
263/// ```
264pub async fn drive_with_frames_with_capacity<S, C, E>(
265 app: WireframeApp<S, C, E>,
266 frames: Vec<Vec<u8>>,
267 capacity: usize,
268) -> io::Result<Vec<u8>>
269where
270 S: TestSerializer,
271 C: Send + 'static,
272 E: Packet,
273{
274 drive_internal(
275 |server| async move { app.handle_connection(server).await },
276 frames,
277 capacity,
278 )
279 .await
280}
281
282forward_default! {
283 /// Feed a single frame into a mutable `app`, allowing the instance to be reused
284 /// across calls.
285 ///
286 /// ```rust
287 /// # use wireframe_testing::drive_with_frame_mut;
288 /// # use wireframe::app::WireframeApp;
289 /// # async fn demo() -> std::io::Result<()> {
290 /// let mut app = WireframeApp::new().expect("failed to initialize app");
291 /// let bytes = drive_with_frame_mut(&mut app, vec![1]).await?;
292 /// # Ok(())
293 /// # }
294 /// ```
295 pub fn drive_with_frame_mut(app: &mut WireframeApp<S, C, E>, frame: Vec<u8>) -> io::Result<Vec<u8>>
296 => drive_with_frame_with_capacity_mut(app, frame)
297}
298
299forward_with_capacity! {
300 /// Feed a single frame into `app` using a duplex buffer of `capacity` bytes.
301 ///
302 /// ```rust
303 /// # use wireframe_testing::drive_with_frame_with_capacity_mut;
304 /// # use wireframe::app::WireframeApp;
305 /// # async fn demo() -> std::io::Result<()> {
306 /// let mut app = WireframeApp::new().expect("failed to initialize app");
307 /// let bytes = drive_with_frame_with_capacity_mut(&mut app, vec![1], 256).await?;
308 /// # Ok(())
309 /// # }
310 /// ```
311 pub fn drive_with_frame_with_capacity_mut(app: &mut WireframeApp<S, C, E>, frame: Vec<u8>, capacity: usize) -> io::Result<Vec<u8>>
312 => drive_with_frames_with_capacity_mut(app, vec![frame], capacity)
313}
314
315forward_default! {
316 /// Feed multiple frames into a mutable `app`.
317 ///
318 /// ```rust
319 /// # use wireframe_testing::drive_with_frames_mut;
320 /// # use wireframe::app::WireframeApp;
321 /// # async fn demo() -> std::io::Result<()> {
322 /// let mut app = WireframeApp::new().expect("failed to initialize app");
323 /// let out = drive_with_frames_mut(&mut app, vec![vec![1], vec![2]]).await?;
324 /// # Ok(())
325 /// # }
326 /// ```
327 pub fn drive_with_frames_mut(app: &mut WireframeApp<S, C, E>, frames: Vec<Vec<u8>>) -> io::Result<Vec<u8>>
328 => drive_with_frames_with_capacity_mut(app, frames)
329}
330
331/// Feed multiple frames into `app` with a duplex buffer of `capacity` bytes.
332///
333/// ```rust
334/// # use wireframe_testing::drive_with_frames_with_capacity_mut;
335/// # use wireframe::app::WireframeApp;
336/// # async fn demo() -> std::io::Result<()> {
337/// let mut app = WireframeApp::new().expect("failed to initialize app");
338/// let out = drive_with_frames_with_capacity_mut(&mut app, vec![vec![1], vec![2]], 64).await?;
339/// # Ok(())
340/// # }
341/// ```
342pub async fn drive_with_frames_with_capacity_mut<S, C, E>(
343 app: &mut WireframeApp<S, C, E>,
344 frames: Vec<Vec<u8>>,
345 capacity: usize,
346) -> io::Result<Vec<u8>>
347where
348 S: TestSerializer,
349 C: Send + 'static,
350 E: Packet,
351{
352 drive_internal(
353 |server| async { app.handle_connection(server).await },
354 frames,
355 capacity,
356 )
357 .await
358}
359
360/// Encode `msg` using bincode, frame it and drive `app`.
361///
362/// ```rust
363/// # use wireframe_testing::drive_with_bincode;
364/// # use wireframe::app::WireframeApp;
365/// #[derive(bincode::Encode)]
366/// struct Ping(u8);
367/// # async fn demo() -> std::io::Result<()> {
368/// let app = WireframeApp::new().expect("failed to initialize app");
369/// let bytes = drive_with_bincode(app, Ping(1)).await?;
370/// # Ok(())
371/// # }
372/// ```
373pub async fn drive_with_bincode<M, S, C, E>(
374 app: WireframeApp<S, C, E>,
375 msg: M,
376) -> io::Result<Vec<u8>>
377where
378 M: bincode::Encode,
379 S: TestSerializer,
380 C: Send + 'static,
381 E: Packet,
382{
383 let bytes = bincode::encode_to_vec(msg, config::standard()).map_err(|e| {
384 io::Error::new(
385 io::ErrorKind::InvalidData,
386 format!("bincode encode failed: {e}"),
387 )
388 })?;
389 let mut codec = new_test_codec(DEFAULT_CAPACITY);
390 let mut framed = BytesMut::with_capacity(bytes.len() + 4);
391 codec.encode(bytes.into(), &mut framed)?;
392 drive_with_frame(app, framed.to_vec()).await
393}
394
395/// Run `app` with input `frames` using an optional duplex buffer `capacity`.
396///
397/// When `capacity` is `None`, a buffer of [`DEFAULT_CAPACITY`] bytes is used.
398/// Frames are written to the client side in order and the bytes emitted by the
399/// server are collected for inspection.
400///
401/// # Errors
402///
403/// Returns an error if `capacity` is zero or exceeds [`MAX_CAPACITY`]. Any
404/// panic in the application task or I/O error on the duplex stream is also
405/// surfaced as an error.
406///
407/// ```rust
408/// # use wireframe_testing::run_app;
409/// # use wireframe::app::WireframeApp;
410/// # async fn demo() -> std::io::Result<()> {
411/// let app = WireframeApp::new().expect("failed to initialize app");
412/// let out = run_app(app, vec![vec![1]], None).await?;
413/// # Ok(())
414/// # }
415/// ```
416
417/// Encode bytes with a length-delimited `codec`, preallocating the prefix.
418///
419/// Panics if encoding fails.
420#[must_use]
421pub fn encode_frame(codec: &mut LengthDelimitedCodec, bytes: Vec<u8>) -> Vec<u8> {
422 let header_len = LengthFormat::default().bytes();
423 let mut buf = BytesMut::with_capacity(bytes.len() + header_len);
424 codec.encode(bytes.into(), &mut buf).expect("encode failed");
425 buf.to_vec()
426}
427
428pub async fn run_app<S, C, E>(
429 app: WireframeApp<S, C, E>,
430 frames: Vec<Vec<u8>>,
431 capacity: Option<usize>,
432) -> io::Result<Vec<u8>>
433where
434 S: TestSerializer,
435 C: Send + 'static,
436 E: Packet,
437{
438 let capacity = capacity.unwrap_or(DEFAULT_CAPACITY);
439 if capacity == 0 {
440 return Err(io::Error::new(
441 io::ErrorKind::InvalidInput,
442 "capacity must be greater than zero",
443 ));
444 }
445 if capacity > MAX_CAPACITY {
446 return Err(io::Error::new(
447 io::ErrorKind::InvalidInput,
448 format!("capacity must not exceed {MAX_CAPACITY} bytes"),
449 ));
450 }
451
452 let (mut client, server) = duplex(capacity);
453 let server_task = tokio::spawn(async move { app.handle_connection(server).await });
454
455 for frame in &frames {
456 client.write_all(frame).await?;
457 }
458 client.shutdown().await?;
459
460 let mut buf = Vec::new();
461 client.read_to_end(&mut buf).await?;
462
463 if let Err(e) = server_task.await {
464 return Err(io::Error::new(
465 io::ErrorKind::Other,
466 format!("server task failed: {e}"),
467 ));
468 }
469
470 Ok(buf)
471}
472
473#[cfg(test)]
474mod tests {
475 use wireframe::app::WireframeApp;
476
477 use super::*;
478
479 #[tokio::test]
480 async fn run_app_rejects_zero_capacity() {
481 let app = WireframeApp::new().expect("failed to create app");
482 let err = run_app(app, vec![], Some(0))
483 .await
484 .expect_err("capacity of zero should error");
485 assert_eq!(err.kind(), std::io::ErrorKind::InvalidInput);
486 }
487
488 #[tokio::test]
489 async fn run_app_rejects_excess_capacity() {
490 let app = WireframeApp::new().expect("failed to create app");
491 let err = run_app(app, vec![], Some(MAX_CAPACITY + 1))
492 .await
493 .expect_err("capacity beyond max should error");
494 assert_eq!(err.kind(), std::io::ErrorKind::InvalidInput);
495 }
496}
497
498/// Run `app` against an empty duplex stream.
499///
500/// This helper drives the connection lifecycle without sending any frames,
501/// ensuring setup and teardown callbacks execute.
502///
503/// # Panics
504///
505/// Panics if `handle_connection` fails.
506///
507/// ```rust
508/// # use wireframe_testing::run_with_duplex_server;
509/// # use wireframe::app::WireframeApp;
510/// # async fn demo() {
511/// let app = WireframeApp::new()
512/// .expect("failed to initialize app");
513/// run_with_duplex_server(app).await;
514/// }
515/// ```
516pub async fn run_with_duplex_server<S, C, E>(app: WireframeApp<S, C, E>)
517where
518 S: TestSerializer,
519 C: Send + 'static,
520 E: Packet,
521{
522 let (_, server) = duplex(EMPTY_SERVER_CAPACITY); // discard client half
523 app.handle_connection(server).await;
524}
525
526/// Await the provided future and panic with context on failure.
527///
528/// In debug builds, the generated message includes the call site for easier
529/// troubleshooting.
530#[macro_export]
531macro_rules! push_expect {
532 ($fut:expr) => {{
533 $fut.await
534 .expect(concat!("push failed at ", file!(), ":", line!()))
535 }};
536 ($fut:expr, $msg:expr) => {{
537 let m = ::std::format!("{msg} at {}:{}", file!(), line!(), msg = $msg);
538 $fut.await.expect(&m)
539 }};
540}
541
542/// Await the provided future and panic with context on failure.
543///
544/// In debug builds, the generated message includes the call site for easier
545/// troubleshooting.
546#[macro_export]
547macro_rules! recv_expect {
548 ($fut:expr) => {{
549 $fut.await
550 .expect(concat!("recv failed at ", file!(), ":", line!()))
551 }};
552 ($fut:expr, $msg:expr) => {{
553 let m = ::std::format!("{msg} at {}:{}", file!(), line!(), msg = $msg);
554 $fut.await.expect(&m)
555 }};
556}