sockudo_ws/
deflate.rs

1//! Per-Message Deflate Extension (RFC 7692)
2//!
3//! This module implements the permessage-deflate WebSocket extension,
4//! which compresses message payloads using the DEFLATE algorithm.
5
6use bytes::{Bytes, BytesMut};
7use flate2::{Compress, Compression, Decompress, FlushCompress, FlushDecompress, Status};
8
9use crate::error::{Error, Result};
10
11/// Trailer bytes that must be removed after compression and added before decompression
12const DEFLATE_TRAILER: [u8; 4] = [0x00, 0x00, 0xff, 0xff];
13
14/// Default LZ77 window size (32KB = 2^15)
15pub const DEFAULT_WINDOW_BITS: u8 = 15;
16
17/// Minimum LZ77 window size (256 bytes = 2^8)
18pub const MIN_WINDOW_BITS: u8 = 8;
19
20/// Maximum LZ77 window size (32KB = 2^15)
21pub const MAX_WINDOW_BITS: u8 = 15;
22
23/// Configuration for permessage-deflate extension
24#[derive(Debug, Clone)]
25pub struct DeflateConfig {
26    /// Server's maximum LZ77 window bits (for compression when server, decompression when client)
27    pub server_max_window_bits: u8,
28    /// Client's maximum LZ77 window bits (for compression when client, decompression when server)
29    pub client_max_window_bits: u8,
30    /// If true, server must reset compression context after each message
31    pub server_no_context_takeover: bool,
32    /// If true, client must reset compression context after each message
33    pub client_no_context_takeover: bool,
34    /// Compression level (0-9, where 0 is no compression, 9 is max)
35    pub compression_level: u32,
36    /// Minimum message size to compress (smaller messages may not benefit)
37    pub compression_threshold: usize,
38}
39
40impl Default for DeflateConfig {
41    fn default() -> Self {
42        Self {
43            server_max_window_bits: DEFAULT_WINDOW_BITS,
44            client_max_window_bits: DEFAULT_WINDOW_BITS,
45            server_no_context_takeover: false,
46            client_no_context_takeover: false,
47            compression_level: 6,      // Default zlib compression level
48            compression_threshold: 32, // Don't compress tiny messages
49        }
50    }
51}
52
53impl DeflateConfig {
54    /// Create config optimized for low memory usage
55    pub fn low_memory() -> Self {
56        Self {
57            server_max_window_bits: 10, // 1KB window
58            client_max_window_bits: 10,
59            server_no_context_takeover: true,
60            client_no_context_takeover: true,
61            compression_level: 1, // Fast compression
62            compression_threshold: 64,
63        }
64    }
65
66    /// Create config optimized for best compression
67    pub fn best_compression() -> Self {
68        Self {
69            server_max_window_bits: MAX_WINDOW_BITS,
70            client_max_window_bits: MAX_WINDOW_BITS,
71            server_no_context_takeover: false,
72            client_no_context_takeover: false,
73            compression_level: 9,
74            compression_threshold: 16,
75        }
76    }
77
78    /// Parse extension parameters from handshake
79    pub fn from_params(params: &[(&str, Option<&str>)]) -> Result<Self> {
80        let mut config = Self::default();
81
82        for (name, value) in params {
83            match *name {
84                "server_no_context_takeover" => {
85                    if value.is_some() {
86                        return Err(Error::HandshakeFailed(
87                            "server_no_context_takeover must not have a value",
88                        ));
89                    }
90                    config.server_no_context_takeover = true;
91                }
92                "client_no_context_takeover" => {
93                    if value.is_some() {
94                        return Err(Error::HandshakeFailed(
95                            "client_no_context_takeover must not have a value",
96                        ));
97                    }
98                    config.client_no_context_takeover = true;
99                    // When client uses no_context_takeover, server should too
100                    // to ensure decompression works correctly on client side
101                    config.server_no_context_takeover = true;
102                }
103                "server_max_window_bits" => {
104                    if let Some(v) = value {
105                        let bits: u8 = v.parse().map_err(|_| {
106                            Error::HandshakeFailed("invalid server_max_window_bits value")
107                        })?;
108                        if !(MIN_WINDOW_BITS..=MAX_WINDOW_BITS).contains(&bits) {
109                            return Err(Error::HandshakeFailed(
110                                "server_max_window_bits out of range (8-15)",
111                            ));
112                        }
113                        config.server_max_window_bits = bits;
114                    }
115                }
116                "client_max_window_bits" => {
117                    if let Some(v) = value {
118                        let bits: u8 = v.parse().map_err(|_| {
119                            Error::HandshakeFailed("invalid client_max_window_bits value")
120                        })?;
121                        if !(MIN_WINDOW_BITS..=MAX_WINDOW_BITS).contains(&bits) {
122                            return Err(Error::HandshakeFailed(
123                                "client_max_window_bits out of range (8-15)",
124                            ));
125                        }
126                        config.client_max_window_bits = bits;
127                    }
128                    // If no value, client just indicates support
129                }
130                _ => {
131                    return Err(Error::HandshakeFailed(
132                        "unknown permessage-deflate parameter",
133                    ));
134                }
135            }
136        }
137
138        Ok(config)
139    }
140
141    /// Generate extension response header value for server
142    pub fn to_response_header(&self) -> String {
143        let mut parts = vec!["permessage-deflate".to_string()];
144
145        if self.server_no_context_takeover {
146            parts.push("server_no_context_takeover".to_string());
147        }
148        if self.client_no_context_takeover {
149            parts.push("client_no_context_takeover".to_string());
150        }
151        if self.server_max_window_bits < MAX_WINDOW_BITS {
152            parts.push(format!(
153                "server_max_window_bits={}",
154                self.server_max_window_bits
155            ));
156        }
157        if self.client_max_window_bits < MAX_WINDOW_BITS {
158            parts.push(format!(
159                "client_max_window_bits={}",
160                self.client_max_window_bits
161            ));
162        }
163
164        parts.join("; ")
165    }
166}
167
168/// Deflate compressor for outgoing messages
169pub struct DeflateEncoder {
170    compress: Compress,
171    no_context_takeover: bool,
172    #[allow(dead_code)]
173    window_bits: u8,
174    #[allow(dead_code)]
175    compression_level: Compression,
176    threshold: usize,
177}
178
179impl DeflateEncoder {
180    /// Create a new encoder
181    pub fn new(window_bits: u8, no_context_takeover: bool, level: u32, threshold: usize) -> Self {
182        let compression_level = Compression::new(level);
183        // Use the negotiated window_bits for compression
184        // This ensures the compressed data can be decompressed by clients with smaller windows
185        let compress = Compress::new_with_window_bits(compression_level, false, window_bits);
186
187        Self {
188            compress,
189            no_context_takeover,
190            window_bits,
191            compression_level,
192            threshold,
193        }
194    }
195
196    /// Compress a message payload
197    ///
198    /// Returns None if the message is too small to benefit from compression
199    /// or if compression would make it larger.
200    pub fn compress(&mut self, data: &[u8]) -> Result<Option<Bytes>> {
201        if data.len() < self.threshold {
202            return Ok(None);
203        }
204
205        // Reset context if required
206        if self.no_context_takeover {
207            self.compress.reset();
208        }
209
210        // Estimate output size (compressed data is often smaller, but we need headroom)
211        let max_output = data.len() + 64;
212        let mut output = BytesMut::with_capacity(max_output);
213
214        // Compress the data
215        let mut total_in: usize = 0;
216        let mut iterations = 0u32;
217
218        loop {
219            iterations += 1;
220            if iterations > 100_000 {
221                return Err(Error::Compression(
222                    "compression took too many iterations".into(),
223                ));
224            }
225
226            // Ensure we have space in output buffer
227            let available = output.capacity() - output.len();
228            if available == 0 {
229                output.reserve(4096);
230            }
231
232            let input = &data[total_in..];
233            let before_out = self.compress.total_out();
234            let before_in = self.compress.total_in();
235
236            // Get writable slice using spare_capacity_mut to avoid UB with uninitialized memory.
237            // We get the spare capacity, compress into it, then only set_len for bytes actually written.
238            let out_start = output.len();
239            let spare = output.spare_capacity_mut();
240
241            // SAFETY: We're creating a &mut [u8] from MaybeUninit<u8> slice.
242            // flate2's compress() will write to this buffer and tell us how many bytes were written.
243            // We only call set_len() for the bytes that were actually initialized by compress().
244            let spare_slice = unsafe {
245                std::slice::from_raw_parts_mut(spare.as_mut_ptr() as *mut u8, spare.len())
246            };
247
248            let status = self
249                .compress
250                .compress(input, spare_slice, FlushCompress::Sync)
251                .map_err(|e| Error::Compression(format!("deflate error: {}", e)))?;
252
253            let consumed = (self.compress.total_in() - before_in) as usize;
254            let produced = (self.compress.total_out() - before_out) as usize;
255
256            total_in += consumed;
257
258            // SAFETY: compress() wrote exactly `produced` bytes to spare_slice.
259            // We're only extending the length by the number of bytes that were initialized.
260            unsafe {
261                output.set_len(out_start + produced);
262            }
263
264            match status {
265                Status::Ok | Status::BufError => {
266                    if total_in >= data.len() {
267                        break;
268                    }
269                }
270                Status::StreamEnd => break,
271            }
272        }
273
274        // Per RFC 7692: Remove trailing 0x00 0x00 0xff 0xff
275        if output.len() >= 4 && output.ends_with(&DEFLATE_TRAILER) {
276            output.truncate(output.len() - 4);
277        }
278
279        // Only use compression if it actually reduces size
280        if output.len() >= data.len() {
281            return Ok(None);
282        }
283
284        Ok(Some(output.freeze()))
285    }
286
287    /// Reset the compression context (for no_context_takeover)
288    pub fn reset(&mut self) {
289        self.compress.reset();
290    }
291}
292
293/// Deflate decompressor for incoming messages
294pub struct DeflateDecoder {
295    decompress: Decompress,
296    no_context_takeover: bool,
297    #[allow(dead_code)]
298    window_bits: u8,
299}
300
301impl DeflateDecoder {
302    /// Create a new decoder
303    pub fn new(window_bits: u8, no_context_takeover: bool) -> Self {
304        // Use raw deflate (no zlib header) with the negotiated window_bits
305        let decompress = Decompress::new_with_window_bits(false, window_bits);
306
307        Self {
308            decompress,
309            no_context_takeover,
310            window_bits,
311        }
312    }
313
314    /// Decompress a message payload
315    pub fn decompress(&mut self, data: &[u8], max_size: usize) -> Result<Bytes> {
316        // Reset context if required
317        if self.no_context_takeover {
318            self.decompress.reset(false);
319        }
320
321        // Per RFC 7692: Append 0x00 0x00 0xff 0xff before decompressing
322        let mut input = BytesMut::with_capacity(data.len() + 4);
323        input.extend_from_slice(data);
324        input.extend_from_slice(&DEFLATE_TRAILER);
325
326        // Start with reasonable output buffer (at least 1KB or 4x input)
327        let initial_cap = std::cmp::max(1024, data.len() * 4);
328        let mut output = BytesMut::with_capacity(initial_cap);
329        let mut total_in: usize = 0;
330        let mut iterations = 0u32;
331
332        loop {
333            iterations += 1;
334            // Safety check to prevent infinite loops
335            if iterations > 100_000 {
336                return Err(Error::Compression(
337                    "decompression took too many iterations".into(),
338                ));
339            }
340
341            // Check size limit
342            if output.len() > max_size {
343                return Err(Error::MessageTooLarge);
344            }
345
346            // Ensure we have space in output buffer
347            let available = output.capacity() - output.len();
348            if available == 0 {
349                if output.capacity() >= max_size {
350                    return Err(Error::MessageTooLarge);
351                }
352                // At least double or add 4KB, whichever is larger
353                let additional = std::cmp::max(output.capacity(), 4096);
354                output.reserve(additional);
355            }
356
357            let before_out = self.decompress.total_out();
358            let before_in = self.decompress.total_in();
359
360            // Get writable slice using spare_capacity_mut to avoid UB with uninitialized memory.
361            let out_start = output.len();
362            let spare = output.spare_capacity_mut();
363
364            // SAFETY: We're creating a &mut [u8] from MaybeUninit<u8> slice.
365            // flate2's decompress() will write to this buffer and tell us how many bytes were written.
366            // We only call set_len() for the bytes that were actually initialized by decompress().
367            let spare_slice = unsafe {
368                std::slice::from_raw_parts_mut(spare.as_mut_ptr() as *mut u8, spare.len())
369            };
370
371            let status = self
372                .decompress
373                .decompress(&input[total_in..], spare_slice, FlushDecompress::Sync)
374                .map_err(|e| Error::Compression(format!("inflate error: {}", e)))?;
375
376            let consumed = (self.decompress.total_in() - before_in) as usize;
377            let produced = (self.decompress.total_out() - before_out) as usize;
378
379            total_in += consumed;
380
381            // SAFETY: decompress() wrote exactly `produced` bytes to spare_slice.
382            // We're only extending the length by the number of bytes that were initialized.
383            unsafe {
384                output.set_len(out_start + produced);
385            }
386
387            match status {
388                Status::Ok => {
389                    if total_in >= input.len() {
390                        break;
391                    }
392                }
393                Status::StreamEnd => break,
394                Status::BufError => {
395                    // Need more output space - will be handled at top of loop
396                }
397            }
398        }
399
400        Ok(output.freeze())
401    }
402
403    /// Reset the decompression context (for no_context_takeover)
404    pub fn reset(&mut self) {
405        self.decompress.reset(false);
406    }
407}
408
409/// Combined compressor/decompressor context for a WebSocket connection
410pub struct DeflateContext {
411    /// Encoder for outgoing messages
412    pub encoder: DeflateEncoder,
413    /// Decoder for incoming messages
414    pub decoder: DeflateDecoder,
415    /// Configuration
416    pub config: DeflateConfig,
417}
418
419impl DeflateContext {
420    /// Create context for server role
421    pub fn server(config: DeflateConfig) -> Self {
422        let encoder = DeflateEncoder::new(
423            config.server_max_window_bits,
424            config.server_no_context_takeover,
425            config.compression_level,
426            config.compression_threshold,
427        );
428        let decoder = DeflateDecoder::new(
429            config.client_max_window_bits,
430            config.client_no_context_takeover,
431        );
432
433        Self {
434            encoder,
435            decoder,
436            config,
437        }
438    }
439
440    /// Create context for client role
441    pub fn client(config: DeflateConfig) -> Self {
442        let encoder = DeflateEncoder::new(
443            config.client_max_window_bits,
444            config.client_no_context_takeover,
445            config.compression_level,
446            config.compression_threshold,
447        );
448        let decoder = DeflateDecoder::new(
449            config.server_max_window_bits,
450            config.server_no_context_takeover,
451        );
452
453        Self {
454            encoder,
455            decoder,
456            config,
457        }
458    }
459
460    /// Compress a message if beneficial
461    pub fn compress(&mut self, data: &[u8]) -> Result<Option<Bytes>> {
462        self.encoder.compress(data)
463    }
464
465    /// Decompress a message
466    pub fn decompress(&mut self, data: &[u8], max_size: usize) -> Result<Bytes> {
467        self.decoder.decompress(data, max_size)
468    }
469}
470
471/// Parse permessage-deflate extension parameters from header value
472pub fn parse_deflate_offer(value: &str) -> Option<Vec<(&str, Option<&str>)>> {
473    let value = value.trim();
474
475    // Check if this is a permessage-deflate offer
476    if !value.starts_with("permessage-deflate") {
477        return None;
478    }
479
480    let rest = value.strip_prefix("permessage-deflate")?.trim_start();
481
482    if rest.is_empty() {
483        return Some(Vec::new());
484    }
485
486    // Must start with semicolon if there are parameters
487    if !rest.starts_with(';') {
488        return None;
489    }
490
491    let mut params = Vec::new();
492
493    for part in rest[1..].split(';') {
494        let part = part.trim();
495        if part.is_empty() {
496            continue;
497        }
498
499        if let Some((name, value)) = part.split_once('=') {
500            let name = name.trim();
501            let value = value.trim().trim_matches('"');
502            params.push((name, Some(value)));
503        } else {
504            params.push((part, None));
505        }
506    }
507
508    Some(params)
509}
510
511#[cfg(test)]
512mod tests {
513    use super::*;
514
515    #[test]
516    fn test_compress_decompress() {
517        let config = DeflateConfig::default();
518        let mut ctx = DeflateContext::server(config);
519
520        let original = b"Hello, World! This is a test message that should be compressed.";
521
522        // Compress
523        let compressed = ctx.compress(original).unwrap();
524        assert!(compressed.is_some());
525        let compressed = compressed.unwrap();
526        assert!(compressed.len() < original.len());
527
528        // Decompress
529        let decompressed = ctx.decompress(&compressed, 1024).unwrap();
530        assert_eq!(&decompressed[..], &original[..]);
531    }
532
533    #[test]
534    fn test_small_message_not_compressed() {
535        let config = DeflateConfig {
536            compression_threshold: 100,
537            ..Default::default()
538        };
539        let mut ctx = DeflateContext::server(config);
540
541        let small = b"tiny";
542        let result = ctx.compress(small).unwrap();
543        assert!(result.is_none());
544    }
545
546    #[test]
547    fn test_context_takeover() {
548        let config = DeflateConfig {
549            server_no_context_takeover: false,
550            compression_threshold: 0,
551            ..Default::default()
552        };
553        let mut ctx = DeflateContext::server(config);
554
555        let msg = b"Hello, World! Hello, World! Hello, World!";
556
557        // First compression
558        let first = ctx.compress(msg).unwrap().unwrap();
559
560        // Second compression should benefit from context
561        let second = ctx.compress(msg).unwrap().unwrap();
562
563        // With context takeover, second should be smaller or equal
564        // (references previous data in LZ77 window)
565        assert!(second.len() <= first.len());
566    }
567
568    #[test]
569    fn test_no_context_takeover() {
570        let config = DeflateConfig {
571            server_no_context_takeover: true,
572            compression_threshold: 0,
573            ..Default::default()
574        };
575        let mut ctx = DeflateContext::server(config);
576
577        let msg = b"Hello, World! Hello, World! Hello, World!";
578
579        // Both compressions should produce same output
580        let first = ctx.compress(msg).unwrap().unwrap();
581        let second = ctx.compress(msg).unwrap().unwrap();
582
583        assert_eq!(first.len(), second.len());
584    }
585
586    #[test]
587    fn test_parse_deflate_offer() {
588        // Simple offer
589        let params = parse_deflate_offer("permessage-deflate").unwrap();
590        assert!(params.is_empty());
591
592        // With parameters
593        let params = parse_deflate_offer(
594            "permessage-deflate; server_no_context_takeover; server_max_window_bits=10",
595        )
596        .unwrap();
597        assert_eq!(params.len(), 2);
598        assert_eq!(params[0], ("server_no_context_takeover", None));
599        assert_eq!(params[1], ("server_max_window_bits", Some("10")));
600
601        // Not a deflate offer
602        assert!(parse_deflate_offer("some-other-extension").is_none());
603    }
604
605    #[test]
606    fn test_config_from_params() {
607        let params = vec![
608            ("server_no_context_takeover", None),
609            ("client_max_window_bits", Some("12")),
610        ];
611
612        let config = DeflateConfig::from_params(&params).unwrap();
613        assert!(config.server_no_context_takeover);
614        assert!(!config.client_no_context_takeover);
615        assert_eq!(config.client_max_window_bits, 12);
616        assert_eq!(config.server_max_window_bits, DEFAULT_WINDOW_BITS);
617    }
618
619    #[test]
620    fn test_response_header() {
621        let config = DeflateConfig {
622            server_no_context_takeover: true,
623            server_max_window_bits: 12,
624            ..Default::default()
625        };
626
627        let header = config.to_response_header();
628        assert!(header.contains("permessage-deflate"));
629        assert!(header.contains("server_no_context_takeover"));
630        assert!(header.contains("server_max_window_bits=12"));
631    }
632}