tauri_plugin_websocket/
lib.rs

1// Copyright 2019-2023 Tauri Programme within The Commons Conservancy
2// SPDX-License-Identifier: Apache-2.0
3// SPDX-License-Identifier: MIT
4
5//! Open a WebSocket connection using a Rust client in JS.
6
7#![doc(
8    html_logo_url = "https://github.com/tauri-apps/tauri/raw/dev/app-icon.png",
9    html_favicon_url = "https://github.com/tauri-apps/tauri/raw/dev/app-icon.png"
10)]
11
12use futures_util::{stream::SplitSink, SinkExt, StreamExt};
13use http::header::{HeaderName, HeaderValue};
14use serde::{ser::Serializer, Deserialize, Serialize};
15use tauri::{
16    ipc::Channel,
17    plugin::{Builder as PluginBuilder, TauriPlugin},
18    Manager, Runtime, State, Window,
19};
20use tokio::{net::TcpStream, sync::Mutex};
21#[cfg(any(feature = "rustls-tls", feature = "native-tls"))]
22use tokio_tungstenite::connect_async_tls_with_config;
23#[cfg(not(any(feature = "rustls-tls", feature = "native-tls")))]
24use tokio_tungstenite::connect_async_with_config;
25use tokio_tungstenite::{
26    tungstenite::{
27        client::IntoClientRequest,
28        protocol::{CloseFrame as ProtocolCloseFrame, WebSocketConfig},
29        Message,
30    },
31    Connector, MaybeTlsStream, WebSocketStream,
32};
33
34use std::collections::HashMap;
35use std::str::FromStr;
36
37type Id = u32;
38type WebSocket = WebSocketStream<MaybeTlsStream<TcpStream>>;
39type WebSocketWriter = SplitSink<WebSocket, Message>;
40type Result<T> = std::result::Result<T, Error>;
41
42#[derive(Debug, thiserror::Error)]
43enum Error {
44    #[error(transparent)]
45    Websocket(#[from] tokio_tungstenite::tungstenite::Error),
46    #[error("connection not found for the given id: {0}")]
47    ConnectionNotFound(Id),
48    #[error(transparent)]
49    InvalidHeaderValue(#[from] tokio_tungstenite::tungstenite::http::header::InvalidHeaderValue),
50    #[error(transparent)]
51    InvalidHeaderName(#[from] tokio_tungstenite::tungstenite::http::header::InvalidHeaderName),
52}
53
54impl Serialize for Error {
55    fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
56    where
57        S: Serializer,
58    {
59        serializer.serialize_str(self.to_string().as_str())
60    }
61}
62
63#[derive(Default)]
64struct ConnectionManager(Mutex<HashMap<Id, WebSocketWriter>>);
65
66#[cfg(any(feature = "rustls-tls", feature = "native-tls"))]
67struct TlsConnector(Mutex<Option<Connector>>);
68
69#[derive(Deserialize)]
70#[serde(untagged, rename_all = "camelCase")]
71enum Max {
72    None,
73    Number(usize),
74}
75
76#[derive(Deserialize)]
77#[serde(rename_all = "camelCase")]
78pub(crate) struct ConnectionConfig {
79    pub read_buffer_size: Option<usize>,
80    pub write_buffer_size: Option<usize>,
81    pub max_write_buffer_size: Option<usize>,
82    pub max_message_size: Option<Max>,
83    pub max_frame_size: Option<Max>,
84    #[serde(default)]
85    pub accept_unmasked_frames: bool,
86    pub headers: Option<Vec<(String, String)>>,
87}
88
89impl From<ConnectionConfig> for WebSocketConfig {
90    fn from(config: ConnectionConfig) -> Self {
91        let mut builder =
92            WebSocketConfig::default().accept_unmasked_frames(config.accept_unmasked_frames);
93
94        if let Some(read_buffer_size) = config.read_buffer_size {
95            builder = builder.read_buffer_size(read_buffer_size)
96        }
97
98        if let Some(write_buffer_size) = config.write_buffer_size {
99            builder = builder.write_buffer_size(write_buffer_size)
100        }
101
102        if let Some(max_write_buffer_size) = config.max_write_buffer_size {
103            builder = builder.max_write_buffer_size(max_write_buffer_size)
104        }
105
106        if let Some(max_message_size) = config.max_message_size {
107            let max_size = match max_message_size {
108                Max::None => Option::None,
109                Max::Number(n) => Some(n),
110            };
111            builder = builder.max_message_size(max_size);
112        }
113
114        if let Some(max_frame_size) = config.max_frame_size {
115            let max_size = match max_frame_size {
116                Max::None => Option::None,
117                Max::Number(n) => Some(n),
118            };
119            builder = builder.max_frame_size(max_size);
120        }
121
122        builder
123    }
124}
125
126#[derive(Deserialize, Serialize)]
127struct CloseFrame {
128    pub code: u16,
129    pub reason: String,
130}
131
132#[derive(Deserialize, Serialize)]
133#[serde(tag = "type", content = "data")]
134enum WebSocketMessage {
135    Text(String),
136    Binary(Vec<u8>),
137    Ping(Vec<u8>),
138    Pong(Vec<u8>),
139    Close(Option<CloseFrame>),
140}
141
142#[tauri::command]
143async fn connect<R: Runtime>(
144    window: Window<R>,
145    url: String,
146    on_message: Channel<serde_json::Value>,
147    config: Option<ConnectionConfig>,
148) -> Result<Id> {
149    let id = rand::random();
150    let mut request = url.into_client_request()?;
151
152    if let Some(headers) = config.as_ref().and_then(|c| c.headers.as_ref()) {
153        for (k, v) in headers {
154            let header_name = HeaderName::from_str(k.as_str())?;
155            let header_value = HeaderValue::from_str(v.as_str())?;
156            request.headers_mut().insert(header_name, header_value);
157        }
158    }
159
160    #[cfg(any(feature = "rustls-tls", feature = "native-tls"))]
161    let tls_connector = match window.try_state::<TlsConnector>() {
162        Some(tls_connector) => tls_connector.0.lock().await.clone(),
163        None => None,
164    };
165
166    #[cfg(any(feature = "rustls-tls", feature = "native-tls"))]
167    let (ws_stream, _) =
168        connect_async_tls_with_config(request, config.map(Into::into), false, tls_connector)
169            .await?;
170    #[cfg(not(any(feature = "rustls-tls", feature = "native-tls")))]
171    let (ws_stream, _) = connect_async_with_config(request, config.map(Into::into), false).await?;
172
173    tauri::async_runtime::spawn(async move {
174        let (write, read) = ws_stream.split();
175        let manager = window.state::<ConnectionManager>();
176        manager.0.lock().await.insert(id, write);
177        read.for_each(move |message| {
178            let window_ = window.clone();
179            let on_message_ = on_message.clone();
180            async move {
181                if let Ok(Message::Close(_)) = message {
182                    let manager = window_.state::<ConnectionManager>();
183                    manager.0.lock().await.remove(&id);
184                }
185
186                let response = match message {
187                    Ok(Message::Text(t)) => {
188                        serde_json::to_value(WebSocketMessage::Text(t.to_string())).unwrap()
189                    }
190                    Ok(Message::Binary(t)) => {
191                        serde_json::to_value(WebSocketMessage::Binary(t.to_vec())).unwrap()
192                    }
193                    Ok(Message::Ping(t)) => {
194                        serde_json::to_value(WebSocketMessage::Ping(t.to_vec())).unwrap()
195                    }
196                    Ok(Message::Pong(t)) => {
197                        serde_json::to_value(WebSocketMessage::Pong(t.to_vec())).unwrap()
198                    }
199                    Ok(Message::Close(t)) => {
200                        serde_json::to_value(WebSocketMessage::Close(t.map(|v| CloseFrame {
201                            code: v.code.into(),
202                            reason: v.reason.to_string(),
203                        })))
204                        .unwrap()
205                    }
206                    Ok(Message::Frame(_)) => serde_json::Value::Null, // This value can't be recieved.
207                    Err(e) => serde_json::to_value(Error::from(e)).unwrap(),
208                };
209
210                let _ = on_message_.send(response);
211            }
212        })
213        .await;
214    });
215
216    Ok(id)
217}
218
219#[tauri::command]
220async fn send(
221    manager: State<'_, ConnectionManager>,
222    id: Id,
223    message: WebSocketMessage,
224) -> Result<()> {
225    if let Some(write) = manager.0.lock().await.get_mut(&id) {
226        write
227            .send(match message {
228                WebSocketMessage::Text(t) => Message::Text(t.into()),
229                WebSocketMessage::Binary(t) => Message::Binary(t.into()),
230                WebSocketMessage::Ping(t) => Message::Ping(t.into()),
231                WebSocketMessage::Pong(t) => Message::Pong(t.into()),
232                WebSocketMessage::Close(t) => Message::Close(t.map(|v| ProtocolCloseFrame {
233                    code: v.code.into(),
234                    reason: v.reason.into(),
235                })),
236            })
237            .await?;
238        Ok(())
239    } else {
240        Err(Error::ConnectionNotFound(id))
241    }
242}
243
244pub fn init<R: Runtime>() -> TauriPlugin<R> {
245    Builder::default().build()
246}
247
248#[derive(Default)]
249pub struct Builder {
250    tls_connector: Option<Connector>,
251}
252
253impl Builder {
254    pub fn new() -> Self {
255        Self {
256            tls_connector: None,
257        }
258    }
259
260    pub fn tls_connector(mut self, connector: Connector) -> Self {
261        self.tls_connector.replace(connector);
262        self
263    }
264
265    pub fn build<R: Runtime>(self) -> TauriPlugin<R> {
266        PluginBuilder::new("websocket")
267            .invoke_handler(tauri::generate_handler![connect, send])
268            .setup(|app, _api| {
269                app.manage(ConnectionManager::default());
270                #[cfg(any(feature = "rustls-tls", feature = "native-tls"))]
271                app.manage(TlsConnector(Mutex::new(self.tls_connector)));
272                Ok(())
273            })
274            .build()
275    }
276}