1use http;
2use core::slice;
3use std::{
4 ffi::{c_char, c_int, c_uint},
5 mem::{self, transmute, MaybeUninit},
6};
7pub const EXT_ID: &str = "permessage-deflate";
9pub const SERVER_NO_CONTEXT_TAKEOVER: &str = "server_no_context_takeover";
11pub const CLIENT_NO_CONTEXT_TAKEOVER: &str = "client_no_context_takeover";
13pub const SERVER_MAX_WINDOW_BITS: &str = "server_max_window_bits";
15pub const CLIENT_MAX_WINDOW_BITS: &str = "client_max_window_bits";
17
18pub const ZLIB_VERSION: &str = "1.2.13\0";
20
21#[cfg(feature = "sync")]
22mod blocking;
23#[cfg(feature = "sync")]
24pub use blocking::*;
25use libz_sys::{Z_BUF_ERROR, Z_NO_FLUSH, Z_OK, Z_SYNC_FLUSH};
26
27#[cfg(feature = "async")]
28mod non_blocking;
29#[cfg(feature = "async")]
30pub use non_blocking::*;
31
32use crate::{errors::WsError, frame::OpCode};
33
34use super::{
35 default_handshake_handler, FrameConfig, FrameReadState, FrameWriteState, ValidateUtf8Policy,
36};
37
38#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
40#[repr(i8)]
41#[allow(missing_docs)]
42pub enum WindowBit {
43 Eight = 8,
44 Nine = 9,
45 Ten = 10,
46 Eleven = 11,
47 Twelve = 12,
48 Thirteen = 13,
49 Fourteen = 14,
50 Fifteen = 15,
51}
52
53impl TryFrom<u8> for WindowBit {
54 type Error = u8;
55
56 fn try_from(value: u8) -> Result<Self, Self::Error> {
57 if matches!(value, 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15) {
58 let value = unsafe { transmute(value) };
59 Ok(value)
60 } else {
61 Err(value)
62 }
63 }
64}
65
66pub fn deflate_handshake_handler(
68 req: http::Request<()>,
69) -> Result<(http::Request<()>, http::Response<String>), (http::Response<String>, WsError)> {
70 let (req, mut resp) = default_handshake_handler(req)?;
71 let mut configs: Vec<PMDConfig> = vec![];
72 for (k, v) in req.headers() {
73 if k.as_str().to_lowercase() == "sec-websocket-extensions" {
74 if let Ok(s) = v.to_str() {
75 match PMDConfig::parse_str(s) {
76 Ok(mut conf) => {
77 configs.append(&mut conf);
78 }
79 Err(e) => {
80 let resp = http::Response::builder()
81 .version(http::Version::HTTP_11)
82 .status(http::StatusCode::BAD_REQUEST)
83 .header("Content-Type", "text/html")
84 .body(e.clone())
85 .unwrap();
86 return Err((resp, WsError::HandShakeFailed(e)));
87 }
88 }
89 }
90 }
91 }
92 if let Some(config) = configs.pop() {
93 resp.headers_mut().insert(
94 "sec-websocket-extensions",
95 http::HeaderValue::from_str(&config.ext_string()).unwrap(),
96 );
97 }
98 Ok((req, resp))
99}
100
101fn gen_low_level_config(conf: &FrameConfig) -> FrameConfig {
102 FrameConfig {
103 mask_send_frame: conf.mask_send_frame,
104 check_rsv: false,
105 auto_fragment_size: conf.auto_fragment_size,
106 merge_frame: false,
107 validate_utf8: ValidateUtf8Policy::Off,
108 ..Default::default()
109 }
110}
111
112pub struct WriteStreamHandler {
114 pub config: PMDConfig,
116 pub com: ZLibCompressStream,
118}
119
120pub struct ReadStreamHandler {
122 pub config: PMDConfig,
124 pub de: ZLibDeCompressStream,
126}
127
128#[allow(missing_docs)]
130#[derive(Debug, Clone)]
131pub struct PMDConfig {
132 pub server_no_context_takeover: bool,
133 pub client_no_context_takeover: bool,
134 pub server_max_window_bits: WindowBit,
135 pub client_max_window_bits: WindowBit,
136}
137
138impl Default for PMDConfig {
139 fn default() -> Self {
140 Self {
141 server_no_context_takeover: false,
142 client_no_context_takeover: false,
143 server_max_window_bits: WindowBit::Fifteen,
144 client_max_window_bits: WindowBit::Fifteen,
145 }
146 }
147}
148
149impl PMDConfig {
150 pub fn ext_string(&self) -> String {
152 let mut s = format!("{EXT_ID};");
153 if self.client_no_context_takeover {
154 s.push_str(CLIENT_NO_CONTEXT_TAKEOVER);
155 s.push(';');
156 s.push(' ');
157 }
158 if self.server_no_context_takeover {
159 s.push_str(SERVER_NO_CONTEXT_TAKEOVER);
160 s.push(';');
161 s.push(' ');
162 }
163 s.push_str(&format!(
164 "{CLIENT_MAX_WINDOW_BITS}={};",
165 self.client_max_window_bits as u8
166 ));
167 s.push_str(&format!(
168 "{SERVER_MAX_WINDOW_BITS}={}",
169 self.server_max_window_bits as u8
170 ));
171 s
172 }
173
174 pub fn multi_ext_string(configs: &[PMDConfig]) -> String {
176 configs
177 .iter()
178 .map(|conf| conf.ext_string())
179 .collect::<Vec<String>>()
180 .join(", ")
181 }
182}
183
184pub struct ZLibDeCompressStream {
186 stream: Box<libz_sys::z_stream>,
187}
188
189unsafe impl Send for ZLibDeCompressStream {}
190unsafe impl Sync for ZLibDeCompressStream {}
191
192impl Drop for ZLibDeCompressStream {
193 fn drop(&mut self) {
194 match unsafe { libz_sys::inflateEnd(self.stream.as_mut()) } {
195 libz_sys::Z_STREAM_ERROR => {
196 tracing::trace!("decompression stream encountered bad state.")
197 }
198 libz_sys::Z_OK | libz_sys::Z_DATA_ERROR => {
200 tracing::trace!("deallocated compression context.")
201 }
202 code => tracing::trace!("bad zlib status encountered: {}", code),
203 }
204 }
205}
206
207impl ZLibDeCompressStream {
208 pub fn new(window: WindowBit) -> Self {
210 let mut stream: Box<MaybeUninit<libz_sys::z_stream>> = Box::new(MaybeUninit::zeroed());
211 let result = unsafe {
212 libz_sys::inflateInit2_(
213 stream.as_mut_ptr(),
214 -(window as i8) as c_int,
215 ZLIB_VERSION.as_ptr() as *const c_char,
216 mem::size_of::<libz_sys::z_stream>() as c_int,
217 )
218 };
219 assert!(result == libz_sys::Z_OK, "Failed to initialize compresser.");
220 Self {
221 stream: unsafe { Box::from_raw(Box::into_raw(stream) as *mut libz_sys::z_stream) },
222 }
223 }
224
225 pub fn with(stream: Box<libz_sys::z_stream>) -> Self {
227 Self { stream }
228 }
229
230 pub fn de_compress(&mut self, inputs: &[&[u8]], output: &mut Vec<u8>) -> Result<(), c_int> {
232 let total_input: usize = inputs.iter().map(|i| i.len()).sum();
233 if total_input > output.capacity() * 2 + 4 {
234 output.resize(total_input * 2 + 4, 0);
235 }
236 let mut write_idx = 0;
237 let before = self.stream.total_out;
238 for i in inputs {
239 let mut iter_read_idx = 0;
240 loop {
241 unsafe {
242 self.stream.next_in = i.as_ptr().add(iter_read_idx) as *mut _;
243 }
244 self.stream.avail_in = (i.len() - iter_read_idx) as c_uint;
245 if output.capacity() - output.len() <= 0 {
246 output.resize(output.capacity() * 2, 0);
247 }
248 let out_slice = unsafe {
249 slice::from_raw_parts_mut(
250 output.as_mut_ptr().add(write_idx),
251 output.capacity() - write_idx,
252 )
253 };
254 self.stream.next_out = out_slice.as_mut_ptr();
255 self.stream.avail_out = out_slice.len() as c_uint;
256
257 match unsafe { libz_sys::inflate(*&mut self.stream.as_mut(), Z_NO_FLUSH) } {
258 Z_OK | Z_BUF_ERROR => {}
259 code => return Err(code),
260 };
261 iter_read_idx = i.len() - self.stream.avail_in as usize;
262 write_idx = (self.stream.total_out - before) as usize;
263 if self.stream.avail_in == 0 {
264 break;
265 }
266 }
267 }
268 unsafe {
269 match libz_sys::inflate(*&mut self.stream.as_mut(), Z_SYNC_FLUSH) {
270 Z_OK | Z_BUF_ERROR => {}
271 code => return Err(code),
272 }
273 output.set_len((self.stream.total_out - before) as usize);
274 };
275 Ok(())
276 }
277
278 pub fn reset(&mut self) -> Result<(), c_int> {
280 let code = unsafe { libz_sys::inflateReset(self.stream.as_mut()) };
281 match code {
282 Z_OK => Ok(()),
283 code => Err(code),
284 }
285 }
286}
287
288pub struct ZLibCompressStream {
290 stream: Box<libz_sys::z_stream>,
291}
292
293unsafe impl Send for ZLibCompressStream {}
294unsafe impl Sync for ZLibCompressStream {}
295
296impl Drop for ZLibCompressStream {
297 fn drop(&mut self) {
298 match unsafe { libz_sys::deflateEnd(self.stream.as_mut()) } {
299 libz_sys::Z_STREAM_ERROR => {
300 tracing::trace!("compression stream encountered bad state.")
301 }
302 libz_sys::Z_OK | libz_sys::Z_DATA_ERROR => {
304 tracing::trace!("deallocated compression context.")
305 }
306 code => tracing::trace!("bad zlib status encountered: {}", code),
307 }
308 }
309}
310
311impl ZLibCompressStream {
312 pub fn new(window: WindowBit) -> Self {
314 let mut stream: Box<MaybeUninit<libz_sys::z_stream>> = Box::new(MaybeUninit::zeroed());
315 let result = unsafe {
316 libz_sys::deflateInit2_(
317 stream.as_mut_ptr(),
318 9,
319 libz_sys::Z_DEFLATED,
320 -(window as i8) as c_int,
321 9,
322 libz_sys::Z_DEFAULT_STRATEGY,
323 ZLIB_VERSION.as_ptr() as *const c_char,
324 mem::size_of::<libz_sys::z_stream>() as c_int,
325 )
326 };
327 assert!(result == libz_sys::Z_OK, "Failed to initialize compresser.");
328 Self {
329 stream: unsafe { Box::from_raw(Box::into_raw(stream) as *mut libz_sys::z_stream) },
330 }
331 }
332
333 pub fn with(stream: Box<libz_sys::z_stream>) -> Self {
335 Self { stream }
336 }
337
338 pub fn compress(&mut self, inputs: &[&[u8]], output: &mut Vec<u8>) -> Result<(), c_int> {
340 let total_input: usize = inputs.iter().map(|i| i.len()).sum();
341 if total_input > output.capacity() * 2 + 4 {
342 output.resize(total_input * 2 + 4, 0);
343 }
344 let mut write_idx = 0;
345 let mut total_remain = total_input;
346 let before = self.stream.total_out;
347 for i in inputs {
348 let mut iter_read_idx = 0;
349 loop {
350 unsafe {
351 self.stream.next_in = i.as_ptr().add(iter_read_idx) as *mut _;
352 }
353 self.stream.avail_in = (i.len() - iter_read_idx) as c_uint;
354 if output.capacity() - output.len() <= 0 {
355 output.resize(output.len() + total_remain * 2, 0)
356 }
357 let out_slice = unsafe {
358 slice::from_raw_parts_mut(
359 output.as_mut_ptr().add(write_idx),
360 output.capacity() - write_idx,
361 )
362 };
363 self.stream.next_out = out_slice.as_mut_ptr();
364 self.stream.avail_out = out_slice.len() as c_uint;
365
366 match unsafe { libz_sys::deflate(*&mut self.stream.as_mut(), Z_NO_FLUSH) } {
367 libz_sys::Z_OK => {}
368 code => return Err(code),
369 };
370 iter_read_idx = i.len() - self.stream.avail_in as usize;
371 write_idx = (self.stream.total_out - before) as usize;
372 if self.stream.avail_in == 0 {
373 break;
374 }
375 }
376 total_remain -= iter_read_idx;
377 }
378 unsafe {
379 match libz_sys::deflate(*&mut self.stream.as_mut(), Z_SYNC_FLUSH) {
380 Z_OK => {}
381 code => return Err(code),
382 }
383 output.set_len((self.stream.total_out - before) as usize);
384 };
385 Ok(())
386 }
387
388 pub fn reset(&mut self) -> Result<(), c_int> {
390 let code = unsafe { libz_sys::deflateReset(self.stream.as_mut()) };
391 match code {
392 Z_OK => Ok(()),
393 code => Err(code),
394 }
395 }
396}
397
398#[derive(Default)]
399struct PMDParamCounter {
400 server_no_context_takeover: bool,
401 client_no_context_takeover: bool,
402 server_max_window_bits: bool,
403 client_max_window_bits: bool,
404}
405
406impl PMDConfig {
407 pub fn parse_str(source: &str) -> Result<Vec<Self>, String> {
409 let lines = source.split("\r\n").count();
410 if lines > 2 {
411 return Err("should not contain multi line".to_string());
412 }
413 let mut configs = vec![];
414 for part in source.split(',') {
415 if part.trim_start().to_lowercase().starts_with(EXT_ID) {
416 let mut conf = Self::default();
417 let mut counter = PMDParamCounter::default();
418 for param in part.split(';').skip(1) {
419 let lower = param.trim().to_lowercase();
420 if lower.starts_with(SERVER_NO_CONTEXT_TAKEOVER) {
421 if counter.server_no_context_takeover {
422 return Err(format!(
423 "got multiple {SERVER_NO_CONTEXT_TAKEOVER} params"
424 ));
425 }
426 if lower.len() != SERVER_NO_CONTEXT_TAKEOVER.len() {
427 return Err(format!(
428 "{SERVER_NO_CONTEXT_TAKEOVER} does not expect param"
429 ));
430 }
431 conf.server_no_context_takeover = true;
432 counter.server_no_context_takeover = true;
433 continue;
434 }
435
436 if lower.starts_with(CLIENT_NO_CONTEXT_TAKEOVER) {
437 if counter.client_no_context_takeover {
438 return Err(format!(
439 "got multiple {CLIENT_NO_CONTEXT_TAKEOVER} params"
440 ));
441 }
442 if lower.len() != CLIENT_NO_CONTEXT_TAKEOVER.len() {
443 return Err(format!(
444 "{CLIENT_NO_CONTEXT_TAKEOVER} does not expect param"
445 ));
446 }
447 conf.client_no_context_takeover = true;
448 counter.client_no_context_takeover = true;
449 continue;
450 }
451
452 if lower.starts_with(SERVER_MAX_WINDOW_BITS) {
453 if counter.server_max_window_bits {
454 return Err(format!("got multiple {SERVER_MAX_WINDOW_BITS} params"));
455 }
456
457 if lower != SERVER_MAX_WINDOW_BITS {
458 let remain = lower.trim_start_matches(SERVER_MAX_WINDOW_BITS);
459 if !remain.trim_start().starts_with('=') {
460 return Err("invalid param value".to_string());
461 }
462 let remain = remain.trim_start().trim_matches('=');
463 let size = match remain.parse::<u8>() {
464 Ok(size) => WindowBit::try_from(size)
465 .map_err(|e| format!("invalid param value {e}"))?,
466 Err(e) => return Err(format!("invalid param value {e}")),
467 };
468 conf.server_max_window_bits = size;
469 }
470 counter.server_max_window_bits = true;
471 continue;
472 }
473
474 if lower.starts_with(CLIENT_MAX_WINDOW_BITS) {
475 if counter.client_max_window_bits {
476 return Err(format!("got multiple {CLIENT_MAX_WINDOW_BITS} params"));
477 }
478
479 if lower != CLIENT_MAX_WINDOW_BITS {
480 let remain = lower.trim_start_matches(CLIENT_MAX_WINDOW_BITS);
481 if !remain.trim_start().starts_with('=') {
482 return Err("invalid param value".to_string());
483 }
484 let remain = remain.trim_start().trim_matches('=');
485 let size = match remain.parse::<u8>() {
486 Ok(size) => WindowBit::try_from(size)
487 .map_err(|e| format!("invalid param value {e}"))?,
488 Err(e) => return Err(format!("invalid param value {e}")),
489 };
490 conf.client_max_window_bits = size;
491 }
492 counter.client_max_window_bits = true;
493 continue;
494 }
495 return Err(format!("unknown param {param}"));
496 }
497 configs.push(conf);
498 }
499 }
500 Ok(configs)
501 }
502}
503
504pub struct DeflateWriteState {
506 write_state: FrameWriteState,
507 com: Option<WriteStreamHandler>,
508 config: FrameConfig,
509 header_buf: [u8; 14],
510 is_server: bool,
511}
512
513impl DeflateWriteState {
514 pub fn with_config(
516 frame_config: FrameConfig,
517 pmd_config: Option<PMDConfig>,
518 is_server: bool,
519 ) -> Self {
520 let low_level_config = gen_low_level_config(&frame_config);
521 let write_state = FrameWriteState::with_config(low_level_config);
522 let com = if let Some(config) = pmd_config {
523 let com_size = if is_server {
524 config.client_max_window_bits
525 } else {
526 config.server_max_window_bits
527 };
528 let com = ZLibCompressStream::new(com_size);
529 Some(WriteStreamHandler { config, com })
530 } else {
531 None
532 };
533 Self {
534 write_state,
535 com,
536 config: frame_config,
537 header_buf: [0; 14],
538 is_server,
539 }
540 }
541}
542
543pub struct DeflateReadState {
545 read_state: FrameReadState,
546 de: Option<ReadStreamHandler>,
547 config: FrameConfig,
548 fragmented: bool,
549 fragmented_data: Vec<u8>,
550 control_buf: Vec<u8>,
551 fragmented_type: OpCode,
552 is_server: bool,
553}
554
555impl DeflateReadState {
556 pub fn with_config(
558 frame_config: FrameConfig,
559 pmd_config: Option<PMDConfig>,
560 is_server: bool,
561 ) -> Self {
562 let low_level_config = gen_low_level_config(&frame_config);
563 let read_state = FrameReadState::with_config(low_level_config);
564 let de = if let Some(config) = pmd_config {
565 let de_size = if is_server {
566 config.client_max_window_bits
567 } else {
568 config.server_max_window_bits
569 };
570 let de = ZLibDeCompressStream::new(de_size);
571 Some(ReadStreamHandler { config, de })
572 } else {
573 None
574 };
575 Self {
576 read_state,
577 de,
578 config: frame_config,
579 fragmented: false,
580 fragmented_data: vec![],
581 control_buf: vec![],
582 fragmented_type: OpCode::Binary,
583 is_server,
584 }
585 }
586}