websock_tungstenite/
connection.rs1use rustls::ClientConfig;
4use std::net::SocketAddr;
5use std::sync::Arc;
6use tokio::io::{AsyncRead, AsyncWrite};
7use tokio_tungstenite::{Connector, WebSocketStream, tungstenite};
8use tungstenite::client::IntoClientRequest;
9use websock_proto::{ConnectOptions, Error, Message, Result};
10
11#[derive(Debug, Clone, Copy)]
12pub struct ConnectionInfo {
13 pub peer: std::net::SocketAddr,
15 pub local: std::net::SocketAddr,
17 pub is_tls: bool,
19}
20
21pub async fn connect(url: &str, opts: ConnectOptions) -> Result<Connection> {
23 connect_with_tls(url, opts, None).await
24}
25
26pub async fn connect_with_tls(
28 url: &str,
29 opts: ConnectOptions,
30 tls: Option<Arc<ClientConfig>>,
31) -> Result<Connection> {
32 let mut req = url
33 .into_client_request()
34 .map_err(|e| Error::InvalidUrl(e.to_string()))?;
35
36 {
38 let headers = req.headers_mut();
39 for (k, v) in opts.headers {
40 let name = tungstenite::http::header::HeaderName::from_bytes(k.as_bytes())
41 .map_err(|e| Error::Protocol(format!("invalid header name: {e}")))?;
42 let value = tungstenite::http::header::HeaderValue::from_str(&v)
43 .map_err(|e| Error::Protocol(format!("invalid header value: {e}")))?;
44 headers.append(name, value);
45 }
46
47 if !opts.protocols.is_empty() {
49 let joined = opts.protocols.join(",");
50 let value = tungstenite::http::header::HeaderValue::from_str(&joined)
51 .map_err(|e| Error::Protocol(format!("invalid protocol value: {e}")))?;
52 headers.insert(tungstenite::http::header::SEC_WEBSOCKET_PROTOCOL, value);
53 }
54 }
55
56 let connector = tls.map(Connector::Rustls);
57 let (ws, _resp) = tokio_tungstenite::connect_async_tls_with_config(req, None, false, connector)
58 .await
59 .map_err(map_tungstenite_err)?;
60
61 let info = ConnectionInfo {
62 peer: ws
63 .get_ref()
64 .get_ref()
65 .peer_addr()
66 .map_err(|e| Error::Io(e.to_string()))?,
67 local: ws
68 .get_ref()
69 .get_ref()
70 .local_addr()
71 .map_err(|e| Error::Io(e.to_string()))?,
72 is_tls: matches!(ws.get_ref(), tokio_tungstenite::MaybeTlsStream::Rustls(_)),
73 };
74
75 Ok(Connection { ws, info })
76}
77
78pub struct Connection<S = tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>> {
80 pub(crate) ws: WebSocketStream<S>,
81 pub(crate) info: ConnectionInfo,
82}
83
84impl<S> Connection<S>
85where
86 S: AsyncRead + AsyncWrite + Unpin,
87{
88 pub async fn send(&mut self, msg: Message) -> Result<()> {
90 use futures_util::SinkExt;
91
92 let tmsg = match msg {
93 Message::Text(s) => tungstenite::Message::Text(s.into()),
94 Message::Binary(b) => tungstenite::Message::Binary(b),
95 };
96
97 self.ws.send(tmsg).await.map_err(map_tungstenite_err)?;
98 Ok(())
99 }
100
101 pub async fn recv(&mut self) -> Result<Message> {
103 use futures_util::{SinkExt, StreamExt};
104
105 loop {
106 let item = self.ws.next().await.ok_or(Error::Closed)?;
107 let msg = item.map_err(map_tungstenite_err)?;
108
109 match msg {
110 tungstenite::Message::Ping(p) => {
111 self.ws
112 .send(tungstenite::Message::Pong(p))
113 .await
114 .map_err(map_tungstenite_err)?;
115 continue;
116 }
117 tungstenite::Message::Pong(_) => continue,
118 tungstenite::Message::Text(s) => return Ok(Message::Text(s.to_string())),
119 tungstenite::Message::Binary(b) => return Ok(Message::Binary(b)),
120 tungstenite::Message::Close(_) => {
121 let _ = self.ws.close(None).await;
122 return Err(Error::Closed);
123 }
124 _ => return Err(Error::Protocol("unsupported ws message".into())),
125 }
126 }
127 }
128
129 pub async fn close(&mut self) -> Result<()> {
131 self.ws.close(None).await.map_err(map_tungstenite_err)?;
132 Ok(())
133 }
134
135 pub fn get_ref(&self) -> &S {
137 self.ws.get_ref()
138 }
139
140 pub fn get_mut(&mut self) -> &mut S {
142 self.ws.get_mut()
143 }
144}
145
146impl<S> websock_proto::WebSocketConnection for Connection<S>
147where
148 S: AsyncRead + AsyncWrite + Unpin,
149{
150 fn send<'a>(&'a mut self, msg: Message) -> websock_proto::LocalBoxFuture<'a, Result<()>> {
151 Box::pin(async move { Connection::send(self, msg).await })
152 }
153
154 fn recv<'a>(&'a mut self) -> websock_proto::LocalBoxFuture<'a, Result<Message>> {
155 Box::pin(async move { Connection::recv(self).await })
156 }
157
158 fn close<'a>(&'a mut self) -> websock_proto::LocalBoxFuture<'a, Result<()>> {
159 Box::pin(async move { Connection::close(self).await })
160 }
161}
162
163impl<S> Connection<S> {
164 pub fn peer_addr(&self) -> SocketAddr {
166 self.info.peer
167 }
168 pub fn local_addr(&self) -> SocketAddr {
170 self.info.local
171 }
172 pub fn is_tls(&self) -> bool {
174 self.info.is_tls
175 }
176 pub fn info(&self) -> ConnectionInfo {
178 self.info
179 }
180}
181
182pub(crate) fn map_tungstenite_err(e: tungstenite::Error) -> Error {
184 use tungstenite::Error as E;
185 match e {
186 E::ConnectionClosed | E::AlreadyClosed => Error::Closed,
187 E::Io(io) => Error::Io(io.to_string()),
188 E::Tls(tls) => Error::Tls(tls.to_string()),
189 E::Url(url) => Error::InvalidUrl(url.to_string()),
190 E::Protocol(err) => Error::Protocol(err.to_string()),
191 E::Utf8(err) => Error::Protocol(err),
192 E::Capacity(err) => Error::Protocol(err.to_string()),
193 E::HttpFormat(err) => Error::Protocol(err.to_string()),
194 other => Error::Other(other.to_string()),
195 }
196}