socket_flow/
extensions.rs

1const PERMESSAGE_DEFLATE: &str = "permessage-deflate";
2const CLIENT_NO_CONTEXT_TAKEOVER: &str = "client_no_context_takeover";
3const SERVER_NO_CONTEXT_TAKEOVER: &str = "server_no_context_takeover";
4const CLIENT_MAX_WINDOW_BITS: &str = "client_max_window_bits";
5const SERVER_MAX_WINDOW_BITS: &str = "server_max_window_bits";
6
7/// It's important to enhance that some compression extensions,
8/// in some cases affects compression and
9/// decompression(client_no_context_takeover, server_no_context_takeover),
10/// while another one affects only compression(client_max_window_bits, server_max_window_bits).
11/// Keeping the context between compression and decompression,
12/// improves performance but adds more overhead, consuming more memory.
13/// Larger window sizes (closer to 15)
14/// result in better compression ratios but are slower and use more memory.
15/// Smaller window sizes (closer to 8) offer faster performance but with worse compression.
16#[derive(Debug, Clone, Default)]
17pub struct Extensions {
18    /// Dictates if compression is enabled
19    pub permessage_deflate: bool,
20    /// Asks that the client should reset its compression context after compressing a message,
21    /// if accepted by the server,
22    /// the server must reset the compression context when decompressing each message.
23    /// Bear in mind
24    /// that this option is related to resetting the context when the client compresses,
25    /// and when the server decompresses.
26    /// The opposite is not valid.
27    pub client_no_context_takeover: Option<bool>,
28    /// Asks that the server should reset its compression context after compressing a message,
29    /// if a client asks this, and the server accepts,
30    /// the client must reset the compression context when decompressing each message.
31    /// Bear in mind
32    /// that this option is related to resetting the context when the server compresses,
33    /// and when the client decompresses.
34    /// The opposite is not valid.
35    pub server_no_context_takeover: Option<bool>,
36    /// Asks that the client sets its compression window to a specific number.
37    pub client_max_window_bits: Option<u8>,
38    /// Asks that the server sets its compression window to a specific number.
39    pub server_max_window_bits: Option<u8>,
40}
41
42// In first stage server will accept all the client extension configs, and
43// will reply the handshake request with everything that came from client
44// on a second stage, the end-user will set the default extension settings when calling
45// accept_async_with_config, and the server will read the client settings from the handshake
46// and will merge with the default settings, prioritizing what is default
47pub fn parse_extensions(extensions_header_value: String) -> Option<Extensions> {
48    let extensions_str = extensions_header_value.split(';');
49    let mut extensions = Extensions::default();
50
51    for extension_str in extensions_str.into_iter() {
52        if extension_str.trim() == PERMESSAGE_DEFLATE {
53            extensions.permessage_deflate = true;
54        } else if extension_str.trim().starts_with(CLIENT_NO_CONTEXT_TAKEOVER) {
55            extensions.client_no_context_takeover = Some(true);
56        } else if extension_str.trim().starts_with(SERVER_NO_CONTEXT_TAKEOVER) {
57            extensions.server_no_context_takeover = Some(true);
58        } else if extension_str.trim().starts_with(CLIENT_MAX_WINDOW_BITS) {
59            if !extension_str.contains('=') {
60                extensions.client_max_window_bits = Some(15);
61            } else {
62                extensions.client_max_window_bits =
63                    extension_str.trim().split('=').last()?.parse::<u8>().ok();
64            }
65        } else if extension_str.trim().starts_with(SERVER_MAX_WINDOW_BITS) {
66            if !extension_str.contains('=') {
67                extensions.server_max_window_bits = Some(15);
68            } else {
69                extensions.server_max_window_bits =
70                    extension_str.trim().split('=').last()?.parse::<u8>().ok();
71            }
72        }
73    }
74    if !extensions.permessage_deflate {
75        return None;
76    }
77
78    Some(extensions)
79}
80
81// Server will use this function, to merge the client requested extensions,
82// with server configured extensions.
83// It will ensure that both parties are capable of supporting the last
84// agreed extensions
85pub fn merge_extensions(
86    server_extensions: Option<Extensions>,
87    client_extensions: Option<Extensions>,
88) -> Option<Extensions> {
89    let server_ext = server_extensions?;
90    let client_ext = client_extensions?;
91    let merged_extensions = Extensions {
92        permessage_deflate: client_ext.permessage_deflate && server_ext.permessage_deflate,
93        client_no_context_takeover: server_ext
94            .client_no_context_takeover
95            .and(client_ext.client_no_context_takeover),
96        server_no_context_takeover: server_ext
97            .server_no_context_takeover
98            .and(client_ext.server_no_context_takeover),
99        client_max_window_bits: match (
100            server_ext.client_max_window_bits,
101            client_ext.client_max_window_bits,
102        ) {
103            (Some(server_bits), Some(client_bits)) => Some(std::cmp::min(server_bits, client_bits)),
104            (Some(server_bits), None) => Some(server_bits),
105            (None, Some(client_bits)) => Some(client_bits),
106            (None, None) => None,
107        },
108        server_max_window_bits: match (
109            server_ext.server_max_window_bits,
110            client_ext.server_max_window_bits,
111        ) {
112            (Some(server_bits), Some(client_bits)) => Some(std::cmp::min(server_bits, client_bits)),
113            (Some(server_bits), None) => Some(server_bits),
114            (None, Some(client_bits)) => Some(client_bits),
115            (None, None) => None,
116        },
117    };
118    Some(merged_extensions)
119}
120
121// Function used for constructing the HTTP request headers for extensions
122pub fn add_extension_headers(request: &mut String, extensions: Option<Extensions>) {
123    match extensions {
124        None => {
125            request.push_str("\r\n");
126        }
127        Some(extensions) => {
128            if extensions.permessage_deflate {
129                request.push_str(&format!("Sec-WebSocket-Extensions: {}", PERMESSAGE_DEFLATE));
130                if let Some(true) = extensions.client_no_context_takeover {
131                    request.push_str(&format!("; {}", CLIENT_NO_CONTEXT_TAKEOVER))
132                }
133                if let Some(true) = extensions.server_no_context_takeover {
134                    request.push_str(&format!("; {}", SERVER_NO_CONTEXT_TAKEOVER))
135                }
136                if let Some(bits) = extensions.client_max_window_bits {
137                    request.push_str(&format!("; {}={}", CLIENT_MAX_WINDOW_BITS, bits))
138                }
139                if let Some(bits) = extensions.server_max_window_bits {
140                    request.push_str(&format!("; {}={}", SERVER_MAX_WINDOW_BITS, bits))
141                }
142            }
143            request.push_str("\r\n\r\n");
144        }
145    }
146}