trillium_client/
websocket.rs1use crate::{Conn, WebSocketConfig, WebSocketConn};
4use std::{
5 borrow::Cow,
6 error::Error,
7 fmt::{self, Display},
8 ops::{Deref, DerefMut},
9};
10use trillium_http::{
11 KnownHeaderName::{
12 Connection, SecWebsocketAccept, SecWebsocketKey, SecWebsocketVersion,
13 Upgrade as UpgradeHeader,
14 },
15 Method, Status, Upgrade, Version,
16};
17pub use trillium_websockets::Message;
18use trillium_websockets::{Role, websocket_accept_hash, websocket_key};
19
20impl Conn {
21 fn set_websocket_upgrade_headers_h1(&mut self) {
22 let headers = self.request_headers_mut();
23 headers.try_insert(UpgradeHeader, "websocket");
24 headers.try_insert(Connection, "upgrade");
25 headers.try_insert(SecWebsocketVersion, "13");
26 headers.try_insert(SecWebsocketKey, websocket_key());
27 }
28
29 pub async fn into_websocket(self) -> Result<WebSocketConn, WebSocketUpgradeError> {
46 self.into_websocket_with_config(WebSocketConfig::default())
47 .await
48 }
49
50 pub async fn into_websocket_with_config(
52 self,
53 config: WebSocketConfig,
54 ) -> Result<WebSocketConn, WebSocketUpgradeError> {
55 if self.status().is_some() {
56 return Err(WebSocketUpgradeError::new(self, ErrorKind::AlreadyExecuted));
57 }
58
59 match self.http_version() {
60 Version::Http2 => self.into_websocket_extended_connect(config).await,
61 Version::Http3 => Err(WebSocketUpgradeError::new(
62 self,
63 ErrorKind::ExtendedConnectUnsupported,
64 )),
65 _ => self.into_websocket_h1(config).await,
66 }
67 }
68
69 async fn into_websocket_h1(
70 mut self,
71 config: WebSocketConfig,
72 ) -> Result<WebSocketConn, WebSocketUpgradeError> {
73 self.set_websocket_upgrade_headers_h1();
74 if let Err(e) = (&mut self).await {
75 return Err(WebSocketUpgradeError::new(self, e.into()));
76 }
77 let status = self.status().expect("Response did not include status");
78 if status != Status::SwitchingProtocols {
79 return Err(WebSocketUpgradeError::new(self, ErrorKind::Status(status)));
80 }
81 let key = self
82 .request_headers()
83 .get_str(SecWebsocketKey)
84 .expect("Request did not include Sec-WebSocket-Key");
85 let accept_key = websocket_accept_hash(key);
86 if self.response_headers().get_str(SecWebsocketAccept) != Some(&accept_key) {
87 return Err(WebSocketUpgradeError::new(self, ErrorKind::InvalidAccept));
88 }
89 let peer_ip = self.peer_addr().map(|addr| addr.ip());
90 let mut conn = WebSocketConn::new(Upgrade::from(self), Some(config), Role::Client).await;
91 conn.set_peer_ip(peer_ip);
92 Ok(conn)
93 }
94
95 async fn into_websocket_extended_connect(
96 mut self,
97 config: WebSocketConfig,
98 ) -> Result<WebSocketConn, WebSocketUpgradeError> {
99 self.request_headers_mut()
105 .try_insert(SecWebsocketVersion, "13");
106 self.set_method(Method::Connect);
107 self.protocol = Some(Cow::Borrowed("websocket"));
108
109 if let Err(e) = (&mut self).await {
115 let kind = match e {
116 trillium_http::Error::ExtendedConnectUnsupported => {
117 ErrorKind::ExtendedConnectUnsupported
118 }
119 other => other.into(),
120 };
121 return Err(WebSocketUpgradeError::new(self, kind));
122 }
123
124 let status = self.status().expect("Response did not include status");
125 if status != Status::Ok {
126 return Err(WebSocketUpgradeError::new(self, ErrorKind::Status(status)));
127 }
128
129 let peer_ip = self.peer_addr().map(|addr| addr.ip());
130 let mut conn = WebSocketConn::new(Upgrade::from(self), Some(config), Role::Client).await;
131 conn.set_peer_ip(peer_ip);
132 Ok(conn)
133 }
134}
135
136#[derive(thiserror::Error, Debug)]
138#[non_exhaustive]
139pub enum ErrorKind {
140 #[error(transparent)]
142 Http(#[from] trillium_http::Error),
143
144 #[error("Unexpected response status {0} for websocket upgrade")]
147 Status(Status),
148
149 #[error("Response Sec-WebSocket-Accept was missing or invalid")]
151 InvalidAccept,
152
153 #[error(
157 "Conn::into_websocket called after execution — build the conn and await into_websocket \
158 instead of awaiting the conn separately"
159 )]
160 AlreadyExecuted,
161
162 #[error("peer does not support extended CONNECT, or h3 client websocket framing is missing")]
169 ExtendedConnectUnsupported,
170}
171
172#[derive(Debug)]
177pub struct WebSocketUpgradeError {
178 pub kind: ErrorKind,
180 conn: Box<Conn>,
181}
182
183impl WebSocketUpgradeError {
184 fn new(conn: Conn, kind: ErrorKind) -> Self {
185 let conn = Box::new(conn);
186 Self { conn, kind }
187 }
188}
189
190impl From<WebSocketUpgradeError> for Conn {
191 fn from(value: WebSocketUpgradeError) -> Self {
192 *value.conn
193 }
194}
195
196impl Deref for WebSocketUpgradeError {
197 type Target = Conn;
198
199 fn deref(&self) -> &Self::Target {
200 &self.conn
201 }
202}
203impl DerefMut for WebSocketUpgradeError {
204 fn deref_mut(&mut self) -> &mut Self::Target {
205 &mut self.conn
206 }
207}
208
209impl Error for WebSocketUpgradeError {}
210
211impl Display for WebSocketUpgradeError {
212 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
213 self.kind.fmt(f)
214 }
215}