1use std::{
2 future::Future,
3 io,
4 pin::Pin,
5 task::{Context, Poll},
6};
7
8use futures::FutureExt;
9
10use crate::{
11 Error, Method, Status, Version,
12 server::{IncomingRequest, OutgoingResponse},
13 ws::{AgentRole, WebSocket},
14};
15
16pub struct ServerHandshake {
18 accept: String,
19 protocols: Vec<String>,
20}
21
22impl ServerHandshake {
23 pub fn new(request: IncomingRequest) -> Result<Self, Error> {
25 if request.method() != Method::Get {
26 return Err(Error::from_static_msg(
27 "invalid HTTP method for WS handshake",
28 ));
29 } else if request.version() == Version::Version10 {
30 return Err(Error::from_static_msg(
31 "this HTTP version is not supported for WS",
32 ));
33 }
34
35 let is_upgrade = request
36 .get_header_fields("connection")
37 .flat_map(|field| {
38 field
39 .value()
40 .map(|v| v.as_ref())
41 .unwrap_or(b"")
42 .split(|&b| b == b',')
43 .map(|kw| kw.trim_ascii())
44 .filter(|kw| !kw.is_empty())
45 })
46 .any(|kw| kw.eq_ignore_ascii_case(b"upgrade"));
47
48 if !is_upgrade {
49 return Err(Error::from_static_msg("not a connection upgrade"));
50 }
51
52 let is_websocket = request
53 .get_header_fields("upgrade")
54 .flat_map(|field| {
55 field
56 .value()
57 .map(|v| v.as_ref())
58 .unwrap_or(b"")
59 .split(|&b| b == b',')
60 .map(|kw| kw.trim_ascii())
61 .filter(|kw| !kw.is_empty())
62 })
63 .any(|kw| kw.eq_ignore_ascii_case(b"websocket"));
64
65 if !is_websocket {
66 return Err(Error::from_static_msg("not a WebSocket upgrade"));
67 }
68
69 let version = request
70 .get_header_field_value("sec-websocket-version")
71 .ok_or_else(|| Error::from_static_msg("missing WS version"))?
72 .trim_ascii();
73
74 if version != b"13" {
75 return Err(Error::from_static_msg("unsupported WS version"));
76 }
77
78 let key = request
79 .get_header_field_value("sec-websocket-key")
80 .ok_or_else(|| Error::from_static_msg("missing WS key"))?
81 .trim_ascii();
82
83 let protocols = request
84 .get_header_fields("sec-websocket-protocol")
85 .flat_map(|field| {
86 field
87 .value()
88 .map(|v| v.as_ref())
89 .unwrap_or(b"")
90 .split(|&b| b == b',')
91 .map(|p| p.trim_ascii())
92 .filter(|p| !p.is_empty())
93 })
94 .map(str::from_utf8)
95 .filter_map(|res| res.ok())
96 .map(|s| s.to_string())
97 .collect::<Vec<_>>();
98
99 let res = Self {
100 accept: super::create_accept_token(key),
101 protocols,
102 };
103
104 Ok(res)
105 }
106
107 #[inline]
109 pub fn protocols(&self) -> &[String] {
110 &self.protocols
111 }
112
113 pub fn complete(
118 self,
119 protocol: Option<&str>,
120 input_buffer_capacity: usize,
121 ) -> (FutureServer, OutgoingResponse) {
122 let is_valid_protocol = if let Some(protocol) = protocol {
123 self.protocols.iter().any(|p| p == protocol)
124 } else {
125 true
126 };
127
128 assert!(is_valid_protocol);
129
130 let mut builder = OutgoingResponse::builder()
131 .set_status(Status::SWITCHING_PROTOCOLS)
132 .add_header_field(("Connection", "upgrade"))
133 .add_header_field(("Upgrade", "websocket"))
134 .add_header_field(("Sec-WebSocket-Accept", self.accept));
135
136 if let Some(protocol) = protocol {
137 builder = builder.add_header_field(("Sec-WebSocket-Protocol", protocol.to_string()));
138 }
139
140 let (response, upgrade) = builder.upgrade();
141
142 let server = async move {
143 upgrade
144 .await
145 .map(|upgraded| WebSocket::new(upgraded, AgentRole::Server, input_buffer_capacity))
146 };
147
148 let future = FutureServer {
149 inner: Box::pin(server),
150 };
151
152 (future, response)
153 }
154}
155
156pub struct FutureServer {
158 inner: Pin<Box<dyn Future<Output = io::Result<WebSocket>> + Send>>,
159}
160
161impl Future for FutureServer {
162 type Output = io::Result<WebSocket>;
163
164 #[inline]
165 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
166 self.inner.poll_unpin(cx)
167 }
168}