xwt_tests/tests/
closed_bi_send_stream.rs

1use std::num::NonZeroUsize;
2
3use xwt_core::prelude::*;
4
5#[derive(Debug, thiserror::Error)]
6pub enum Error<Endpoint>
7where
8    Endpoint: xwt_core::endpoint::Connect + std::fmt::Debug,
9    Endpoint::Connecting: std::fmt::Debug,
10    ConnectSessionFor<Endpoint>: xwt_core::session::stream::OpenBi + std::fmt::Debug,
11{
12    #[error("connect: {0}")]
13    Connect(#[source] xwt_error::Connect<Endpoint>),
14    #[error("open bi stream: {0}")]
15    OpenBiStream(#[source] BiStreamOpenErrorFor<ConnectSessionFor<Endpoint>>),
16    #[error("opening bi stream: {0}")]
17    OpeningBiStream(#[source] BiStreamOpeningErrorFor<ConnectSessionFor<Endpoint>>),
18    #[error("read stream: {0}")]
19    ReadStream(#[source] ReadErrorFor<RecvStreamFor<ConnectSessionFor<Endpoint>>>),
20    #[error("a read was successful while we expected it to abort (read {0} bytes)")]
21    ReadDidNotFail(NonZeroUsize),
22    #[error("error code mismatch: got code {0}")]
23    ErrorCodeMismatch(xwt_core::stream::ErrorCode),
24}
25
26pub async fn run<Endpoint>(
27    endpoint: Endpoint,
28    url: &str,
29    expected_error_code: xwt_core::stream::ErrorCode,
30) -> Result<(), Error<Endpoint>>
31where
32    Endpoint: xwt_core::endpoint::Connect + std::fmt::Debug,
33    Endpoint::Connecting: std::fmt::Debug,
34    ConnectSessionFor<Endpoint>: xwt_core::session::stream::OpenBi + std::fmt::Debug,
35    RecvStreamFor<ConnectSessionFor<Endpoint>>: xwt_core::stream::Read,
36{
37    let session = crate::utils::connect(&endpoint, url)
38        .await
39        .map_err(Error::Connect)?;
40
41    let opening = session.open_bi().await.map_err(Error::OpenBiStream)?;
42    let (_send_stream, mut recv_stream) =
43        opening.wait_bi().await.map_err(Error::OpeningBiStream)?;
44
45    let mut buf = [0u8; 1];
46
47    let error_code = match recv_stream.read(&mut buf).await {
48        Ok(len) => return Err(Error::ReadDidNotFail(len)),
49        Err(err) => match err.as_error_code() {
50            Some(error_code) => error_code,
51            None => return Err(Error::ReadStream(err)),
52        },
53    };
54
55    if error_code != expected_error_code {
56        return Err(Error::ErrorCodeMismatch(error_code));
57    }
58
59    Ok(())
60}