1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
use super::*;
use http::FmtHeaderField;

fn parse_ws_uri(uri: &str) -> std::result::Result<(bool, &str, &str), &'static str> {
    let err_msg = "Invalid Websocket URI";
    let (schema, uri) = uri.split_once("://").ok_or(err_msg)?;
    let secure = if schema.eq_ignore_ascii_case("ws") {
        false
    } else if schema.eq_ignore_ascii_case("wss") {
        true
    } else {
        return Err(err_msg);
    };
    let (addr, path) = uri.split_once('/').unwrap_or((uri, ""));
    Ok((secure, addr, path))
}

impl WebSocket<CLIENT> {
    pub async fn connect(uri: impl AsRef<str>) -> Result<Self> {
        Self::connect_with_headers(uri, [("", ""); 0]).await
    }

    pub async fn connect_with_headers(
        uri: impl AsRef<str>,
        headers: impl IntoIterator<Item = impl FmtHeaderField>,
    ) -> Result<Self> {
        let (secure, addr, path) = parse_ws_uri(uri.as_ref()).map_err(invalid_input)?;
        let port = if addr.contains(':') {
            ""
        } else {
            match secure {
                true => ":443",
                false => ":80",
            }
        };

        let mut stream = BufReader::new(TcpStream::connect(format!("{addr}{port}")).await?);

        let (request, sec_key) = handshake::request(addr, path, headers);
        stream.get_mut().write_all(request.as_bytes()).await?;

        let mut bytes = stream.fill_buf().await?;
        let total_len = bytes.len();

        let header = http::Record::from_raw(&mut bytes).map_err(invalid_data)?;
        if header.schema != b"HTTP/1.1 101 Switching Protocols" {
            return proto_err("Invalid HTTP response");
        }

        if header
            .get_sec_ws_accept()
            .ok_or_else(|| invalid_data("Couldn't get `Accept-Key` from response"))?
            != handshake::accept_key_from(sec_key).as_bytes()
        {
            return proto_err("Invalid accept key");
        }

        let remaining = bytes.len();
        stream.consume(total_len - remaining);

        Ok(Self {
            stream,
            len: 0,
            fin: true,
            on_event: Box::new(|_| Ok(())),
        })
    }

    pub async fn recv(&mut self) -> Result<Data> {
        let ty = cls_if_err!(self, self.read_data_frame_header().await)?;
        Ok(client::Data { ty, ws: self })
    }
}

pub struct Data<'a> {
    pub ty: DataType,
    pub(crate) ws: &'a mut WebSocket<CLIENT>,
}

impl Data<'_> {
    async fn _read_next_frag(&mut self) -> Result<()> {
        self.ws.read_fragmented_header().await
    }

    #[inline]
    async fn _read(&mut self, buf: &mut [u8]) -> Result<usize> {
        let amt = read_bytes(
            &mut self.ws.stream,
            buf.len().min(self.ws.len),
            |bytes| unsafe {
                std::ptr::copy_nonoverlapping(bytes.as_ptr(), buf.as_mut_ptr(), bytes.len());
            },
        )
        .await?;
        self.ws.len -= amt;
        Ok(amt)
    }
}

default_impl_for_data!();