ttpkit_http/ws/
server.rs

1use std::{
2    future::Future,
3    io,
4    pin::Pin,
5    task::{Context, Poll},
6};
7
8use futures::FutureExt;
9
10use crate::{
11    Error, Method, Status, Version,
12    server::{IncomingRequest, OutgoingResponse},
13    ws::{AgentRole, WebSocket},
14};
15
16/// WS server handshake.
17pub struct ServerHandshake {
18    accept: String,
19    protocols: Vec<String>,
20}
21
22impl ServerHandshake {
23    /// Create a new handshake from a given incoming HTTP request.
24    pub fn new(request: IncomingRequest) -> Result<Self, Error> {
25        if request.method() != Method::Get {
26            return Err(Error::from_static_msg(
27                "invalid HTTP method for WS handshake",
28            ));
29        } else if request.version() == Version::Version10 {
30            return Err(Error::from_static_msg(
31                "this HTTP version is not supported for WS",
32            ));
33        }
34
35        let is_upgrade = request
36            .get_header_fields("connection")
37            .flat_map(|field| {
38                field
39                    .value()
40                    .map(|v| v.as_ref())
41                    .unwrap_or(b"")
42                    .split(|&b| b == b',')
43                    .map(|kw| kw.trim_ascii())
44                    .filter(|kw| !kw.is_empty())
45            })
46            .any(|kw| kw.eq_ignore_ascii_case(b"upgrade"));
47
48        if !is_upgrade {
49            return Err(Error::from_static_msg("not a connection upgrade"));
50        }
51
52        let is_websocket = request
53            .get_header_fields("upgrade")
54            .flat_map(|field| {
55                field
56                    .value()
57                    .map(|v| v.as_ref())
58                    .unwrap_or(b"")
59                    .split(|&b| b == b',')
60                    .map(|kw| kw.trim_ascii())
61                    .filter(|kw| !kw.is_empty())
62            })
63            .any(|kw| kw.eq_ignore_ascii_case(b"websocket"));
64
65        if !is_websocket {
66            return Err(Error::from_static_msg("not a WebSocket upgrade"));
67        }
68
69        let version = request
70            .get_header_field_value("sec-websocket-version")
71            .ok_or_else(|| Error::from_static_msg("missing WS version"))?
72            .trim_ascii();
73
74        if version != b"13" {
75            return Err(Error::from_static_msg("unsupported WS version"));
76        }
77
78        let key = request
79            .get_header_field_value("sec-websocket-key")
80            .ok_or_else(|| Error::from_static_msg("missing WS key"))?
81            .trim_ascii();
82
83        let protocols = request
84            .get_header_fields("sec-websocket-protocol")
85            .flat_map(|field| {
86                field
87                    .value()
88                    .map(|v| v.as_ref())
89                    .unwrap_or(b"")
90                    .split(|&b| b == b',')
91                    .map(|p| p.trim_ascii())
92                    .filter(|p| !p.is_empty())
93            })
94            .map(str::from_utf8)
95            .filter_map(|res| res.ok())
96            .map(|s| s.to_string())
97            .collect::<Vec<_>>();
98
99        let res = Self {
100            accept: super::create_accept_token(key),
101            protocols,
102        };
103
104        Ok(res)
105    }
106
107    /// Get list of WS sub-protocols supported by the client.
108    #[inline]
109    pub fn protocols(&self) -> &[String] {
110        &self.protocols
111    }
112
113    /// Complete the WS handshake.
114    ///
115    /// The method will prepare an HTTP response for the client and indicate
116    /// use of a given sub-protocol.
117    pub fn complete(
118        self,
119        protocol: Option<&str>,
120        input_buffer_capacity: usize,
121    ) -> (FutureServer, OutgoingResponse) {
122        let is_valid_protocol = if let Some(protocol) = protocol {
123            self.protocols.iter().any(|p| p == protocol)
124        } else {
125            true
126        };
127
128        assert!(is_valid_protocol);
129
130        let mut builder = OutgoingResponse::builder()
131            .set_status(Status::SWITCHING_PROTOCOLS)
132            .add_header_field(("Connection", "upgrade"))
133            .add_header_field(("Upgrade", "websocket"))
134            .add_header_field(("Sec-WebSocket-Accept", self.accept));
135
136        if let Some(protocol) = protocol {
137            builder = builder.add_header_field(("Sec-WebSocket-Protocol", protocol.to_string()));
138        }
139
140        let (response, upgrade) = builder.upgrade();
141
142        let server = async move {
143            upgrade
144                .await
145                .map(|upgraded| WebSocket::new(upgraded, AgentRole::Server, input_buffer_capacity))
146        };
147
148        let future = FutureServer {
149            inner: Box::pin(server),
150        };
151
152        (future, response)
153    }
154}
155
156/// Future WS server.
157pub struct FutureServer {
158    inner: Pin<Box<dyn Future<Output = io::Result<WebSocket>> + Send>>,
159}
160
161impl Future for FutureServer {
162    type Output = io::Result<WebSocket>;
163
164    #[inline]
165    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
166        self.inner.poll_unpin(cx)
167    }
168}