xwt_tests/tests/
closed_bi_send_stream.rs

1use std::num::NonZeroUsize;
2
3use xwt_core::prelude::*;
4
5#[derive(Debug, thiserror::Error)]
6pub enum Error<Endpoint>
7where
8    Endpoint: xwt_core::endpoint::Connect + std::fmt::Debug,
9    Endpoint::Connecting: std::fmt::Debug,
10    ConnectSessionFor<Endpoint>: xwt_core::session::stream::OpenBi + std::fmt::Debug,
11{
12    #[error("connect: {0}")]
13    Connect(#[source] xwt_error::Connect<Endpoint>),
14    #[error("open bi stream: {0}")]
15    OpenBiStream(#[source] BiStreamOpenErrorFor<ConnectSessionFor<Endpoint>>),
16    #[error("opening bi stream: {0}")]
17    OpeningBiStream(#[source] BiStreamOpeningErrorFor<ConnectSessionFor<Endpoint>>),
18    #[error("read stream: {0}")]
19    ReadStream(#[source] ReadErrorFor<RecvStreamFor<ConnectSessionFor<Endpoint>>>),
20    #[error("a read was successful while we expected it to abort (read {0} bytes)")]
21    ReadDidNotFail(NonZeroUsize),
22    #[error("error code conversion to u32 failed")]
23    ErrorCodeConversion,
24    #[error("error code mismatch: got code {0}")]
25    ErrorCodeMismatch(u32),
26}
27
28pub async fn run<Endpoint>(
29    endpoint: Endpoint,
30    url: &str,
31    expected_error_code: u32,
32) -> Result<(), Error<Endpoint>>
33where
34    Endpoint: xwt_core::endpoint::Connect + std::fmt::Debug,
35    Endpoint::Connecting: std::fmt::Debug,
36    ConnectSessionFor<Endpoint>: xwt_core::session::stream::OpenBi + std::fmt::Debug,
37    RecvStreamFor<ConnectSessionFor<Endpoint>>: xwt_core::stream::Read,
38{
39    let session = crate::utils::connect(&endpoint, url)
40        .await
41        .map_err(Error::Connect)?;
42
43    let opening = session.open_bi().await.map_err(Error::OpenBiStream)?;
44    let (_send_stream, mut recv_stream) =
45        opening.wait_bi().await.map_err(Error::OpeningBiStream)?;
46
47    let mut buf = [0u8; 1];
48
49    let error_code = match recv_stream.read(&mut buf).await {
50        Ok(len) => return Err(Error::ReadDidNotFail(len)),
51        Err(err) => match err.as_error_code() {
52            Some(error_code) => error_code
53                .try_into()
54                .map_err(|_| Error::ErrorCodeConversion)?,
55            None => return Err(Error::ReadStream(err)),
56        },
57    };
58
59    if error_code != expected_error_code {
60        return Err(Error::ErrorCodeMismatch(error_code));
61    }
62
63    Ok(())
64}