zerodds_websocket_bridge/
close.rs1use alloc::string::String;
7use alloc::vec::Vec;
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11#[repr(u16)]
12pub enum CloseCode {
13 Normal = 1000,
15 GoingAway = 1001,
17 ProtocolError = 1002,
19 UnsupportedData = 1003,
21 NoStatusReceived = 1005,
23 AbnormalClosure = 1006,
25 InvalidPayloadData = 1007,
27 PolicyViolation = 1008,
29 MessageTooBig = 1009,
31 MandatoryExtension = 1010,
33 InternalError = 1011,
35 ServiceRestart = 1012,
37 TryAgainLater = 1013,
39 BadGateway = 1014,
41 TlsHandshakeFailure = 1015,
43}
44
45impl CloseCode {
46 #[must_use]
48 pub const fn to_u16(self) -> u16 {
49 self as u16
50 }
51
52 #[allow(clippy::result_unit_err)]
57 pub const fn from_u16(v: u16) -> Result<Self, ()> {
58 match v {
59 1000 => Ok(Self::Normal),
60 1001 => Ok(Self::GoingAway),
61 1002 => Ok(Self::ProtocolError),
62 1003 => Ok(Self::UnsupportedData),
63 1005 => Ok(Self::NoStatusReceived),
64 1006 => Ok(Self::AbnormalClosure),
65 1007 => Ok(Self::InvalidPayloadData),
66 1008 => Ok(Self::PolicyViolation),
67 1009 => Ok(Self::MessageTooBig),
68 1010 => Ok(Self::MandatoryExtension),
69 1011 => Ok(Self::InternalError),
70 1012 => Ok(Self::ServiceRestart),
71 1013 => Ok(Self::TryAgainLater),
72 1014 => Ok(Self::BadGateway),
73 1015 => Ok(Self::TlsHandshakeFailure),
74 _ => Err(()),
75 }
76 }
77
78 #[must_use]
81 pub const fn is_reserved(self) -> bool {
82 matches!(
83 self,
84 Self::NoStatusReceived | Self::AbnormalClosure | Self::TlsHandshakeFailure
85 )
86 }
87}
88
89#[derive(Debug, Clone, Copy, PartialEq, Eq)]
95pub enum StatusCodeRange {
96 Invalid,
98 ProtocolReserved,
101 LibraryDefined,
103 ApplicationDefined,
105 OutOfRange,
107}
108
109#[must_use]
111pub const fn classify_status_code(code: u16) -> StatusCodeRange {
112 match code {
113 0..=999 => StatusCodeRange::Invalid,
114 1000..=2999 => StatusCodeRange::ProtocolReserved,
115 3000..=3999 => StatusCodeRange::LibraryDefined,
116 4000..=4999 => StatusCodeRange::ApplicationDefined,
117 _ => StatusCodeRange::OutOfRange,
118 }
119}
120
121#[must_use]
125pub const fn is_forbidden_on_wire(code: u16) -> bool {
126 matches!(code, 1004 | 1005 | 1006 | 1015)
127}
128
129#[allow(clippy::result_unit_err)]
135pub const fn validate_wire_status_code(code: u16) -> Result<(), ()> {
136 if is_forbidden_on_wire(code) {
137 return Err(());
138 }
139 match classify_status_code(code) {
140 StatusCodeRange::ProtocolReserved
141 | StatusCodeRange::LibraryDefined
142 | StatusCodeRange::ApplicationDefined => Ok(()),
143 StatusCodeRange::Invalid | StatusCodeRange::OutOfRange => Err(()),
144 }
145}
146
147#[derive(Debug, Clone, PartialEq, Eq)]
149pub struct ClosePayload {
150 pub code: CloseCode,
152 pub reason: String,
154}
155
156#[must_use]
158pub fn encode_close_payload(payload: &ClosePayload) -> Vec<u8> {
159 let mut out = Vec::with_capacity(2 + payload.reason.len());
160 out.extend_from_slice(&payload.code.to_u16().to_be_bytes());
161 out.extend_from_slice(payload.reason.as_bytes());
162 out
163}
164
165#[allow(clippy::result_unit_err)]
170pub fn decode_close_payload(bytes: &[u8]) -> Result<ClosePayload, ()> {
171 if bytes.is_empty() {
172 return Err(());
174 }
175 if bytes.len() < 2 {
176 return Err(());
177 }
178 let code_u16 = u16::from_be_bytes([bytes[0], bytes[1]]);
179 let code = CloseCode::from_u16(code_u16)?;
180 if code.is_reserved() {
181 return Err(());
182 }
183 let reason = core::str::from_utf8(&bytes[2..])
184 .map_err(|_| ())?
185 .to_string();
186 if reason.len() > 123 {
187 return Err(());
188 }
189 Ok(ClosePayload { code, reason })
190}
191
192#[cfg(test)]
193#[allow(clippy::expect_used, clippy::unwrap_used, clippy::panic)]
194mod tests {
195 use super::*;
196
197 #[test]
198 fn standard_codes_round_trip() {
199 for c in [
200 CloseCode::Normal,
201 CloseCode::GoingAway,
202 CloseCode::ProtocolError,
203 CloseCode::UnsupportedData,
204 CloseCode::InvalidPayloadData,
205 CloseCode::PolicyViolation,
206 CloseCode::MessageTooBig,
207 CloseCode::MandatoryExtension,
208 CloseCode::InternalError,
209 CloseCode::ServiceRestart,
210 CloseCode::TryAgainLater,
211 CloseCode::BadGateway,
212 ] {
213 assert_eq!(CloseCode::from_u16(c.to_u16()).unwrap(), c);
214 }
215 }
216
217 #[test]
218 fn reserved_codes_flag_correctly() {
219 assert!(CloseCode::NoStatusReceived.is_reserved());
220 assert!(CloseCode::AbnormalClosure.is_reserved());
221 assert!(CloseCode::TlsHandshakeFailure.is_reserved());
222 assert!(!CloseCode::Normal.is_reserved());
223 }
224
225 #[test]
226 fn unknown_code_rejected() {
227 assert!(CloseCode::from_u16(2999).is_err());
228 }
229
230 #[test]
231 fn round_trip_payload_with_reason() {
232 let p = ClosePayload {
233 code: CloseCode::Normal,
234 reason: "bye".into(),
235 };
236 let buf = encode_close_payload(&p);
237 assert_eq!(buf[0..2], [0x03, 0xe8]); let back = decode_close_payload(&buf).unwrap();
239 assert_eq!(back, p);
240 }
241
242 #[test]
243 fn round_trip_payload_no_reason() {
244 let p = ClosePayload {
245 code: CloseCode::GoingAway,
246 reason: String::new(),
247 };
248 let buf = encode_close_payload(&p);
249 let back = decode_close_payload(&buf).unwrap();
250 assert_eq!(back, p);
251 }
252
253 #[test]
254 fn decode_reserved_code_rejected() {
255 let buf = [0x03, 0xed]; assert!(decode_close_payload(&buf).is_err());
257 }
258
259 #[test]
260 fn decode_short_payload_rejected() {
261 assert!(decode_close_payload(&[]).is_err());
262 assert!(decode_close_payload(&[0x03]).is_err());
263 }
264
265 #[test]
266 fn reason_too_long_rejected() {
267 let mut buf = alloc::vec![0x03, 0xe8];
268 buf.extend(std::iter::repeat_n(b'a', 124));
269 assert!(decode_close_payload(&buf).is_err());
270 }
271
272 #[test]
277 fn classify_status_code_recognizes_protocol_range() {
278 assert_eq!(
279 classify_status_code(1000),
280 StatusCodeRange::ProtocolReserved
281 );
282 assert_eq!(
283 classify_status_code(2999),
284 StatusCodeRange::ProtocolReserved
285 );
286 }
287
288 #[test]
289 fn classify_status_code_recognizes_library_range() {
290 assert_eq!(classify_status_code(3000), StatusCodeRange::LibraryDefined);
291 assert_eq!(classify_status_code(3999), StatusCodeRange::LibraryDefined);
292 }
293
294 #[test]
295 fn classify_status_code_recognizes_app_range() {
296 assert_eq!(
297 classify_status_code(4000),
298 StatusCodeRange::ApplicationDefined
299 );
300 assert_eq!(
301 classify_status_code(4999),
302 StatusCodeRange::ApplicationDefined
303 );
304 }
305
306 #[test]
307 fn classify_status_code_recognizes_invalid_below_1000() {
308 assert_eq!(classify_status_code(0), StatusCodeRange::Invalid);
309 assert_eq!(classify_status_code(999), StatusCodeRange::Invalid);
310 }
311
312 #[test]
313 fn classify_status_code_recognizes_out_of_range_above_5000() {
314 assert_eq!(classify_status_code(5000), StatusCodeRange::OutOfRange);
315 }
316
317 #[test]
318 fn is_forbidden_on_wire_covers_all_four() {
319 assert!(is_forbidden_on_wire(1004));
320 assert!(is_forbidden_on_wire(1005));
321 assert!(is_forbidden_on_wire(1006));
322 assert!(is_forbidden_on_wire(1015));
323 assert!(!is_forbidden_on_wire(1000));
324 }
325
326 #[test]
327 fn validate_wire_status_code_accepts_normal() {
328 assert!(validate_wire_status_code(1000).is_ok());
329 assert!(validate_wire_status_code(3000).is_ok());
330 assert!(validate_wire_status_code(4500).is_ok());
331 }
332
333 #[test]
334 fn validate_wire_status_code_rejects_forbidden() {
335 assert!(validate_wire_status_code(1004).is_err());
336 assert!(validate_wire_status_code(1005).is_err());
337 assert!(validate_wire_status_code(1006).is_err());
338 assert!(validate_wire_status_code(1015).is_err());
339 }
340
341 #[test]
342 fn validate_wire_status_code_rejects_out_of_range() {
343 assert!(validate_wire_status_code(0).is_err());
344 assert!(validate_wire_status_code(999).is_err());
345 assert!(validate_wire_status_code(5000).is_err());
346 }
347
348 #[test]
353 fn handshake_starts_in_open_state() {
354 let h = CloseHandshake::new();
355 assert_eq!(h.state(), CloseState::Open);
356 assert!(!h.is_closed());
357 }
358
359 #[test]
360 fn initiator_send_close_transitions_to_closing() {
361 let mut h = CloseHandshake::new();
362 h.initiator_send_close(CloseCode::Normal).expect("ok");
363 assert_eq!(h.state(), CloseState::ClosingInitiator);
364 }
365
366 #[test]
367 fn initiator_recv_close_response_transitions_to_closed() {
368 let mut h = CloseHandshake::new();
369 h.initiator_send_close(CloseCode::Normal).expect("ok");
370 h.recv_close_response(CloseCode::Normal).expect("ok");
371 assert_eq!(h.state(), CloseState::Closed);
372 assert!(h.is_closed());
373 }
374
375 #[test]
376 fn responder_recv_close_transitions_to_closing_responder() {
377 let mut h = CloseHandshake::new();
378 h.responder_recv_close(CloseCode::Normal).expect("ok");
379 assert_eq!(h.state(), CloseState::ClosingResponder);
380 }
381
382 #[test]
383 fn responder_send_close_response_completes_normally() {
384 let mut h = CloseHandshake::new();
385 h.responder_recv_close(CloseCode::GoingAway).expect("ok");
386 h.responder_send_close_response().expect("ok");
387 assert_eq!(h.state(), CloseState::Closed);
388 }
389
390 #[test]
391 fn fail_marks_abnormal_closure() {
392 let mut h = CloseHandshake::new();
393 h.fail("transport error");
394 assert_eq!(h.state(), CloseState::Failed);
395 assert!(h.is_closed());
396 assert_eq!(h.failure_reason(), Some("transport error"));
397 }
398
399 #[test]
400 fn second_close_send_in_closing_is_rejected() {
401 let mut h = CloseHandshake::new();
402 h.initiator_send_close(CloseCode::Normal).expect("ok");
403 assert!(h.initiator_send_close(CloseCode::Normal).is_err());
404 }
405
406 #[test]
407 fn recv_close_in_open_state_is_responder_path() {
408 let mut h = CloseHandshake::new();
409 assert!(h.recv_close_response(CloseCode::Normal).is_err());
411 }
412}
413
414#[derive(Debug, Clone, Copy, PartialEq, Eq)]
420pub enum CloseState {
421 Open,
423 ClosingInitiator,
425 ClosingResponder,
428 Closed,
431 Failed,
433}
434
435#[derive(Debug, Clone)]
437pub struct CloseHandshake {
438 state: CloseState,
439 sent_code: Option<CloseCode>,
440 received_code: Option<CloseCode>,
441 failure_reason: Option<String>,
442}
443
444impl Default for CloseHandshake {
445 fn default() -> Self {
446 Self::new()
447 }
448}
449
450impl CloseHandshake {
451 #[must_use]
453 pub fn new() -> Self {
454 Self {
455 state: CloseState::Open,
456 sent_code: None,
457 received_code: None,
458 failure_reason: None,
459 }
460 }
461
462 #[must_use]
464 pub fn state(&self) -> CloseState {
465 self.state
466 }
467
468 #[must_use]
470 pub fn is_closed(&self) -> bool {
471 matches!(self.state, CloseState::Closed | CloseState::Failed)
472 }
473
474 #[must_use]
476 pub fn failure_reason(&self) -> Option<&str> {
477 self.failure_reason.as_deref()
478 }
479
480 #[allow(clippy::result_unit_err)]
486 pub fn initiator_send_close(&mut self, code: CloseCode) -> Result<(), ()> {
487 if self.state != CloseState::Open {
488 return Err(());
489 }
490 self.state = CloseState::ClosingInitiator;
491 self.sent_code = Some(code);
492 Ok(())
493 }
494
495 #[allow(clippy::result_unit_err)]
501 pub fn recv_close_response(&mut self, code: CloseCode) -> Result<(), ()> {
502 if self.state != CloseState::ClosingInitiator {
503 return Err(());
504 }
505 self.received_code = Some(code);
506 self.state = CloseState::Closed;
507 Ok(())
508 }
509
510 #[allow(clippy::result_unit_err)]
516 pub fn responder_recv_close(&mut self, code: CloseCode) -> Result<(), ()> {
517 if self.state != CloseState::Open {
518 return Err(());
519 }
520 self.received_code = Some(code);
521 self.state = CloseState::ClosingResponder;
522 Ok(())
523 }
524
525 #[allow(clippy::result_unit_err)]
531 pub fn responder_send_close_response(&mut self) -> Result<(), ()> {
532 if self.state != CloseState::ClosingResponder {
533 return Err(());
534 }
535 self.sent_code = self.received_code;
537 self.state = CloseState::Closed;
538 Ok(())
539 }
540
541 pub fn fail(&mut self, reason: impl Into<String>) {
544 self.state = CloseState::Failed;
545 self.failure_reason = Some(reason.into());
546 }
547}