socket_flow/
extensions.rs1const PERMESSAGE_DEFLATE: &str = "permessage-deflate";
2const CLIENT_NO_CONTEXT_TAKEOVER: &str = "client_no_context_takeover";
3const SERVER_NO_CONTEXT_TAKEOVER: &str = "server_no_context_takeover";
4const CLIENT_MAX_WINDOW_BITS: &str = "client_max_window_bits";
5const SERVER_MAX_WINDOW_BITS: &str = "server_max_window_bits";
6
7#[derive(Debug, Clone, Default)]
17pub struct Extensions {
18 pub permessage_deflate: bool,
20 pub client_no_context_takeover: Option<bool>,
28 pub server_no_context_takeover: Option<bool>,
36 pub client_max_window_bits: Option<u8>,
38 pub server_max_window_bits: Option<u8>,
40}
41
42pub fn parse_extensions(extensions_header_value: String) -> Option<Extensions> {
48 let extensions_str = extensions_header_value.split(';');
49 let mut extensions = Extensions::default();
50
51 for extension_str in extensions_str.into_iter() {
52 if extension_str.trim() == PERMESSAGE_DEFLATE {
53 extensions.permessage_deflate = true;
54 } else if extension_str.trim().starts_with(CLIENT_NO_CONTEXT_TAKEOVER) {
55 extensions.client_no_context_takeover = Some(true);
56 } else if extension_str.trim().starts_with(SERVER_NO_CONTEXT_TAKEOVER) {
57 extensions.server_no_context_takeover = Some(true);
58 } else if extension_str.trim().starts_with(CLIENT_MAX_WINDOW_BITS) {
59 if !extension_str.contains('=') {
60 extensions.client_max_window_bits = Some(15);
61 } else {
62 extensions.client_max_window_bits =
63 extension_str.trim().split('=').last()?.parse::<u8>().ok();
64 }
65 } else if extension_str.trim().starts_with(SERVER_MAX_WINDOW_BITS) {
66 if !extension_str.contains('=') {
67 extensions.server_max_window_bits = Some(15);
68 } else {
69 extensions.server_max_window_bits =
70 extension_str.trim().split('=').last()?.parse::<u8>().ok();
71 }
72 }
73 }
74 if !extensions.permessage_deflate {
75 return None;
76 }
77
78 Some(extensions)
79}
80
81pub fn merge_extensions(
86 server_extensions: Option<Extensions>,
87 client_extensions: Option<Extensions>,
88) -> Option<Extensions> {
89 let server_ext = server_extensions?;
90 let client_ext = client_extensions?;
91 let merged_extensions = Extensions {
92 permessage_deflate: client_ext.permessage_deflate && server_ext.permessage_deflate,
93 client_no_context_takeover: server_ext
94 .client_no_context_takeover
95 .and(client_ext.client_no_context_takeover),
96 server_no_context_takeover: server_ext
97 .server_no_context_takeover
98 .and(client_ext.server_no_context_takeover),
99 client_max_window_bits: match (
100 server_ext.client_max_window_bits,
101 client_ext.client_max_window_bits,
102 ) {
103 (Some(server_bits), Some(client_bits)) => Some(std::cmp::min(server_bits, client_bits)),
104 (Some(server_bits), None) => Some(server_bits),
105 (None, Some(client_bits)) => Some(client_bits),
106 (None, None) => None,
107 },
108 server_max_window_bits: match (
109 server_ext.server_max_window_bits,
110 client_ext.server_max_window_bits,
111 ) {
112 (Some(server_bits), Some(client_bits)) => Some(std::cmp::min(server_bits, client_bits)),
113 (Some(server_bits), None) => Some(server_bits),
114 (None, Some(client_bits)) => Some(client_bits),
115 (None, None) => None,
116 },
117 };
118 Some(merged_extensions)
119}
120
121pub fn add_extension_headers(request: &mut String, extensions: Option<Extensions>) {
123 match extensions {
124 None => {
125 request.push_str("\r\n");
126 }
127 Some(extensions) => {
128 if extensions.permessage_deflate {
129 request.push_str(&format!("Sec-WebSocket-Extensions: {}", PERMESSAGE_DEFLATE));
130 if let Some(true) = extensions.client_no_context_takeover {
131 request.push_str(&format!("; {}", CLIENT_NO_CONTEXT_TAKEOVER))
132 }
133 if let Some(true) = extensions.server_no_context_takeover {
134 request.push_str(&format!("; {}", SERVER_NO_CONTEXT_TAKEOVER))
135 }
136 if let Some(bits) = extensions.client_max_window_bits {
137 request.push_str(&format!("; {}={}", CLIENT_MAX_WINDOW_BITS, bits))
138 }
139 if let Some(bits) = extensions.server_max_window_bits {
140 request.push_str(&format!("; {}={}", SERVER_MAX_WINDOW_BITS, bits))
141 }
142 }
143 request.push_str("\r\n\r\n");
144 }
145 }
146}