ws_tool/codec/deflate/
mod.rs

1use http;
2use core::slice;
3use std::{
4    ffi::{c_char, c_int, c_uint},
5    mem::{self, transmute, MaybeUninit},
6};
7/// permessage-deflate id
8pub const EXT_ID: &str = "permessage-deflate";
9/// server_no_context_takeover param
10pub const SERVER_NO_CONTEXT_TAKEOVER: &str = "server_no_context_takeover";
11/// client_no_context_takeover param
12pub const CLIENT_NO_CONTEXT_TAKEOVER: &str = "client_no_context_takeover";
13/// server_max_window_bits param
14pub const SERVER_MAX_WINDOW_BITS: &str = "server_max_window_bits";
15/// client_max_window_bits param
16pub const CLIENT_MAX_WINDOW_BITS: &str = "client_max_window_bits";
17
18/// zlib version
19pub const ZLIB_VERSION: &str = "1.2.13\0";
20
21#[cfg(feature = "sync")]
22mod blocking;
23#[cfg(feature = "sync")]
24pub use blocking::*;
25use libz_sys::{Z_BUF_ERROR, Z_NO_FLUSH, Z_OK, Z_SYNC_FLUSH};
26
27#[cfg(feature = "async")]
28mod non_blocking;
29#[cfg(feature = "async")]
30pub use non_blocking::*;
31
32use crate::{errors::WsError, frame::OpCode};
33
34use super::{
35    default_handshake_handler, FrameConfig, FrameReadState, FrameWriteState, ValidateUtf8Policy,
36};
37
38/// permessage-deflate window bit
39#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
40#[repr(i8)]
41#[allow(missing_docs)]
42pub enum WindowBit {
43    Eight = 8,
44    Nine = 9,
45    Ten = 10,
46    Eleven = 11,
47    Twelve = 12,
48    Thirteen = 13,
49    Fourteen = 14,
50    Fifteen = 15,
51}
52
53impl TryFrom<u8> for WindowBit {
54    type Error = u8;
55
56    fn try_from(value: u8) -> Result<Self, Self::Error> {
57        if matches!(value, 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15) {
58            let value = unsafe { transmute(value) };
59            Ok(value)
60        } else {
61            Err(value)
62        }
63    }
64}
65
66/// permessage-deflate req handler
67pub fn deflate_handshake_handler(
68    req: http::Request<()>,
69) -> Result<(http::Request<()>, http::Response<String>), (http::Response<String>, WsError)> {
70    let (req, mut resp) = default_handshake_handler(req)?;
71    let mut configs: Vec<PMDConfig> = vec![];
72    for (k, v) in req.headers() {
73        if k.as_str().to_lowercase() == "sec-websocket-extensions" {
74            if let Ok(s) = v.to_str() {
75                match PMDConfig::parse_str(s) {
76                    Ok(mut conf) => {
77                        configs.append(&mut conf);
78                    }
79                    Err(e) => {
80                        let resp = http::Response::builder()
81                            .version(http::Version::HTTP_11)
82                            .status(http::StatusCode::BAD_REQUEST)
83                            .header("Content-Type", "text/html")
84                            .body(e.clone())
85                            .unwrap();
86                        return Err((resp, WsError::HandShakeFailed(e)));
87                    }
88                }
89            }
90        }
91    }
92    if let Some(config) = configs.pop() {
93        resp.headers_mut().insert(
94            "sec-websocket-extensions",
95            http::HeaderValue::from_str(&config.ext_string()).unwrap(),
96        );
97    }
98    Ok((req, resp))
99}
100
101fn gen_low_level_config(conf: &FrameConfig) -> FrameConfig {
102    FrameConfig {
103        mask_send_frame: conf.mask_send_frame,
104        check_rsv: false,
105        auto_fragment_size: conf.auto_fragment_size,
106        merge_frame: false,
107        validate_utf8: ValidateUtf8Policy::Off,
108        ..Default::default()
109    }
110}
111
112/// helper struct to handler com stream
113pub struct WriteStreamHandler {
114    /// permessage deflate config
115    pub config: PMDConfig,
116    /// compressor
117    pub com: ZLibCompressStream,
118}
119
120/// helper struct to handle de stream
121pub struct ReadStreamHandler {
122    /// permessage deflate config
123    pub config: PMDConfig,
124    /// decompressor
125    pub de: ZLibDeCompressStream,
126}
127
128/// permessage-deflate
129#[allow(missing_docs)]
130#[derive(Debug, Clone)]
131pub struct PMDConfig {
132    pub server_no_context_takeover: bool,
133    pub client_no_context_takeover: bool,
134    pub server_max_window_bits: WindowBit,
135    pub client_max_window_bits: WindowBit,
136}
137
138impl Default for PMDConfig {
139    fn default() -> Self {
140        Self {
141            server_no_context_takeover: false,
142            client_no_context_takeover: false,
143            server_max_window_bits: WindowBit::Fifteen,
144            client_max_window_bits: WindowBit::Fifteen,
145        }
146    }
147}
148
149impl PMDConfig {
150    /// get extension string
151    pub fn ext_string(&self) -> String {
152        let mut s = format!("{EXT_ID};");
153        if self.client_no_context_takeover {
154            s.push_str(CLIENT_NO_CONTEXT_TAKEOVER);
155            s.push(';');
156            s.push(' ');
157        }
158        if self.server_no_context_takeover {
159            s.push_str(SERVER_NO_CONTEXT_TAKEOVER);
160            s.push(';');
161            s.push(' ');
162        }
163        s.push_str(&format!(
164            "{CLIENT_MAX_WINDOW_BITS}={};",
165            self.client_max_window_bits as u8
166        ));
167        s.push_str(&format!(
168            "{SERVER_MAX_WINDOW_BITS}={}",
169            self.server_max_window_bits as u8
170        ));
171        s
172    }
173
174    /// helper function to build multi permessage deflate config header
175    pub fn multi_ext_string(configs: &[PMDConfig]) -> String {
176        configs
177            .iter()
178            .map(|conf| conf.ext_string())
179            .collect::<Vec<String>>()
180            .join(", ")
181    }
182}
183
184///
185pub struct ZLibDeCompressStream {
186    stream: Box<libz_sys::z_stream>,
187}
188
189unsafe impl Send for ZLibDeCompressStream {}
190unsafe impl Sync for ZLibDeCompressStream {}
191
192impl Drop for ZLibDeCompressStream {
193    fn drop(&mut self) {
194        match unsafe { libz_sys::inflateEnd(self.stream.as_mut()) } {
195            libz_sys::Z_STREAM_ERROR => {
196                tracing::trace!("decompression stream encountered bad state.")
197            }
198            // Ignore discarded data error because we are raw
199            libz_sys::Z_OK | libz_sys::Z_DATA_ERROR => {
200                tracing::trace!("deallocated compression context.")
201            }
202            code => tracing::trace!("bad zlib status encountered: {}", code),
203        }
204    }
205}
206
207impl ZLibDeCompressStream {
208    /// construct new compress stream
209    pub fn new(window: WindowBit) -> Self {
210        let mut stream: Box<MaybeUninit<libz_sys::z_stream>> = Box::new(MaybeUninit::zeroed());
211        let result = unsafe {
212            libz_sys::inflateInit2_(
213                stream.as_mut_ptr(),
214                -(window as i8) as c_int,
215                ZLIB_VERSION.as_ptr() as *const c_char,
216                mem::size_of::<libz_sys::z_stream>() as c_int,
217            )
218        };
219        assert!(result == libz_sys::Z_OK, "Failed to initialize compresser.");
220        Self {
221            stream: unsafe { Box::from_raw(Box::into_raw(stream) as *mut libz_sys::z_stream) },
222        }
223    }
224
225    /// construct with custom stream
226    pub fn with(stream: Box<libz_sys::z_stream>) -> Self {
227        Self { stream }
228    }
229
230    /// decompress data
231    pub fn de_compress(&mut self, inputs: &[&[u8]], output: &mut Vec<u8>) -> Result<(), c_int> {
232        let total_input: usize = inputs.iter().map(|i| i.len()).sum();
233        if total_input > output.capacity() * 2 + 4 {
234            output.resize(total_input * 2 + 4, 0);
235        }
236        let mut write_idx = 0;
237        let before = self.stream.total_out;
238        for i in inputs {
239            let mut iter_read_idx = 0;
240            loop {
241                unsafe {
242                    self.stream.next_in = i.as_ptr().add(iter_read_idx) as *mut _;
243                }
244                self.stream.avail_in = (i.len() - iter_read_idx) as c_uint;
245                if output.capacity() - output.len() <= 0 {
246                    output.resize(output.capacity() * 2, 0);
247                }
248                let out_slice = unsafe {
249                    slice::from_raw_parts_mut(
250                        output.as_mut_ptr().add(write_idx),
251                        output.capacity() - write_idx,
252                    )
253                };
254                self.stream.next_out = out_slice.as_mut_ptr();
255                self.stream.avail_out = out_slice.len() as c_uint;
256
257                match unsafe { libz_sys::inflate(*&mut self.stream.as_mut(), Z_NO_FLUSH) } {
258                    Z_OK | Z_BUF_ERROR => {}
259                    code => return Err(code),
260                };
261                iter_read_idx = i.len() - self.stream.avail_in as usize;
262                write_idx = (self.stream.total_out - before) as usize;
263                if self.stream.avail_in == 0 {
264                    break;
265                }
266            }
267        }
268        unsafe {
269            match libz_sys::inflate(*&mut self.stream.as_mut(), Z_SYNC_FLUSH) {
270                Z_OK | Z_BUF_ERROR => {}
271                code => return Err(code),
272            }
273            output.set_len((self.stream.total_out - before) as usize);
274        };
275        Ok(())
276    }
277
278    /// reset stream state
279    pub fn reset(&mut self) -> Result<(), c_int> {
280        let code = unsafe { libz_sys::inflateReset(self.stream.as_mut()) };
281        match code {
282            Z_OK => Ok(()),
283            code => Err(code),
284        }
285    }
286}
287
288/// zlib compress stream
289pub struct ZLibCompressStream {
290    stream: Box<libz_sys::z_stream>,
291}
292
293unsafe impl Send for ZLibCompressStream {}
294unsafe impl Sync for ZLibCompressStream {}
295
296impl Drop for ZLibCompressStream {
297    fn drop(&mut self) {
298        match unsafe { libz_sys::deflateEnd(self.stream.as_mut()) } {
299            libz_sys::Z_STREAM_ERROR => {
300                tracing::trace!("compression stream encountered bad state.")
301            }
302            // Ignore discarded data error because we are raw
303            libz_sys::Z_OK | libz_sys::Z_DATA_ERROR => {
304                tracing::trace!("deallocated compression context.")
305            }
306            code => tracing::trace!("bad zlib status encountered: {}", code),
307        }
308    }
309}
310
311impl ZLibCompressStream {
312    /// construct with window bit
313    pub fn new(window: WindowBit) -> Self {
314        let mut stream: Box<MaybeUninit<libz_sys::z_stream>> = Box::new(MaybeUninit::zeroed());
315        let result = unsafe {
316            libz_sys::deflateInit2_(
317                stream.as_mut_ptr(),
318                9,
319                libz_sys::Z_DEFLATED,
320                -(window as i8) as c_int,
321                9,
322                libz_sys::Z_DEFAULT_STRATEGY,
323                ZLIB_VERSION.as_ptr() as *const c_char,
324                mem::size_of::<libz_sys::z_stream>() as c_int,
325            )
326        };
327        assert!(result == libz_sys::Z_OK, "Failed to initialize compresser.");
328        Self {
329            stream: unsafe { Box::from_raw(Box::into_raw(stream) as *mut libz_sys::z_stream) },
330        }
331    }
332
333    /// construct with custom stream
334    pub fn with(stream: Box<libz_sys::z_stream>) -> Self {
335        Self { stream }
336    }
337
338    /// compress data
339    pub fn compress(&mut self, inputs: &[&[u8]], output: &mut Vec<u8>) -> Result<(), c_int> {
340        let total_input: usize = inputs.iter().map(|i| i.len()).sum();
341        if total_input > output.capacity() * 2 + 4 {
342            output.resize(total_input * 2 + 4, 0);
343        }
344        let mut write_idx = 0;
345        let mut total_remain = total_input;
346        let before = self.stream.total_out;
347        for i in inputs {
348            let mut iter_read_idx = 0;
349            loop {
350                unsafe {
351                    self.stream.next_in = i.as_ptr().add(iter_read_idx) as *mut _;
352                }
353                self.stream.avail_in = (i.len() - iter_read_idx) as c_uint;
354                if output.capacity() - output.len() <= 0 {
355                    output.resize(output.len() + total_remain * 2, 0)
356                }
357                let out_slice = unsafe {
358                    slice::from_raw_parts_mut(
359                        output.as_mut_ptr().add(write_idx),
360                        output.capacity() - write_idx,
361                    )
362                };
363                self.stream.next_out = out_slice.as_mut_ptr();
364                self.stream.avail_out = out_slice.len() as c_uint;
365
366                match unsafe { libz_sys::deflate(*&mut self.stream.as_mut(), Z_NO_FLUSH) } {
367                    libz_sys::Z_OK => {}
368                    code => return Err(code),
369                };
370                iter_read_idx = i.len() - self.stream.avail_in as usize;
371                write_idx = (self.stream.total_out - before) as usize;
372                if self.stream.avail_in == 0 {
373                    break;
374                }
375            }
376            total_remain -= iter_read_idx;
377        }
378        unsafe {
379            match libz_sys::deflate(*&mut self.stream.as_mut(), Z_SYNC_FLUSH) {
380                Z_OK => {}
381                code => return Err(code),
382            }
383            output.set_len((self.stream.total_out - before) as usize);
384        };
385        Ok(())
386    }
387
388    /// reset stream state
389    pub fn reset(&mut self) -> Result<(), c_int> {
390        let code = unsafe { libz_sys::deflateReset(self.stream.as_mut()) };
391        match code {
392            Z_OK => Ok(()),
393            code => Err(code),
394        }
395    }
396}
397
398#[derive(Default)]
399struct PMDParamCounter {
400    server_no_context_takeover: bool,
401    client_no_context_takeover: bool,
402    server_max_window_bits: bool,
403    client_max_window_bits: bool,
404}
405
406impl PMDConfig {
407    /// case-insensitive parse one line header
408    pub fn parse_str(source: &str) -> Result<Vec<Self>, String> {
409        let lines = source.split("\r\n").count();
410        if lines > 2 {
411            return Err("should not contain multi line".to_string());
412        }
413        let mut configs = vec![];
414        for part in source.split(',') {
415            if part.trim_start().to_lowercase().starts_with(EXT_ID) {
416                let mut conf = Self::default();
417                let mut counter = PMDParamCounter::default();
418                for param in part.split(';').skip(1) {
419                    let lower = param.trim().to_lowercase();
420                    if lower.starts_with(SERVER_NO_CONTEXT_TAKEOVER) {
421                        if counter.server_no_context_takeover {
422                            return Err(format!(
423                                "got multiple {SERVER_NO_CONTEXT_TAKEOVER} params"
424                            ));
425                        }
426                        if lower.len() != SERVER_NO_CONTEXT_TAKEOVER.len() {
427                            return Err(format!(
428                                "{SERVER_NO_CONTEXT_TAKEOVER} does not expect param"
429                            ));
430                        }
431                        conf.server_no_context_takeover = true;
432                        counter.server_no_context_takeover = true;
433                        continue;
434                    }
435
436                    if lower.starts_with(CLIENT_NO_CONTEXT_TAKEOVER) {
437                        if counter.client_no_context_takeover {
438                            return Err(format!(
439                                "got multiple {CLIENT_NO_CONTEXT_TAKEOVER} params"
440                            ));
441                        }
442                        if lower.len() != CLIENT_NO_CONTEXT_TAKEOVER.len() {
443                            return Err(format!(
444                                "{CLIENT_NO_CONTEXT_TAKEOVER} does not expect param"
445                            ));
446                        }
447                        conf.client_no_context_takeover = true;
448                        counter.client_no_context_takeover = true;
449                        continue;
450                    }
451
452                    if lower.starts_with(SERVER_MAX_WINDOW_BITS) {
453                        if counter.server_max_window_bits {
454                            return Err(format!("got multiple {SERVER_MAX_WINDOW_BITS} params"));
455                        }
456
457                        if lower != SERVER_MAX_WINDOW_BITS {
458                            let remain = lower.trim_start_matches(SERVER_MAX_WINDOW_BITS);
459                            if !remain.trim_start().starts_with('=') {
460                                return Err("invalid param value".to_string());
461                            }
462                            let remain = remain.trim_start().trim_matches('=');
463                            let size = match remain.parse::<u8>() {
464                                Ok(size) => WindowBit::try_from(size)
465                                    .map_err(|e| format!("invalid param value {e}"))?,
466                                Err(e) => return Err(format!("invalid param value {e}")),
467                            };
468                            conf.server_max_window_bits = size;
469                        }
470                        counter.server_max_window_bits = true;
471                        continue;
472                    }
473
474                    if lower.starts_with(CLIENT_MAX_WINDOW_BITS) {
475                        if counter.client_max_window_bits {
476                            return Err(format!("got multiple {CLIENT_MAX_WINDOW_BITS} params"));
477                        }
478
479                        if lower != CLIENT_MAX_WINDOW_BITS {
480                            let remain = lower.trim_start_matches(CLIENT_MAX_WINDOW_BITS);
481                            if !remain.trim_start().starts_with('=') {
482                                return Err("invalid param value".to_string());
483                            }
484                            let remain = remain.trim_start().trim_matches('=');
485                            let size = match remain.parse::<u8>() {
486                                Ok(size) => WindowBit::try_from(size)
487                                    .map_err(|e| format!("invalid param value {e}"))?,
488                                Err(e) => return Err(format!("invalid param value {e}")),
489                            };
490                            conf.client_max_window_bits = size;
491                        }
492                        counter.client_max_window_bits = true;
493                        continue;
494                    }
495                    return Err(format!("unknown param {param}"));
496                }
497                configs.push(conf);
498            }
499        }
500        Ok(configs)
501    }
502}
503
504/// deflate frame write state
505pub struct DeflateWriteState {
506    write_state: FrameWriteState,
507    com: Option<WriteStreamHandler>,
508    config: FrameConfig,
509    header_buf: [u8; 14],
510    is_server: bool,
511}
512
513impl DeflateWriteState {
514    /// construct with config
515    pub fn with_config(
516        frame_config: FrameConfig,
517        pmd_config: Option<PMDConfig>,
518        is_server: bool,
519    ) -> Self {
520        let low_level_config = gen_low_level_config(&frame_config);
521        let write_state = FrameWriteState::with_config(low_level_config);
522        let com = if let Some(config) = pmd_config {
523            let com_size = if is_server {
524                config.client_max_window_bits
525            } else {
526                config.server_max_window_bits
527            };
528            let com = ZLibCompressStream::new(com_size);
529            Some(WriteStreamHandler { config, com })
530        } else {
531            None
532        };
533        Self {
534            write_state,
535            com,
536            config: frame_config,
537            header_buf: [0; 14],
538            is_server,
539        }
540    }
541}
542
543/// deflate frame read state
544pub struct DeflateReadState {
545    read_state: FrameReadState,
546    de: Option<ReadStreamHandler>,
547    config: FrameConfig,
548    fragmented: bool,
549    fragmented_data: Vec<u8>,
550    control_buf: Vec<u8>,
551    fragmented_type: OpCode,
552    is_server: bool,
553}
554
555impl DeflateReadState {
556    /// construct with config
557    pub fn with_config(
558        frame_config: FrameConfig,
559        pmd_config: Option<PMDConfig>,
560        is_server: bool,
561    ) -> Self {
562        let low_level_config = gen_low_level_config(&frame_config);
563        let read_state = FrameReadState::with_config(low_level_config);
564        let de = if let Some(config) = pmd_config {
565            let de_size = if is_server {
566                config.client_max_window_bits
567            } else {
568                config.server_max_window_bits
569            };
570            let de = ZLibDeCompressStream::new(de_size);
571            Some(ReadStreamHandler { config, de })
572        } else {
573            None
574        };
575        Self {
576            read_state,
577            de,
578            config: frame_config,
579            fragmented: false,
580            fragmented_data: vec![],
581            control_buf: vec![],
582            fragmented_type: OpCode::Binary,
583            is_server,
584        }
585    }
586}