1#![doc(
8 html_logo_url = "https://github.com/tauri-apps/tauri/raw/dev/app-icon.png",
9 html_favicon_url = "https://github.com/tauri-apps/tauri/raw/dev/app-icon.png"
10)]
11
12use futures_util::{stream::SplitSink, SinkExt, StreamExt};
13use http::header::{HeaderName, HeaderValue};
14use serde::{ser::Serializer, Deserialize, Serialize};
15use tauri::{
16 ipc::Channel,
17 plugin::{Builder as PluginBuilder, TauriPlugin},
18 Manager, Runtime, State, Window,
19};
20use tokio::{net::TcpStream, sync::Mutex};
21#[cfg(any(feature = "rustls-tls", feature = "native-tls"))]
22use tokio_tungstenite::connect_async_tls_with_config;
23#[cfg(not(any(feature = "rustls-tls", feature = "native-tls")))]
24use tokio_tungstenite::connect_async_with_config;
25use tokio_tungstenite::{
26 tungstenite::{
27 client::IntoClientRequest,
28 protocol::{CloseFrame as ProtocolCloseFrame, WebSocketConfig},
29 Message,
30 },
31 Connector, MaybeTlsStream, WebSocketStream,
32};
33
34use std::collections::HashMap;
35use std::str::FromStr;
36
37type Id = u32;
38type WebSocket = WebSocketStream<MaybeTlsStream<TcpStream>>;
39type WebSocketWriter = SplitSink<WebSocket, Message>;
40type Result<T> = std::result::Result<T, Error>;
41
42#[derive(Debug, thiserror::Error)]
43enum Error {
44 #[error(transparent)]
45 Websocket(#[from] tokio_tungstenite::tungstenite::Error),
46 #[error("connection not found for the given id: {0}")]
47 ConnectionNotFound(Id),
48 #[error(transparent)]
49 InvalidHeaderValue(#[from] tokio_tungstenite::tungstenite::http::header::InvalidHeaderValue),
50 #[error(transparent)]
51 InvalidHeaderName(#[from] tokio_tungstenite::tungstenite::http::header::InvalidHeaderName),
52}
53
54impl Serialize for Error {
55 fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
56 where
57 S: Serializer,
58 {
59 serializer.serialize_str(self.to_string().as_str())
60 }
61}
62
63#[derive(Default)]
64struct ConnectionManager(Mutex<HashMap<Id, WebSocketWriter>>);
65
66#[cfg(any(feature = "rustls-tls", feature = "native-tls"))]
67struct TlsConnector(Mutex<Option<Connector>>);
68
69#[derive(Deserialize)]
70#[serde(untagged, rename_all = "camelCase")]
71enum Max {
72 None,
73 Number(usize),
74}
75
76#[derive(Deserialize)]
77#[serde(rename_all = "camelCase")]
78pub(crate) struct ConnectionConfig {
79 pub read_buffer_size: Option<usize>,
80 pub write_buffer_size: Option<usize>,
81 pub max_write_buffer_size: Option<usize>,
82 pub max_message_size: Option<Max>,
83 pub max_frame_size: Option<Max>,
84 #[serde(default)]
85 pub accept_unmasked_frames: bool,
86 pub headers: Option<Vec<(String, String)>>,
87}
88
89impl From<ConnectionConfig> for WebSocketConfig {
90 fn from(config: ConnectionConfig) -> Self {
91 let mut builder =
92 WebSocketConfig::default().accept_unmasked_frames(config.accept_unmasked_frames);
93
94 if let Some(read_buffer_size) = config.read_buffer_size {
95 builder = builder.read_buffer_size(read_buffer_size)
96 }
97
98 if let Some(write_buffer_size) = config.write_buffer_size {
99 builder = builder.write_buffer_size(write_buffer_size)
100 }
101
102 if let Some(max_write_buffer_size) = config.max_write_buffer_size {
103 builder = builder.max_write_buffer_size(max_write_buffer_size)
104 }
105
106 if let Some(max_message_size) = config.max_message_size {
107 let max_size = match max_message_size {
108 Max::None => Option::None,
109 Max::Number(n) => Some(n),
110 };
111 builder = builder.max_message_size(max_size);
112 }
113
114 if let Some(max_frame_size) = config.max_frame_size {
115 let max_size = match max_frame_size {
116 Max::None => Option::None,
117 Max::Number(n) => Some(n),
118 };
119 builder = builder.max_frame_size(max_size);
120 }
121
122 builder
123 }
124}
125
126#[derive(Deserialize, Serialize)]
127struct CloseFrame {
128 pub code: u16,
129 pub reason: String,
130}
131
132#[derive(Deserialize, Serialize)]
133#[serde(tag = "type", content = "data")]
134enum WebSocketMessage {
135 Text(String),
136 Binary(Vec<u8>),
137 Ping(Vec<u8>),
138 Pong(Vec<u8>),
139 Close(Option<CloseFrame>),
140}
141
142#[tauri::command]
143async fn connect<R: Runtime>(
144 window: Window<R>,
145 url: String,
146 on_message: Channel<serde_json::Value>,
147 config: Option<ConnectionConfig>,
148) -> Result<Id> {
149 let id = rand::random();
150 let mut request = url.into_client_request()?;
151
152 if let Some(headers) = config.as_ref().and_then(|c| c.headers.as_ref()) {
153 for (k, v) in headers {
154 let header_name = HeaderName::from_str(k.as_str())?;
155 let header_value = HeaderValue::from_str(v.as_str())?;
156 request.headers_mut().insert(header_name, header_value);
157 }
158 }
159
160 #[cfg(any(feature = "rustls-tls", feature = "native-tls"))]
161 let tls_connector = match window.try_state::<TlsConnector>() {
162 Some(tls_connector) => tls_connector.0.lock().await.clone(),
163 None => None,
164 };
165
166 #[cfg(any(feature = "rustls-tls", feature = "native-tls"))]
167 let (ws_stream, _) =
168 connect_async_tls_with_config(request, config.map(Into::into), false, tls_connector)
169 .await?;
170 #[cfg(not(any(feature = "rustls-tls", feature = "native-tls")))]
171 let (ws_stream, _) = connect_async_with_config(request, config.map(Into::into), false).await?;
172
173 tauri::async_runtime::spawn(async move {
174 let (write, read) = ws_stream.split();
175 let manager = window.state::<ConnectionManager>();
176 manager.0.lock().await.insert(id, write);
177 read.for_each(move |message| {
178 let window_ = window.clone();
179 let on_message_ = on_message.clone();
180 async move {
181 if let Ok(Message::Close(_)) = message {
182 let manager = window_.state::<ConnectionManager>();
183 manager.0.lock().await.remove(&id);
184 }
185
186 let response = match message {
187 Ok(Message::Text(t)) => {
188 serde_json::to_value(WebSocketMessage::Text(t.to_string())).unwrap()
189 }
190 Ok(Message::Binary(t)) => {
191 serde_json::to_value(WebSocketMessage::Binary(t.to_vec())).unwrap()
192 }
193 Ok(Message::Ping(t)) => {
194 serde_json::to_value(WebSocketMessage::Ping(t.to_vec())).unwrap()
195 }
196 Ok(Message::Pong(t)) => {
197 serde_json::to_value(WebSocketMessage::Pong(t.to_vec())).unwrap()
198 }
199 Ok(Message::Close(t)) => {
200 serde_json::to_value(WebSocketMessage::Close(t.map(|v| CloseFrame {
201 code: v.code.into(),
202 reason: v.reason.to_string(),
203 })))
204 .unwrap()
205 }
206 Ok(Message::Frame(_)) => serde_json::Value::Null, Err(e) => serde_json::to_value(Error::from(e)).unwrap(),
208 };
209
210 let _ = on_message_.send(response);
211 }
212 })
213 .await;
214 });
215
216 Ok(id)
217}
218
219#[tauri::command]
220async fn send(
221 manager: State<'_, ConnectionManager>,
222 id: Id,
223 message: WebSocketMessage,
224) -> Result<()> {
225 if let Some(write) = manager.0.lock().await.get_mut(&id) {
226 write
227 .send(match message {
228 WebSocketMessage::Text(t) => Message::Text(t.into()),
229 WebSocketMessage::Binary(t) => Message::Binary(t.into()),
230 WebSocketMessage::Ping(t) => Message::Ping(t.into()),
231 WebSocketMessage::Pong(t) => Message::Pong(t.into()),
232 WebSocketMessage::Close(t) => Message::Close(t.map(|v| ProtocolCloseFrame {
233 code: v.code.into(),
234 reason: v.reason.into(),
235 })),
236 })
237 .await?;
238 Ok(())
239 } else {
240 Err(Error::ConnectionNotFound(id))
241 }
242}
243
244pub fn init<R: Runtime>() -> TauriPlugin<R> {
245 Builder::default().build()
246}
247
248#[derive(Default)]
249pub struct Builder {
250 tls_connector: Option<Connector>,
251}
252
253impl Builder {
254 pub fn new() -> Self {
255 Self {
256 tls_connector: None,
257 }
258 }
259
260 pub fn tls_connector(mut self, connector: Connector) -> Self {
261 self.tls_connector.replace(connector);
262 self
263 }
264
265 pub fn build<R: Runtime>(self) -> TauriPlugin<R> {
266 PluginBuilder::new("websocket")
267 .invoke_handler(tauri::generate_handler![connect, send])
268 .setup(|app, _api| {
269 app.manage(ConnectionManager::default());
270 #[cfg(any(feature = "rustls-tls", feature = "native-tls"))]
271 app.manage(TlsConnector(Mutex::new(self.tls_connector)));
272 Ok(())
273 })
274 .build()
275 }
276}