zerodds_websocket_bridge/
permessage_deflate.rs1use alloc::string::{String, ToString};
21use alloc::vec::Vec;
22
23pub const DEFLATE_TAIL: [u8; 4] = [0x00, 0x00, 0xff, 0xff];
26
27#[derive(Debug, Clone, Copy, PartialEq, Eq)]
29pub struct PermessageDeflateParams {
30 pub server_no_takeover: bool,
32 pub client_no_takeover: bool,
34 pub server_max_window_bits: u8,
36 pub client_max_window_bits: u8,
38}
39
40impl Default for PermessageDeflateParams {
41 fn default() -> Self {
42 Self {
43 server_no_takeover: false,
44 client_no_takeover: false,
45 server_max_window_bits: 15,
46 client_max_window_bits: 15,
47 }
48 }
49}
50
51#[derive(Debug, Clone, PartialEq, Eq)]
53pub enum NegotiationError {
54 UnknownParam(String),
56 InvalidWindowBits(u8),
58 BooleanWithValue(String),
60}
61
62impl core::fmt::Display for NegotiationError {
63 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
64 match self {
65 Self::UnknownParam(p) => write!(f, "unknown parameter: {p}"),
66 Self::InvalidWindowBits(b) => write!(f, "invalid window_bits: {b}"),
67 Self::BooleanWithValue(p) => write!(f, "boolean param `{p}` has value"),
68 }
69 }
70}
71
72#[cfg(feature = "std")]
73impl std::error::Error for NegotiationError {}
74
75pub fn parse_offer(offer: &str) -> Result<PermessageDeflateParams, NegotiationError> {
82 let mut params = PermessageDeflateParams::default();
83 for part in offer.split(';').skip(1) {
84 let part = part.trim();
85 if part.is_empty() {
86 continue;
87 }
88 if let Some((k, v)) = part.split_once('=') {
89 let k = k.trim();
90 let v = v.trim().trim_matches('"');
91 match k {
92 "server_max_window_bits" => {
93 let bits: u8 = v
94 .parse()
95 .map_err(|_| NegotiationError::InvalidWindowBits(0))?;
96 if !(8..=15).contains(&bits) {
97 return Err(NegotiationError::InvalidWindowBits(bits));
98 }
99 params.server_max_window_bits = bits;
100 }
101 "client_max_window_bits" => {
102 let bits: u8 = v
103 .parse()
104 .map_err(|_| NegotiationError::InvalidWindowBits(0))?;
105 if !(8..=15).contains(&bits) {
106 return Err(NegotiationError::InvalidWindowBits(bits));
107 }
108 params.client_max_window_bits = bits;
109 }
110 "server_no_context_takeover" | "client_no_context_takeover" => {
111 return Err(NegotiationError::BooleanWithValue(k.to_string()));
112 }
113 other => return Err(NegotiationError::UnknownParam(other.to_string())),
114 }
115 } else {
116 match part {
117 "server_no_context_takeover" => params.server_no_takeover = true,
118 "client_no_context_takeover" => params.client_no_takeover = true,
119 "client_max_window_bits" => {
120 params.client_max_window_bits = 15;
122 }
123 other => return Err(NegotiationError::UnknownParam(other.to_string())),
124 }
125 }
126 }
127 Ok(params)
128}
129
130#[must_use]
132pub fn render_accept(params: &PermessageDeflateParams) -> String {
133 let mut s = String::from("permessage-deflate");
134 if params.server_no_takeover {
135 s.push_str("; server_no_context_takeover");
136 }
137 if params.client_no_takeover {
138 s.push_str("; client_no_context_takeover");
139 }
140 if params.server_max_window_bits != 15 {
141 s.push_str(&alloc::format!(
142 "; server_max_window_bits={}",
143 params.server_max_window_bits
144 ));
145 }
146 if params.client_max_window_bits != 15 {
147 s.push_str(&alloc::format!(
148 "; client_max_window_bits={}",
149 params.client_max_window_bits
150 ));
151 }
152 s
153}
154
155#[must_use]
157pub fn append_tail(payload: &[u8]) -> Vec<u8> {
158 let mut out = Vec::with_capacity(payload.len() + 4);
159 out.extend_from_slice(payload);
160 out.extend_from_slice(&DEFLATE_TAIL);
161 out
162}
163
164#[must_use]
166pub fn strip_tail(payload: &[u8]) -> &[u8] {
167 if payload.ends_with(&DEFLATE_TAIL) {
168 &payload[..payload.len() - DEFLATE_TAIL.len()]
169 } else {
170 payload
171 }
172}
173
174#[cfg(test)]
175#[allow(clippy::expect_used, clippy::unwrap_used, clippy::panic)]
176mod tests {
177 use super::*;
178
179 #[test]
180 fn parse_no_params_yields_defaults() {
181 let p = parse_offer("permessage-deflate").unwrap();
182 assert_eq!(p, PermessageDeflateParams::default());
183 }
184
185 #[test]
186 fn parse_no_takeover_flags() {
187 let p = parse_offer(
188 "permessage-deflate; server_no_context_takeover; client_no_context_takeover",
189 )
190 .unwrap();
191 assert!(p.server_no_takeover);
192 assert!(p.client_no_takeover);
193 }
194
195 #[test]
196 fn parse_window_bits() {
197 let p =
198 parse_offer("permessage-deflate; server_max_window_bits=12; client_max_window_bits=10")
199 .unwrap();
200 assert_eq!(p.server_max_window_bits, 12);
201 assert_eq!(p.client_max_window_bits, 10);
202 }
203
204 #[test]
205 fn rejects_invalid_window_bits() {
206 assert!(parse_offer("permessage-deflate; server_max_window_bits=7").is_err());
207 assert!(parse_offer("permessage-deflate; server_max_window_bits=16").is_err());
208 }
209
210 #[test]
211 fn rejects_unknown_param() {
212 assert!(matches!(
213 parse_offer("permessage-deflate; foo"),
214 Err(NegotiationError::UnknownParam(_))
215 ));
216 }
217
218 #[test]
219 fn rejects_boolean_with_value() {
220 assert!(matches!(
221 parse_offer("permessage-deflate; server_no_context_takeover=yes"),
222 Err(NegotiationError::BooleanWithValue(_))
223 ));
224 }
225
226 #[test]
227 fn render_default_is_bare_extension_name() {
228 let s = render_accept(&PermessageDeflateParams::default());
229 assert_eq!(s, "permessage-deflate");
230 }
231
232 #[test]
233 fn render_includes_params() {
234 let p = PermessageDeflateParams {
235 server_no_takeover: true,
236 client_no_takeover: false,
237 server_max_window_bits: 12,
238 client_max_window_bits: 15,
239 };
240 let s = render_accept(&p);
241 assert!(s.contains("server_no_context_takeover"));
242 assert!(s.contains("server_max_window_bits=12"));
243 assert!(!s.contains("client_max_window_bits"));
244 }
245
246 #[test]
247 fn tail_round_trip() {
248 let raw = b"hello";
249 let with_tail = append_tail(raw);
250 assert_eq!(with_tail, b"hello\x00\x00\xff\xff");
251 let stripped = strip_tail(&with_tail);
252 assert_eq!(stripped, raw);
253 }
254
255 #[test]
256 fn strip_tail_no_op_when_absent() {
257 assert_eq!(strip_tail(b"hello"), b"hello");
258 }
259
260 #[test]
261 fn parameterless_client_max_window_bits_accepted() {
262 let p = parse_offer("permessage-deflate; client_max_window_bits").unwrap();
263 assert_eq!(p.client_max_window_bits, 15);
264 }
265}