1use crate::{Event, ProtocolError};
2use std::collections::{HashMap, HashSet};
3use std::fmt::Debug;
4use std::hash::Hash;
5
6#[derive(Debug, Eq, PartialEq, Copy, Clone, Hash)]
8pub enum Role {
9 Client,
11 Server,
13}
14
15#[derive(Debug, Eq, PartialEq, Copy, Clone, Hash)]
17pub enum State {
18 Idle,
20 SendResponse,
22 SendBody,
24 Done,
26 MustClose,
28 Closed,
30 Error,
32 MightSwitchProtocol,
34 SwitchedProtocol,
36}
37
38#[derive(Debug, Eq, PartialEq, Copy, Clone, Hash)]
40pub enum Switch {
41 SwitchUpgrade,
43 SwitchConnect,
45 Client,
47}
48
49#[derive(Debug, Eq, PartialEq, Copy, Clone, Hash)]
51pub enum EventType {
52 Request,
54 InformationalResponse,
56 NormalResponse,
58 Data,
60 EndOfMessage,
62 ConnectionClosed,
64 NeedData,
66 Paused,
68 RequestClient, InformationalResponseSwitchUpgrade, NormalResponseSwitchConnect, }
76
77impl From<&Event> for EventType {
78 fn from(value: &Event) -> Self {
79 match value {
80 Event::Request(_) => EventType::Request,
81 Event::NormalResponse(_) => EventType::NormalResponse,
82 Event::InformationalResponse(_) => EventType::InformationalResponse,
83 Event::Data(_) => EventType::Data,
84 Event::EndOfMessage(_) => EventType::EndOfMessage,
85 Event::ConnectionClosed(_) => EventType::ConnectionClosed,
86 Event::NeedData() => EventType::NeedData,
87 Event::Paused() => EventType::Paused,
88 }
89 }
90}
91
92pub struct ConnectionState {
93 pub keep_alive: bool,
94 pub pending_switch_proposals: HashSet<Switch>,
95 pub states: HashMap<Role, State>,
96}
97
98impl ConnectionState {
99 pub fn new() -> Self {
100 ConnectionState {
101 keep_alive: true,
102 pending_switch_proposals: HashSet::new(),
103 states: HashMap::from([(Role::Client, State::Idle), (Role::Server, State::Idle)]),
104 }
105 }
106
107 pub fn process_error(&mut self, role: Role) {
108 self.states.insert(role, State::Error);
109 self._fire_state_triggered_transitions();
110 }
111
112 pub fn process_keep_alive_disabled(&mut self) {
113 self.keep_alive = false;
114 self._fire_state_triggered_transitions();
115 }
116
117 pub fn process_client_switch_proposal(&mut self, switch_event: Switch) {
118 self.pending_switch_proposals.insert(switch_event);
119 self._fire_state_triggered_transitions();
120 }
121
122 pub fn process_event(
123 &mut self,
124 role: Role,
125 event_type: EventType,
126 server_switch_event: Option<Switch>,
127 ) -> Result<(), ProtocolError> {
128 let mut _event_type = event_type;
129 if let Some(server_switch_event) = server_switch_event {
130 if role != Role::Server {
131 return Err(ProtocolError::LocalProtocolError(
132 format!(
133 "Received server switch event {:?} for role {:?}",
134 server_switch_event, role
135 )
136 .into(),
137 ));
138 }
139 if !self.pending_switch_proposals.contains(&server_switch_event) {
140 return Err(ProtocolError::LocalProtocolError(
141 format!(
142 "Received server {:?} event without a pending proposal",
143 server_switch_event
144 )
145 .into(),
146 ));
147 }
148 _event_type = match (event_type, server_switch_event) {
149 (EventType::Request, Switch::Client) => EventType::RequestClient,
150 (EventType::NormalResponse, Switch::SwitchConnect) => {
151 EventType::NormalResponseSwitchConnect
152 }
153 (EventType::InformationalResponse, Switch::SwitchUpgrade) => {
154 EventType::InformationalResponseSwitchUpgrade
155 }
156 _ => {
157 return Err(ProtocolError::LocalProtocolError(
158 format!(
159 "Can't handle event type {:?} with server switch {:?} when role={:?} and state={:?}",
160 _event_type, server_switch_event, role, self.states[&role]
161 )
162 .into(),
163 ))
164 }
165 };
166 }
167 if server_switch_event.is_none() && _event_type == EventType::NormalResponse {
168 self.pending_switch_proposals.clear();
169 }
170 self._fire_event_triggered_transitions(role, _event_type)?;
171 if _event_type == EventType::Request {
172 if role != Role::Client {
173 return Err(ProtocolError::LocalProtocolError(
174 format!("Received request event for role {:?}", role).into(),
175 ));
176 }
177 self._fire_event_triggered_transitions(Role::Server, EventType::RequestClient)?
178 }
179 self._fire_state_triggered_transitions();
180 Ok(())
181 }
182
183 fn _fire_event_triggered_transitions(
184 &mut self,
185 role: Role,
186 event_type: EventType,
187 ) -> Result<(), ProtocolError> {
188 let state = self.states[&role];
189 let new_state = match (role, state, event_type) {
190 (Role::Client, State::Idle, EventType::Request) => State::SendBody,
191 (Role::Client, State::Idle, EventType::ConnectionClosed) => State::Closed,
192 (Role::Client, State::SendBody, EventType::Data) => State::SendBody,
193 (Role::Client, State::SendBody, EventType::EndOfMessage) => State::Done,
194 (Role::Client, State::Done, EventType::ConnectionClosed) => State::Closed,
195 (Role::Client, State::MustClose, EventType::ConnectionClosed) => State::Closed,
196 (Role::Client, State::Closed, EventType::ConnectionClosed) => State::Closed,
197
198 (Role::Server, State::Idle, EventType::ConnectionClosed) => State::Closed,
199 (Role::Server, State::Idle, EventType::NormalResponse) => State::SendBody,
200 (Role::Server, State::Idle, EventType::RequestClient) => State::SendResponse,
201 (Role::Server, State::SendResponse, EventType::InformationalResponse) => {
202 State::SendResponse
203 }
204 (Role::Server, State::SendResponse, EventType::NormalResponse) => State::SendBody,
205 (Role::Server, State::SendResponse, EventType::InformationalResponseSwitchUpgrade) => {
206 State::SwitchedProtocol
207 }
208 (Role::Server, State::SendResponse, EventType::NormalResponseSwitchConnect) => {
209 State::SwitchedProtocol
210 }
211 (Role::Server, State::SendBody, EventType::Data) => State::SendBody,
212 (Role::Server, State::SendBody, EventType::EndOfMessage) => State::Done,
213 (Role::Server, State::Done, EventType::ConnectionClosed) => State::Closed,
214 (Role::Server, State::MustClose, EventType::ConnectionClosed) => State::Closed,
215 (Role::Server, State::Closed, EventType::ConnectionClosed) => State::Closed,
216 _ => {
217 return Err(ProtocolError::LocalProtocolError(
218 format!(
219 "Can't handle event type {:?} when role={:?} and state={:?}",
220 event_type, role, state
221 )
222 .into(),
223 ))
224 }
225 };
226 self.states.insert(role, new_state);
227 Ok(())
228 }
229
230 fn _fire_state_triggered_transitions(&mut self) {
231 loop {
232 let start_states = self.states.clone();
233
234 if self.pending_switch_proposals.len() > 0 {
235 if self.states[&Role::Client] == State::Done {
236 self.states.insert(Role::Client, State::MightSwitchProtocol);
237 }
238 }
239
240 if self.pending_switch_proposals.is_empty() {
241 if self.states[&Role::Client] == State::MightSwitchProtocol {
242 self.states.insert(Role::Client, State::Done);
243 }
244 }
245
246 if !self.keep_alive {
247 for role in &[Role::Client, Role::Server] {
248 if self.states[role] == State::Done {
249 self.states.insert(*role, State::MustClose);
250 }
251 }
252 }
253
254 let joint_state = (self.states[&Role::Client], self.states[&Role::Server]);
255 let changes = match joint_state {
256 (State::MightSwitchProtocol, State::SwitchedProtocol) => {
257 vec![(Role::Client, State::SwitchedProtocol)]
258 }
259 (State::Closed, State::Done) => {
260 vec![(Role::Server, State::MustClose)]
261 }
262 (State::Closed, State::Idle) => {
263 vec![(Role::Server, State::MustClose)]
264 }
265 (State::Error, State::Done) => vec![(Role::Server, State::MustClose)],
266 (State::Done, State::Closed) => {
267 vec![(Role::Client, State::MustClose)]
268 }
269 (State::Idle, State::Closed) => {
270 vec![(Role::Client, State::MustClose)]
271 }
272 (State::Done, State::Error) => vec![(Role::Client, State::MustClose)],
273 _ => vec![],
274 };
275 for (role, new_state) in changes {
276 self.states.insert(role, new_state);
277 }
278
279 if self.states == start_states {
280 return;
281 }
282 }
283 }
284
285 pub fn start_next_cycle(&mut self) -> Result<(), ProtocolError> {
286 if self.states != HashMap::from([(Role::Client, State::Done), (Role::Server, State::Done)])
287 {
288 return Err(ProtocolError::LocalProtocolError(
289 format!("Not in a reusable state. self.states={:?}", self.states).into(),
290 ));
291 }
292 assert!(self.keep_alive);
293 assert!(self.pending_switch_proposals.is_empty());
294 self.states.clear();
295 self.states.insert(Role::Client, State::Idle);
296 self.states.insert(Role::Server, State::Idle);
297 Ok(())
298 }
299}
300
301#[cfg(test)]
302mod tests {
303 use super::*;
304
305 #[test]
306 fn test_connection_state() {
307 let mut cs = ConnectionState::new();
308
309 assert_eq!(
312 cs.states,
313 HashMap::from([(Role::Client, State::Idle), (Role::Server, State::Idle)])
314 );
315
316 cs.process_event(Role::Client, EventType::Request, None)
317 .unwrap();
318 assert_eq!(
320 cs.states,
321 HashMap::from([
322 (Role::Client, State::SendBody),
323 (Role::Server, State::SendResponse)
324 ])
325 );
326
327 cs.process_event(Role::Client, EventType::Request, None)
329 .expect_err("Expected LocalProtocolError");
330 assert_eq!(
331 cs.states,
332 HashMap::from([
333 (Role::Client, State::SendBody),
334 (Role::Server, State::SendResponse)
335 ])
336 );
337
338 cs.process_event(Role::Server, EventType::InformationalResponse, None)
339 .unwrap();
340 assert_eq!(
341 cs.states,
342 HashMap::from([
343 (Role::Client, State::SendBody),
344 (Role::Server, State::SendResponse)
345 ])
346 );
347
348 cs.process_event(Role::Server, EventType::NormalResponse, None)
349 .unwrap();
350 assert_eq!(
351 cs.states,
352 HashMap::from([
353 (Role::Client, State::SendBody),
354 (Role::Server, State::SendBody)
355 ])
356 );
357
358 cs.process_event(Role::Client, EventType::EndOfMessage, None)
359 .unwrap();
360 cs.process_event(Role::Server, EventType::EndOfMessage, None)
361 .unwrap();
362 assert_eq!(
363 cs.states,
364 HashMap::from([(Role::Client, State::Done), (Role::Server, State::Done)])
365 );
366
367 cs.process_event(Role::Server, EventType::ConnectionClosed, None)
370 .unwrap();
371 assert_eq!(
372 cs.states,
373 HashMap::from([
374 (Role::Client, State::MustClose),
375 (Role::Server, State::Closed)
376 ])
377 );
378 }
379
380 #[test]
381 fn test_connection_state_keep_alive() {
382 let mut cs = ConnectionState::new();
384 cs.process_event(Role::Client, EventType::Request, None)
385 .unwrap();
386 cs.process_keep_alive_disabled();
387 cs.process_event(Role::Client, EventType::EndOfMessage, None)
388 .unwrap();
389 assert_eq!(
390 cs.states,
391 HashMap::from([
392 (Role::Client, State::MustClose),
393 (Role::Server, State::SendResponse)
394 ])
395 );
396
397 cs.process_event(Role::Server, EventType::NormalResponse, None)
398 .unwrap();
399 cs.process_event(Role::Server, EventType::EndOfMessage, None)
400 .unwrap();
401 assert_eq!(
402 cs.states,
403 HashMap::from([
404 (Role::Client, State::MustClose),
405 (Role::Server, State::MustClose)
406 ])
407 );
408 }
409
410 #[test]
411 fn test_connection_state_keep_alive_in_done() {
412 let mut cs = ConnectionState::new();
416 cs.process_event(Role::Client, EventType::Request, None)
417 .unwrap();
418 cs.process_event(Role::Client, EventType::EndOfMessage, None)
419 .unwrap();
420 assert_eq!(cs.states[&Role::Client], State::Done);
421 cs.process_keep_alive_disabled();
422 assert_eq!(cs.states[&Role::Client], State::MustClose);
423 }
424
425 #[test]
426 fn test_connection_state_switch_denied() {
427 for switch_type in [Switch::SwitchConnect, Switch::SwitchUpgrade] {
428 for deny_early in [true, false] {
429 let mut cs = ConnectionState::new();
430 cs.process_client_switch_proposal(switch_type);
431 cs.process_event(Role::Client, EventType::Request, None)
432 .unwrap();
433 cs.process_event(Role::Client, EventType::Data, None)
434 .unwrap();
435 assert_eq!(
436 cs.states,
437 HashMap::from([
438 (Role::Client, State::SendBody),
439 (Role::Server, State::SendResponse)
440 ])
441 );
442
443 assert!(cs.pending_switch_proposals.contains(&switch_type));
444
445 if deny_early {
446 cs.process_event(Role::Server, EventType::NormalResponse, None)
448 .unwrap();
449 assert!(cs.pending_switch_proposals.is_empty());
450 }
451
452 cs.process_event(Role::Client, EventType::EndOfMessage, None)
453 .unwrap();
454
455 if deny_early {
456 assert_eq!(
457 cs.states,
458 HashMap::from([
459 (Role::Client, State::Done),
460 (Role::Server, State::SendBody)
461 ])
462 );
463 } else {
464 assert_eq!(
465 cs.states,
466 HashMap::from([
467 (Role::Client, State::MightSwitchProtocol),
468 (Role::Server, State::SendResponse)
469 ])
470 );
471
472 cs.process_event(Role::Server, EventType::InformationalResponse, None)
473 .unwrap();
474 assert_eq!(
475 cs.states,
476 HashMap::from([
477 (Role::Client, State::MightSwitchProtocol),
478 (Role::Server, State::SendResponse)
479 ])
480 );
481
482 cs.process_event(Role::Server, EventType::NormalResponse, None)
483 .unwrap();
484 assert_eq!(
485 cs.states,
486 HashMap::from([
487 (Role::Client, State::Done),
488 (Role::Server, State::SendBody)
489 ])
490 );
491 assert!(cs.pending_switch_proposals.is_empty());
492 }
493 }
494 }
495 }
496
497 #[test]
498 fn test_connection_state_protocol_switch_accepted() {
499 for switch_event in [Switch::SwitchUpgrade, Switch::SwitchConnect] {
500 let mut cs = ConnectionState::new();
501 cs.process_client_switch_proposal(switch_event);
502 cs.process_event(Role::Client, EventType::Request, None)
503 .unwrap();
504 cs.process_event(Role::Client, EventType::Data, None)
505 .unwrap();
506 assert_eq!(
507 cs.states,
508 HashMap::from([
509 (Role::Client, State::SendBody),
510 (Role::Server, State::SendResponse)
511 ])
512 );
513
514 cs.process_event(Role::Client, EventType::EndOfMessage, None)
515 .unwrap();
516 assert_eq!(
517 cs.states,
518 HashMap::from([
519 (Role::Client, State::MightSwitchProtocol),
520 (Role::Server, State::SendResponse)
521 ])
522 );
523
524 cs.process_event(Role::Server, EventType::InformationalResponse, None)
525 .unwrap();
526 assert_eq!(
527 cs.states,
528 HashMap::from([
529 (Role::Client, State::MightSwitchProtocol),
530 (Role::Server, State::SendResponse)
531 ])
532 );
533
534 cs.process_event(
535 Role::Server,
536 match switch_event {
537 Switch::SwitchUpgrade => EventType::InformationalResponse,
538 Switch::SwitchConnect => EventType::NormalResponse,
539 _ => panic!(),
540 },
541 Some(switch_event),
542 )
543 .unwrap();
544 assert_eq!(
545 cs.states,
546 HashMap::from([
547 (Role::Client, State::SwitchedProtocol),
548 (Role::Server, State::SwitchedProtocol)
549 ])
550 );
551 }
552 }
553
554 #[test]
555 fn test_connection_state_double_protocol_switch() {
556 for server_switch in [
559 None,
560 Some(Switch::SwitchUpgrade),
561 Some(Switch::SwitchConnect),
562 ] {
563 let mut cs = ConnectionState::new();
564 cs.process_client_switch_proposal(Switch::SwitchUpgrade);
565 cs.process_client_switch_proposal(Switch::SwitchConnect);
566 cs.process_event(Role::Client, EventType::Request, None)
567 .unwrap();
568 cs.process_event(Role::Client, EventType::EndOfMessage, None)
569 .unwrap();
570 assert_eq!(
571 cs.states,
572 HashMap::from([
573 (Role::Client, State::MightSwitchProtocol),
574 (Role::Server, State::SendResponse)
575 ])
576 );
577 cs.process_event(
578 Role::Server,
579 match server_switch {
580 Some(Switch::SwitchUpgrade) => EventType::InformationalResponse,
581 Some(Switch::SwitchConnect) => EventType::NormalResponse,
582 None => EventType::NormalResponse,
583 _ => panic!(),
584 },
585 server_switch,
586 )
587 .unwrap();
588 if server_switch.is_none() {
589 assert_eq!(
590 cs.states,
591 HashMap::from([(Role::Client, State::Done), (Role::Server, State::SendBody)])
592 );
593 } else {
594 assert_eq!(
595 cs.states,
596 HashMap::from([
597 (Role::Client, State::SwitchedProtocol),
598 (Role::Server, State::SwitchedProtocol)
599 ])
600 );
601 }
602 }
603 }
604
605 #[test]
606 fn test_connection_state_inconsistent_protocol_switch() {
607 for (client_switches, server_switch) in [
608 (vec![], Switch::SwitchUpgrade),
609 (vec![], Switch::SwitchConnect),
610 (vec![Switch::SwitchUpgrade], Switch::SwitchConnect),
611 (vec![Switch::SwitchConnect], Switch::SwitchUpgrade),
612 ] {
613 let mut cs = ConnectionState::new();
614 for client_switch in client_switches.clone() {
615 cs.process_client_switch_proposal(client_switch);
616 }
617 cs.process_event(Role::Client, EventType::Request, None)
618 .unwrap();
619 cs.process_event(Role::Server, EventType::NormalResponse, Some(server_switch))
620 .expect_err("Expected LocalProtocolError");
621 }
622 }
623
624 #[test]
625 fn test_connection_state_invalid_switch_event_returns_error() {
626 let mut cs = ConnectionState::new();
627 cs.process_client_switch_proposal(Switch::SwitchUpgrade);
628 cs.process_event(Role::Client, EventType::Request, None)
629 .unwrap();
630 cs.process_event(Role::Server, EventType::Data, Some(Switch::SwitchUpgrade))
631 .expect_err("Expected LocalProtocolError");
632 }
633
634 #[test]
635 fn test_connection_state_keepalive_protocol_switch_interaction() {
636 let mut cs = ConnectionState::new();
638 cs.process_client_switch_proposal(Switch::SwitchUpgrade);
639 cs.process_event(Role::Client, EventType::Request, None)
640 .unwrap();
641 cs.process_keep_alive_disabled();
642 cs.process_event(Role::Client, EventType::Data, None)
643 .unwrap();
644 assert_eq!(
645 cs.states,
646 HashMap::from([
647 (Role::Client, State::SendBody),
648 (Role::Server, State::SendResponse)
649 ])
650 );
651 }
652
653 #[test]
654 fn test_connection_state_reuse() {
655 let mut cs = ConnectionState::new();
656
657 cs.start_next_cycle()
658 .expect_err("Expected LocalProtocolError");
659
660 cs.process_event(Role::Client, EventType::Request, None)
661 .unwrap();
662 cs.process_event(Role::Client, EventType::EndOfMessage, None)
663 .unwrap();
664
665 cs.start_next_cycle()
666 .expect_err("Expected LocalProtocolError");
667
668 cs.process_event(Role::Server, EventType::NormalResponse, None)
669 .unwrap();
670 cs.process_event(Role::Server, EventType::EndOfMessage, None)
671 .unwrap();
672
673 cs.start_next_cycle().unwrap();
674 assert_eq!(
675 cs.states,
676 HashMap::from([(Role::Client, State::Idle), (Role::Server, State::Idle)])
677 );
678
679 cs.process_event(Role::Client, EventType::Request, None)
682 .unwrap();
683 cs.process_keep_alive_disabled();
684 cs.process_event(Role::Client, EventType::EndOfMessage, None)
685 .unwrap();
686 cs.process_event(Role::Server, EventType::NormalResponse, None)
687 .unwrap();
688 cs.process_event(Role::Server, EventType::EndOfMessage, None)
689 .unwrap();
690
691 cs.start_next_cycle()
692 .expect_err("Expected LocalProtocolError");
693
694 cs = ConnectionState::new();
697 cs.process_event(Role::Client, EventType::Request, None)
698 .unwrap();
699 cs.process_event(Role::Client, EventType::EndOfMessage, None)
700 .unwrap();
701 cs.process_event(Role::Client, EventType::ConnectionClosed, None)
702 .unwrap();
703 cs.process_event(Role::Server, EventType::NormalResponse, None)
704 .unwrap();
705 cs.process_event(Role::Server, EventType::EndOfMessage, None)
706 .unwrap();
707
708 cs.start_next_cycle()
709 .expect_err("Expected LocalProtocolError");
710
711 cs = ConnectionState::new();
714 cs.process_client_switch_proposal(Switch::SwitchUpgrade);
715 cs.process_event(Role::Client, EventType::Request, None)
716 .unwrap();
717 cs.process_event(Role::Client, EventType::EndOfMessage, None)
718 .unwrap();
719 cs.process_event(
720 Role::Server,
721 EventType::InformationalResponse,
722 Some(Switch::SwitchUpgrade),
723 )
724 .unwrap();
725
726 cs.start_next_cycle()
727 .expect_err("Expected LocalProtocolError");
728
729 cs = ConnectionState::new();
732 cs.process_client_switch_proposal(Switch::SwitchUpgrade);
733 cs.process_event(Role::Client, EventType::Request, None)
734 .unwrap();
735 cs.process_event(Role::Client, EventType::EndOfMessage, None)
736 .unwrap();
737 cs.process_event(Role::Server, EventType::NormalResponse, None)
738 .unwrap();
739 cs.process_event(Role::Server, EventType::EndOfMessage, None)
740 .unwrap();
741
742 cs.start_next_cycle().unwrap();
743 assert_eq!(
744 cs.states,
745 HashMap::from([(Role::Client, State::Idle), (Role::Server, State::Idle)])
746 );
747 }
748
749 #[test]
750 fn test_server_request_is_illegal() {
751 let mut cs = ConnectionState::new();
754 cs.process_event(Role::Server, EventType::Request, None)
755 .expect_err("Expected LocalProtocolError");
756 }
757}