xwt_tests/tests/
closed_bi_send_stream.rs1use std::num::NonZeroUsize;
2
3use xwt_core::prelude::*;
4
5#[derive(Debug, thiserror::Error)]
6pub enum Error<Endpoint>
7where
8 Endpoint: xwt_core::endpoint::Connect + std::fmt::Debug,
9 Endpoint::Connecting: std::fmt::Debug,
10 ConnectSessionFor<Endpoint>: xwt_core::session::stream::OpenBi + std::fmt::Debug,
11{
12 #[error("connect: {0}")]
13 Connect(#[source] xwt_error::Connect<Endpoint>),
14 #[error("open bi stream: {0}")]
15 OpenBiStream(#[source] BiStreamOpenErrorFor<ConnectSessionFor<Endpoint>>),
16 #[error("opening bi stream: {0}")]
17 OpeningBiStream(#[source] BiStreamOpeningErrorFor<ConnectSessionFor<Endpoint>>),
18 #[error("read stream: {0}")]
19 ReadStream(#[source] ReadErrorFor<RecvStreamFor<ConnectSessionFor<Endpoint>>>),
20 #[error("a read was successful while we expected it to abort (read {0} bytes)")]
21 ReadDidNotFail(NonZeroUsize),
22 #[error("error code conversion to u32 failed")]
23 ErrorCodeConversion,
24 #[error("error code mismatch: got code {0}")]
25 ErrorCodeMismatch(u32),
26}
27
28pub async fn run<Endpoint>(
29 endpoint: Endpoint,
30 url: &str,
31 expected_error_code: u32,
32) -> Result<(), Error<Endpoint>>
33where
34 Endpoint: xwt_core::endpoint::Connect + std::fmt::Debug,
35 Endpoint::Connecting: std::fmt::Debug,
36 ConnectSessionFor<Endpoint>: xwt_core::session::stream::OpenBi + std::fmt::Debug,
37 RecvStreamFor<ConnectSessionFor<Endpoint>>: xwt_core::stream::Read,
38{
39 let session = crate::utils::connect(&endpoint, url)
40 .await
41 .map_err(Error::Connect)?;
42
43 let opening = session.open_bi().await.map_err(Error::OpenBiStream)?;
44 let (_send_stream, mut recv_stream) =
45 opening.wait_bi().await.map_err(Error::OpeningBiStream)?;
46
47 let mut buf = [0u8; 1];
48
49 let error_code = match recv_stream.read(&mut buf).await {
50 Ok(len) => return Err(Error::ReadDidNotFail(len)),
51 Err(err) => match err.as_error_code() {
52 Some(error_code) => error_code
53 .try_into()
54 .map_err(|_| Error::ErrorCodeConversion)?,
55 None => return Err(Error::ReadStream(err)),
56 },
57 };
58
59 if error_code != expected_error_code {
60 return Err(Error::ErrorCodeMismatch(error_code));
61 }
62
63 Ok(())
64}