1#![doc(
8 html_logo_url = "https://github.com/tauri-apps/tauri/raw/dev/app-icon.png",
9 html_favicon_url = "https://github.com/tauri-apps/tauri/raw/dev/app-icon.png"
10)]
11
12use futures_util::{stream::SplitSink, SinkExt, StreamExt};
13use http::header::{HeaderName, HeaderValue};
14use serde::{ser::Serializer, Deserialize, Serialize};
15use tauri::{
16 ipc::Channel,
17 plugin::{Builder as PluginBuilder, TauriPlugin},
18 Manager, Runtime, State, Window,
19};
20use tokio::{net::TcpStream, sync::Mutex};
21#[cfg(any(
22 feature = "rustls-tls",
23 feature = "rustls-tls-native-roots",
24 feature = "native-tls"
25))]
26use tokio_tungstenite::connect_async_tls_with_config;
27#[cfg(not(any(
28 feature = "rustls-tls",
29 feature = "rustls-tls-native-roots",
30 feature = "native-tls"
31)))]
32use tokio_tungstenite::connect_async_with_config;
33use tokio_tungstenite::{
34 tungstenite::{
35 client::IntoClientRequest,
36 protocol::{CloseFrame as ProtocolCloseFrame, WebSocketConfig},
37 Message,
38 },
39 Connector, MaybeTlsStream, WebSocketStream,
40};
41
42use std::collections::HashMap;
43use std::str::FromStr;
44
45type Id = u32;
46type WebSocket = WebSocketStream<MaybeTlsStream<TcpStream>>;
47type WebSocketWriter = SplitSink<WebSocket, Message>;
48type Result<T> = std::result::Result<T, Error>;
49
50#[derive(Debug, thiserror::Error)]
51enum Error {
52 #[error(transparent)]
53 Websocket(#[from] tokio_tungstenite::tungstenite::Error),
54 #[error("connection not found for the given id: {0}")]
55 ConnectionNotFound(Id),
56 #[error(transparent)]
57 InvalidHeaderValue(#[from] tokio_tungstenite::tungstenite::http::header::InvalidHeaderValue),
58 #[error(transparent)]
59 InvalidHeaderName(#[from] tokio_tungstenite::tungstenite::http::header::InvalidHeaderName),
60}
61
62impl Serialize for Error {
63 fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
64 where
65 S: Serializer,
66 {
67 serializer.serialize_str(self.to_string().as_str())
68 }
69}
70
71#[derive(Default)]
72struct ConnectionManager(Mutex<HashMap<Id, WebSocketWriter>>);
73
74#[cfg(any(
75 feature = "rustls-tls",
76 feature = "rustls-tls-native-roots",
77 feature = "native-tls"
78))]
79struct TlsConnector(Mutex<Option<Connector>>);
80
81#[derive(Deserialize)]
82#[serde(untagged, rename_all = "camelCase")]
83enum Max {
84 None,
85 Number(usize),
86}
87
88#[derive(Deserialize)]
89#[serde(rename_all = "camelCase")]
90pub(crate) struct ConnectionConfig {
91 pub read_buffer_size: Option<usize>,
92 pub write_buffer_size: Option<usize>,
93 pub max_write_buffer_size: Option<usize>,
94 pub max_message_size: Option<Max>,
95 pub max_frame_size: Option<Max>,
96 #[serde(default)]
97 pub accept_unmasked_frames: bool,
98 pub headers: Option<Vec<(String, String)>>,
99}
100
101impl From<ConnectionConfig> for WebSocketConfig {
102 fn from(config: ConnectionConfig) -> Self {
103 let mut builder =
104 WebSocketConfig::default().accept_unmasked_frames(config.accept_unmasked_frames);
105
106 if let Some(read_buffer_size) = config.read_buffer_size {
107 builder = builder.read_buffer_size(read_buffer_size)
108 }
109
110 if let Some(write_buffer_size) = config.write_buffer_size {
111 builder = builder.write_buffer_size(write_buffer_size)
112 }
113
114 if let Some(max_write_buffer_size) = config.max_write_buffer_size {
115 builder = builder.max_write_buffer_size(max_write_buffer_size)
116 }
117
118 if let Some(max_message_size) = config.max_message_size {
119 let max_size = match max_message_size {
120 Max::None => Option::None,
121 Max::Number(n) => Some(n),
122 };
123 builder = builder.max_message_size(max_size);
124 }
125
126 if let Some(max_frame_size) = config.max_frame_size {
127 let max_size = match max_frame_size {
128 Max::None => Option::None,
129 Max::Number(n) => Some(n),
130 };
131 builder = builder.max_frame_size(max_size);
132 }
133
134 builder
135 }
136}
137
138#[derive(Deserialize, Serialize)]
139struct CloseFrame {
140 pub code: u16,
141 pub reason: String,
142}
143
144#[derive(Deserialize, Serialize)]
145#[serde(tag = "type", content = "data")]
146enum WebSocketMessage {
147 Text(String),
148 Binary(Vec<u8>),
149 Ping(Vec<u8>),
150 Pong(Vec<u8>),
151 Close(Option<CloseFrame>),
152}
153
154#[tauri::command]
155async fn connect<R: Runtime>(
156 window: Window<R>,
157 url: String,
158 on_message: Channel<serde_json::Value>,
159 config: Option<ConnectionConfig>,
160) -> Result<Id> {
161 let id = rand::random();
162 let mut request = url.into_client_request()?;
163
164 if let Some(headers) = config.as_ref().and_then(|c| c.headers.as_ref()) {
165 for (k, v) in headers {
166 let header_name = HeaderName::from_str(k.as_str())?;
167 let header_value = HeaderValue::from_str(v.as_str())?;
168 request.headers_mut().insert(header_name, header_value);
169 }
170 }
171
172 #[cfg(any(
173 feature = "rustls-tls",
174 feature = "rustls-tls-native-roots",
175 feature = "native-tls"
176 ))]
177 let tls_connector = match window.try_state::<TlsConnector>() {
178 Some(tls_connector) => tls_connector.0.lock().await.clone(),
179 None => None,
180 };
181
182 #[cfg(any(
183 feature = "rustls-tls",
184 feature = "rustls-tls-native-roots",
185 feature = "native-tls"
186 ))]
187 let (ws_stream, _) =
188 connect_async_tls_with_config(request, config.map(Into::into), false, tls_connector)
189 .await?;
190 #[cfg(not(any(
191 feature = "rustls-tls",
192 feature = "rustls-tls-native-roots",
193 feature = "native-tls"
194 )))]
195 let (ws_stream, _) = connect_async_with_config(request, config.map(Into::into), false).await?;
196
197 tauri::async_runtime::spawn(async move {
198 let (write, read) = ws_stream.split();
199 let manager = window.state::<ConnectionManager>();
200 manager.0.lock().await.insert(id, write);
201 read.for_each(move |message| {
202 let window_ = window.clone();
203 let on_message_ = on_message.clone();
204 async move {
205 if let Ok(Message::Close(_)) = message {
206 let manager = window_.state::<ConnectionManager>();
207 manager.0.lock().await.remove(&id);
208 }
209
210 let response = match message {
211 Ok(Message::Text(t)) => {
212 serde_json::to_value(WebSocketMessage::Text(t.to_string())).unwrap()
213 }
214 Ok(Message::Binary(t)) => {
215 serde_json::to_value(WebSocketMessage::Binary(t.to_vec())).unwrap()
216 }
217 Ok(Message::Ping(t)) => {
218 serde_json::to_value(WebSocketMessage::Ping(t.to_vec())).unwrap()
219 }
220 Ok(Message::Pong(t)) => {
221 serde_json::to_value(WebSocketMessage::Pong(t.to_vec())).unwrap()
222 }
223 Ok(Message::Close(t)) => {
224 serde_json::to_value(WebSocketMessage::Close(t.map(|v| CloseFrame {
225 code: v.code.into(),
226 reason: v.reason.to_string(),
227 })))
228 .unwrap()
229 }
230 Ok(Message::Frame(_)) => serde_json::Value::Null, Err(e) => serde_json::to_value(Error::from(e)).unwrap(),
232 };
233
234 let _ = on_message_.send(response);
235 }
236 })
237 .await;
238 });
239
240 Ok(id)
241}
242
243#[tauri::command]
244async fn send(
245 manager: State<'_, ConnectionManager>,
246 id: Id,
247 message: WebSocketMessage,
248) -> Result<()> {
249 if let Some(write) = manager.0.lock().await.get_mut(&id) {
250 write
251 .send(match message {
252 WebSocketMessage::Text(t) => Message::Text(t.into()),
253 WebSocketMessage::Binary(t) => Message::Binary(t.into()),
254 WebSocketMessage::Ping(t) => Message::Ping(t.into()),
255 WebSocketMessage::Pong(t) => Message::Pong(t.into()),
256 WebSocketMessage::Close(t) => Message::Close(t.map(|v| ProtocolCloseFrame {
257 code: v.code.into(),
258 reason: v.reason.into(),
259 })),
260 })
261 .await?;
262 Ok(())
263 } else {
264 Err(Error::ConnectionNotFound(id))
265 }
266}
267
268pub fn init<R: Runtime>() -> TauriPlugin<R> {
269 Builder::default().build()
270}
271
272#[derive(Default)]
273pub struct Builder {
274 tls_connector: Option<Connector>,
275}
276
277impl Builder {
278 pub fn new() -> Self {
279 Self {
280 tls_connector: None,
281 }
282 }
283
284 pub fn tls_connector(mut self, connector: Connector) -> Self {
285 self.tls_connector.replace(connector);
286 self
287 }
288
289 pub fn build<R: Runtime>(self) -> TauriPlugin<R> {
290 PluginBuilder::new("websocket")
291 .invoke_handler(tauri::generate_handler![connect, send])
292 .setup(|app, _api| {
293 #[cfg(any(feature = "rustls-tls", feature = "rustls-tls-native-roots"))]
294 if (self.tls_connector.is_none()
295 || matches!(self.tls_connector, Some(Connector::Plain)))
296 && rustls::crypto::CryptoProvider::get_default().is_none()
297 {
298 let _ = rustls::crypto::ring::default_provider().install_default();
300 }
301
302 app.manage(ConnectionManager::default());
303 #[cfg(any(
304 feature = "rustls-tls",
305 feature = "rustls-tls-native-roots",
306 feature = "native-tls"
307 ))]
308 app.manage(TlsConnector(Mutex::new(self.tls_connector)));
309 Ok(())
310 })
311 .build()
312 }
313}