tokio_websockets/upgrade/
client_request.rs1use std::str::FromStr;
3
4use base64::{Engine, engine::general_purpose::STANDARD};
5use bytes::{Buf, BytesMut};
6use http::{HeaderMap, header::SET_COOKIE};
7use httparse::Request;
8use tokio_util::codec::Decoder;
9
10use crate::{sha::digest, upgrade::Error};
11
12const SWITCHING_PROTOCOLS_BODY: &[u8] = b"HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: ";
15
16fn contains_ignore_ascii_case(mut haystack: &[u8], needle: &[u8]) -> bool {
19 if needle.is_empty() {
20 return true;
21 }
22
23 while haystack.len() >= needle.len() {
24 if haystack[..needle.len()].eq_ignore_ascii_case(needle) {
25 return true;
26 }
27
28 haystack = &haystack[1..];
29 }
30
31 false
32}
33
34struct ClientRequest {
36 ws_accept: [u8; 20],
38}
39
40impl ClientRequest {
41 pub fn parse<'a, F>(header: F) -> Result<Self, Error>
48 where
49 F: Fn(&'static str) -> Option<&'a str> + 'a,
50 {
51 let find_header = |name| header(name).ok_or(super::Error::MissingHeader(name));
52
53 let check_header = |name, expected, err| {
54 let actual = find_header(name)?;
55 if actual.eq_ignore_ascii_case(expected) {
56 Ok(())
57 } else {
58 Err(err)
59 }
60 };
61
62 let check_header_contains = |name, expected: &str, err| {
63 let actual = find_header(name)?;
64 if contains_ignore_ascii_case(actual.as_bytes(), expected.as_bytes()) {
65 Ok(())
66 } else {
67 Err(err)
68 }
69 };
70
71 check_header("Upgrade", "websocket", Error::UpgradeNotWebSocket)?;
72 check_header_contains("Connection", "Upgrade", Error::ConnectionNotUpgrade)?;
73 check_header(
74 "Sec-WebSocket-Version",
75 "13",
76 Error::UnsupportedWebSocketVersion,
77 )?;
78
79 let key = find_header("Sec-WebSocket-Key")?;
80 let ws_accept = digest(key.as_bytes());
81 Ok(Self { ws_accept })
82 }
83
84 #[must_use]
87 pub fn ws_accept(&self) -> String {
88 STANDARD.encode(self.ws_accept)
89 }
90}
91
92pub struct Codec<'a> {
99 pub response_headers: &'a HeaderMap,
101}
102
103impl Decoder for Codec<'_> {
104 type Error = crate::Error;
105 type Item = (http::Request<()>, Vec<u8>);
106
107 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
108 let mut headers = [httparse::EMPTY_HEADER; 64];
109 let mut request = Request::new(&mut headers);
110 let status = request.parse(src).map_err(Error::Parsing)?;
111
112 if !status.is_complete() {
113 return Ok(None);
114 }
115
116 let request_len = status.unwrap();
117
118 let mut builder = http::request::Builder::new();
119 if let Some(m) = request.method {
120 let method =
121 http::method::Method::from_bytes(m.as_bytes()).expect("httparse method is valid");
122 builder = builder.method(method);
123 }
124
125 if let Some(uri) = request.path {
126 builder = builder.uri(uri);
127 }
128
129 match request.version {
130 Some(0) => builder = builder.version(http::Version::HTTP_10),
131 Some(1) => builder = builder.version(http::Version::HTTP_11),
132 _ => Err(Error::Parsing(httparse::Error::Version))?,
133 }
134
135 let mut header_map = http::HeaderMap::with_capacity(request.headers.len());
136
137 for header in request.headers {
138 let name = http::HeaderName::from_str(header.name)
139 .map_err(|_| Error::Parsing(httparse::Error::HeaderName))?;
140 let value = http::HeaderValue::from_bytes(header.value)
141 .map_err(|_| Error::Parsing(httparse::Error::HeaderValue))?;
142
143 header_map.insert(name, value);
144 }
145
146 let mut request = builder
148 .body(())
149 .expect("httparse sees the request as valid");
150 *request.headers_mut() = header_map;
151
152 let ws_accept =
153 ClientRequest::parse(|name| request.headers().get(name).and_then(|h| h.to_str().ok()))?
154 .ws_accept();
155
156 src.advance(request_len);
157
158 let mut resp = Vec::with_capacity(SWITCHING_PROTOCOLS_BODY.len() + ws_accept.len() + 4);
160
161 resp.extend_from_slice(SWITCHING_PROTOCOLS_BODY);
162 resp.extend_from_slice(ws_accept.as_bytes());
163 resp.extend_from_slice(b"\r\n");
164
165 for name in self.response_headers.keys() {
166 let values = self.response_headers.get_all(name).iter();
167
168 if name == SET_COOKIE {
169 for value in values {
172 resp.extend_from_slice(name.as_str().as_bytes());
173 resp.extend_from_slice(b": ");
174 resp.extend_from_slice(value.as_bytes());
175 resp.extend_from_slice(b"\r\n");
176 }
177 } else {
178 resp.extend_from_slice(name.as_str().as_bytes());
180 resp.extend_from_slice(b": ");
181
182 let mut values = values.peekable();
183 while let Some(value) = values.next() {
184 resp.extend_from_slice(value.as_bytes());
185
186 if values.peek().is_some() {
187 resp.push(b',');
188 }
189 }
190
191 resp.extend_from_slice(b"\r\n");
192 }
193 }
194
195 resp.extend_from_slice(b"\r\n");
196
197 Ok(Some((request, resp)))
198 }
199}