#![doc(
html_logo_url = "https://github.com/tauri-apps/tauri/raw/dev/app-icon.png",
html_favicon_url = "https://github.com/tauri-apps/tauri/raw/dev/app-icon.png"
)]
use futures_util::{stream::SplitSink, SinkExt, StreamExt};
use http::header::{HeaderName, HeaderValue};
use serde::{ser::Serializer, Deserialize, Serialize};
use tauri::{
ipc::Channel,
plugin::{Builder as PluginBuilder, TauriPlugin},
Manager, Runtime, State, Window,
};
use tokio::{net::TcpStream, sync::Mutex};
#[cfg(any(feature = "rustls-tls", feature = "native-tls"))]
use tokio_tungstenite::connect_async_tls_with_config;
#[cfg(not(any(feature = "rustls-tls", feature = "native-tls")))]
use tokio_tungstenite::connect_async_with_config;
use tokio_tungstenite::{
tungstenite::{
client::IntoClientRequest,
protocol::{CloseFrame as ProtocolCloseFrame, WebSocketConfig},
Message,
},
Connector, MaybeTlsStream, WebSocketStream,
};
use std::collections::HashMap;
use std::str::FromStr;
type Id = u32;
type WebSocket = WebSocketStream<MaybeTlsStream<TcpStream>>;
type WebSocketWriter = SplitSink<WebSocket, Message>;
type Result<T> = std::result::Result<T, Error>;
#[derive(Debug, thiserror::Error)]
enum Error {
#[error(transparent)]
Websocket(#[from] tokio_tungstenite::tungstenite::Error),
#[error("connection not found for the given id: {0}")]
ConnectionNotFound(Id),
#[error(transparent)]
InvalidHeaderValue(#[from] tokio_tungstenite::tungstenite::http::header::InvalidHeaderValue),
#[error(transparent)]
InvalidHeaderName(#[from] tokio_tungstenite::tungstenite::http::header::InvalidHeaderName),
}
impl Serialize for Error {
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_str(self.to_string().as_str())
}
}
#[derive(Default)]
struct ConnectionManager(Mutex<HashMap<Id, WebSocketWriter>>);
#[cfg(any(feature = "rustls-tls", feature = "native-tls"))]
struct TlsConnector(Mutex<Option<Connector>>);
#[derive(Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ConnectionConfig {
pub write_buffer_size: Option<usize>,
pub max_write_buffer_size: Option<usize>,
pub max_message_size: Option<usize>,
pub max_frame_size: Option<usize>,
#[serde(default)]
pub accept_unmasked_frames: bool,
pub headers: Option<Vec<(String, String)>>,
}
impl From<ConnectionConfig> for WebSocketConfig {
fn from(config: ConnectionConfig) -> Self {
#[allow(deprecated)]
Self {
max_send_queue: None,
write_buffer_size: config.write_buffer_size.unwrap_or(128 * 1024),
max_write_buffer_size: config.max_write_buffer_size.unwrap_or(usize::MAX),
max_message_size: config.max_message_size,
max_frame_size: config.max_frame_size,
accept_unmasked_frames: config.accept_unmasked_frames,
}
}
}
#[derive(Deserialize, Serialize)]
struct CloseFrame {
pub code: u16,
pub reason: String,
}
#[derive(Deserialize, Serialize)]
#[serde(tag = "type", content = "data")]
enum WebSocketMessage {
Text(String),
Binary(Vec<u8>),
Ping(Vec<u8>),
Pong(Vec<u8>),
Close(Option<CloseFrame>),
}
#[tauri::command]
async fn connect<R: Runtime>(
window: Window<R>,
url: String,
on_message: Channel,
config: Option<ConnectionConfig>,
) -> Result<Id> {
let id = rand::random();
let mut request = url.into_client_request()?;
if let Some(headers) = config.as_ref().and_then(|c| c.headers.as_ref()) {
for (k, v) in headers {
let header_name = HeaderName::from_str(k.as_str())?;
let header_value = HeaderValue::from_str(v.as_str())?;
request.headers_mut().insert(header_name, header_value);
}
}
#[cfg(any(feature = "rustls-tls", feature = "native-tls"))]
let tls_connector = match window.try_state::<TlsConnector>() {
Some(tls_connector) => tls_connector.0.lock().await.clone(),
None => None,
};
#[cfg(any(feature = "rustls-tls", feature = "native-tls"))]
let (ws_stream, _) =
connect_async_tls_with_config(request, config.map(Into::into), false, tls_connector)
.await?;
#[cfg(not(any(feature = "rustls-tls", feature = "native-tls")))]
let (ws_stream, _) = connect_async_with_config(request, config.map(Into::into), false).await?;
tauri::async_runtime::spawn(async move {
let (write, read) = ws_stream.split();
let manager = window.state::<ConnectionManager>();
manager.0.lock().await.insert(id, write);
read.for_each(move |message| {
let window_ = window.clone();
let on_message_ = on_message.clone();
async move {
if let Ok(Message::Close(_)) = message {
let manager = window_.state::<ConnectionManager>();
manager.0.lock().await.remove(&id);
}
let response = match message {
Ok(Message::Text(t)) => {
serde_json::to_value(WebSocketMessage::Text(t)).unwrap()
}
Ok(Message::Binary(t)) => {
serde_json::to_value(WebSocketMessage::Binary(t)).unwrap()
}
Ok(Message::Ping(t)) => {
serde_json::to_value(WebSocketMessage::Ping(t)).unwrap()
}
Ok(Message::Pong(t)) => {
serde_json::to_value(WebSocketMessage::Pong(t)).unwrap()
}
Ok(Message::Close(t)) => {
serde_json::to_value(WebSocketMessage::Close(t.map(|v| CloseFrame {
code: v.code.into(),
reason: v.reason.into_owned(),
})))
.unwrap()
}
Ok(Message::Frame(_)) => serde_json::Value::Null, Err(e) => serde_json::to_value(Error::from(e)).unwrap(),
};
let _ = on_message_.send(response);
}
})
.await;
});
Ok(id)
}
#[tauri::command]
async fn send(
manager: State<'_, ConnectionManager>,
id: Id,
message: WebSocketMessage,
) -> Result<()> {
if let Some(write) = manager.0.lock().await.get_mut(&id) {
write
.send(match message {
WebSocketMessage::Text(t) => Message::Text(t),
WebSocketMessage::Binary(t) => Message::Binary(t),
WebSocketMessage::Ping(t) => Message::Ping(t),
WebSocketMessage::Pong(t) => Message::Pong(t),
WebSocketMessage::Close(t) => Message::Close(t.map(|v| ProtocolCloseFrame {
code: v.code.into(),
reason: std::borrow::Cow::Owned(v.reason),
})),
})
.await?;
Ok(())
} else {
Err(Error::ConnectionNotFound(id))
}
}
pub fn init<R: Runtime>() -> TauriPlugin<R> {
Builder::default().build()
}
#[derive(Default)]
pub struct Builder {
tls_connector: Option<Connector>,
}
impl Builder {
pub fn new() -> Self {
Self {
tls_connector: None,
}
}
pub fn tls_connector(mut self, connector: Connector) -> Self {
self.tls_connector.replace(connector);
self
}
pub fn build<R: Runtime>(self) -> TauriPlugin<R> {
PluginBuilder::new("websocket")
.invoke_handler(tauri::generate_handler![connect, send])
.setup(|app, _api| {
app.manage(ConnectionManager::default());
#[cfg(any(feature = "rustls-tls", feature = "native-tls"))]
app.manage(TlsConnector(Mutex::new(self.tls_connector)));
Ok(())
})
.build()
}
}