trillium_client/
websocket.rs1use crate::{Conn, WebSocketConfig, WebSocketConn};
4use std::{
5 borrow::Cow,
6 error::Error,
7 fmt::{self, Display},
8 ops::{Deref, DerefMut},
9};
10use trillium_http::{
11 KnownHeaderName::{
12 Connection, SecWebsocketAccept, SecWebsocketKey, SecWebsocketVersion,
13 Upgrade as UpgradeHeader,
14 },
15 Method, Status, Upgrade, Version,
16};
17pub use trillium_websockets::Message;
18use trillium_websockets::{Role, websocket_accept_hash, websocket_key};
19
20impl Conn {
21 fn set_websocket_upgrade_headers_h1(&mut self) {
22 let headers = self.request_headers_mut();
23 headers.try_insert(UpgradeHeader, "websocket");
24 headers.try_insert(Connection, "upgrade");
25 headers.try_insert(SecWebsocketVersion, "13");
26 headers.try_insert(SecWebsocketKey, websocket_key());
27 }
28
29 pub async fn into_websocket(self) -> Result<WebSocketConn, WebSocketUpgradeError> {
41 self.into_websocket_with_config(WebSocketConfig::default())
42 .await
43 }
44
45 pub async fn into_websocket_with_config(
47 self,
48 config: WebSocketConfig,
49 ) -> Result<WebSocketConn, WebSocketUpgradeError> {
50 if self.status().is_some() {
51 return Err(WebSocketUpgradeError::new(self, ErrorKind::AlreadyExecuted));
52 }
53
54 match self.http_version() {
55 Version::Http2 | Version::Http3 => self.into_websocket_extended_connect(config).await,
56 _ => self.into_websocket_h1(config).await,
57 }
58 }
59
60 async fn into_websocket_h1(
61 mut self,
62 config: WebSocketConfig,
63 ) -> Result<WebSocketConn, WebSocketUpgradeError> {
64 self.set_websocket_upgrade_headers_h1();
65 if let Err(e) = (&mut self).await {
66 return Err(WebSocketUpgradeError::new(self, e.into()));
67 }
68 let status = self.status().expect("Response did not include status");
69 if status != Status::SwitchingProtocols {
70 return Err(WebSocketUpgradeError::new(self, ErrorKind::Status(status)));
71 }
72 let key = self
73 .request_headers()
74 .get_str(SecWebsocketKey)
75 .expect("Request did not include Sec-WebSocket-Key");
76 let accept_key = websocket_accept_hash(key);
77 if self.response_headers().get_str(SecWebsocketAccept) != Some(&accept_key) {
78 return Err(WebSocketUpgradeError::new(self, ErrorKind::InvalidAccept));
79 }
80 let peer_ip = self.peer_addr().map(|addr| addr.ip());
81 let mut conn = WebSocketConn::new(Upgrade::from(self), Some(config), Role::Client).await;
82 conn.set_peer_ip(peer_ip);
83 Ok(conn)
84 }
85
86 async fn into_websocket_extended_connect(
87 mut self,
88 config: WebSocketConfig,
89 ) -> Result<WebSocketConn, WebSocketUpgradeError> {
90 self.request_headers_mut()
96 .try_insert(SecWebsocketVersion, "13");
97 self.set_method(Method::Connect);
98 self.protocol = Some(Cow::Borrowed("websocket"));
99
100 if let Err(e) = (&mut self).await {
106 let kind = match e {
107 trillium_http::Error::ExtendedConnectUnsupported => {
108 ErrorKind::ExtendedConnectUnsupported
109 }
110 other => other.into(),
111 };
112 return Err(WebSocketUpgradeError::new(self, kind));
113 }
114
115 let status = self.status().expect("Response did not include status");
116 if status != Status::Ok {
117 return Err(WebSocketUpgradeError::new(self, ErrorKind::Status(status)));
118 }
119
120 let peer_ip = self.peer_addr().map(|addr| addr.ip());
121 let mut conn = WebSocketConn::new(Upgrade::from(self), Some(config), Role::Client).await;
122 conn.set_peer_ip(peer_ip);
123 Ok(conn)
124 }
125}
126
127#[derive(thiserror::Error, Debug)]
129#[non_exhaustive]
130pub enum ErrorKind {
131 #[error(transparent)]
133 Http(#[from] trillium_http::Error),
134
135 #[error("Unexpected response status {0} for websocket upgrade")]
138 Status(Status),
139
140 #[error("Response Sec-WebSocket-Accept was missing or invalid")]
142 InvalidAccept,
143
144 #[error(
148 "Conn::into_websocket called after execution — build the conn and await into_websocket \
149 instead of awaiting the conn separately"
150 )]
151 AlreadyExecuted,
152
153 #[error("peer does not support extended CONNECT")]
157 ExtendedConnectUnsupported,
158}
159
160#[derive(Debug)]
165pub struct WebSocketUpgradeError {
166 pub kind: ErrorKind,
168 conn: Box<Conn>,
169}
170
171impl WebSocketUpgradeError {
172 fn new(conn: Conn, kind: ErrorKind) -> Self {
173 let conn = Box::new(conn);
174 Self { conn, kind }
175 }
176}
177
178impl From<WebSocketUpgradeError> for Conn {
179 fn from(value: WebSocketUpgradeError) -> Self {
180 *value.conn
181 }
182}
183
184impl Deref for WebSocketUpgradeError {
185 type Target = Conn;
186
187 fn deref(&self) -> &Self::Target {
188 &self.conn
189 }
190}
191impl DerefMut for WebSocketUpgradeError {
192 fn deref_mut(&mut self) -> &mut Self::Target {
193 &mut self.conn
194 }
195}
196
197impl Error for WebSocketUpgradeError {}
198
199impl Display for WebSocketUpgradeError {
200 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
201 self.kind.fmt(f)
202 }
203}