tauri_plugin_rspc/
lib.rs

1//! [Tauri](https://tauri.app) integration for [rspc](https://rspc.dev).
2//!
3//! # Example
4//!
5//! ```rust
6//! use rspc::Router;
7//!
8//! let router = Router::new();
9//! let (procedures, _types) = router.build().unwrap();
10//!
11//! tauri::Builder::default()
12//!     .plugin(tauri_plugin_rspc::init(procedures, |window| todo!()))
13//!     .run(tauri::generate_context!())
14//!     .expect("error while running tauri application");
15//! ```
16//!
17#![forbid(unsafe_code)]
18#![cfg_attr(docsrs, feature(doc_cfg))]
19#![doc(
20    html_logo_url = "https://github.com/specta-rs/rspc/raw/main/.github/logo.png",
21    html_favicon_url = "https://github.com/specta-rs/rspc/raw/main/.github/logo.png"
22)]
23
24use std::{
25    borrow::Cow,
26    collections::HashMap,
27    sync::{Arc, Mutex, MutexGuard, PoisonError},
28};
29
30use rspc_procedure::{ProcedureError, Procedures};
31use serde::{de::Error, Deserialize, Serialize};
32use serde_json::value::RawValue;
33use tauri::{
34    async_runtime::{spawn, JoinHandle},
35    generate_handler,
36    ipc::{Channel, InvokeResponseBody, IpcResponse},
37    plugin::{Builder, TauriPlugin},
38    Manager,
39};
40
41struct RpcHandler<R, TCtxFn, TCtx> {
42    subscriptions: Mutex<HashMap<u32, JoinHandle<()>>>,
43    ctx_fn: TCtxFn,
44    procedures: Procedures<TCtx>,
45    phantom: std::marker::PhantomData<fn() -> R>,
46}
47
48impl<R, TCtxFn, TCtx> RpcHandler<R, TCtxFn, TCtx>
49where
50    R: tauri::Runtime,
51    TCtxFn: Fn(tauri::Window<R>) -> TCtx + Send + Sync + 'static,
52    TCtx: Send + 'static,
53{
54    fn subscriptions(&self) -> MutexGuard<HashMap<u32, JoinHandle<()>>> {
55        self.subscriptions
56            .lock()
57            .unwrap_or_else(PoisonError::into_inner)
58    }
59
60    fn handle_rpc_impl(
61        self: Arc<Self>,
62        window: tauri::Window<R>,
63        channel: tauri::ipc::Channel<IpcResultResponse>,
64        req: Request,
65    ) {
66        match req {
67            Request::Request { path, input } => {
68                let id = channel.id();
69                let ctx = (self.ctx_fn)(window);
70
71                let Some(procedure) = self.procedures.get(&Cow::Borrowed(&*path)) else {
72                    let err = ProcedureError::NotFound;
73                    send(
74                        &channel,
75                        Response::Value {
76                            code: match err {
77                                ProcedureError::NotFound => 404,
78                                ProcedureError::Deserialize(_) => 400,
79                                ProcedureError::Downcast(_) => 400,
80                                ProcedureError::Resolver(_) => 500, // This is a breaking change. It previously came from the user.
81                                ProcedureError::Unwind(_) => 500,
82                            },
83                            value: &err,
84                        },
85                    );
86                    send::<()>(&channel, Response::Done);
87                    return;
88                };
89
90                let mut stream = match input {
91                    Some(i) => procedure.exec_with_deserializer(ctx, i.as_ref()),
92                    None => procedure.exec_with_deserializer(ctx, serde_json::Value::Null),
93                };
94
95                let this = self.clone();
96                let handle = spawn(async move {
97                    while let Some(value) = stream.next().await {
98                        match value {
99                            Ok(v) => send(
100                                &channel,
101                                Response::Value {
102                                    code: 200,
103                                    value: &v.as_serialize().unwrap(),
104                                },
105                            ),
106                            Err(err) => send(
107                                &channel,
108                                Response::Value {
109                                    code: match err {
110                                        ProcedureError::NotFound => 404,
111                                        ProcedureError::Deserialize(_) => 400,
112                                        ProcedureError::Downcast(_) => 400,
113                                        ProcedureError::Resolver(_) => 500, // This is a breaking change. It previously came from the user.
114                                        ProcedureError::Unwind(_) => 500,
115                                    },
116                                    value: &err,
117                                },
118                            ),
119                        }
120                    }
121
122                    this.subscriptions().remove(&id);
123                    send::<()>(&channel, Response::Done);
124                });
125
126                // if the client uses an existing ID, we will assume the previous subscription is no longer required
127                if let Some(old) = self.subscriptions().insert(id, handle) {
128                    old.abort();
129                }
130            }
131            Request::Abort(id) => {
132                if let Some(h) = self.subscriptions().remove(&id) {
133                    h.abort();
134                }
135            }
136        }
137    }
138}
139
140trait HandleRpc<R: tauri::Runtime>: Send + Sync {
141    fn handle_rpc(
142        self: Arc<Self>,
143        window: tauri::Window<R>,
144        channel: tauri::ipc::Channel<IpcResultResponse>,
145        req: Request,
146    );
147}
148
149impl<R, TCtxFn, TCtx> HandleRpc<R> for RpcHandler<R, TCtxFn, TCtx>
150where
151    R: tauri::Runtime,
152    TCtxFn: Fn(tauri::Window<R>) -> TCtx + Send + Sync + 'static,
153    TCtx: Send + 'static,
154{
155    fn handle_rpc(
156        self: Arc<Self>,
157        window: tauri::Window<R>,
158        channel: tauri::ipc::Channel<IpcResultResponse>,
159        req: Request,
160    ) {
161        Self::handle_rpc_impl(self, window, channel, req);
162    }
163}
164
165// Tauri commands can't be generic except for their runtime,
166// so we need to store + access the handler behind a trait.
167// This way handle_rpc_impl has full access to the generics it was instantiated with,
168// while State can be stored a) as a singleton (enforced by the type system!) and b) as type erased Tauri state
169struct State<R>(Arc<dyn HandleRpc<R>>);
170
171#[tauri::command]
172fn handle_rpc<R: tauri::Runtime>(
173    state: tauri::State<'_, State<R>>,
174    window: tauri::Window<R>,
175    channel: tauri::ipc::Channel<IpcResultResponse>,
176    req: Request,
177) {
178    state.0.clone().handle_rpc(window, channel, req);
179}
180
181pub fn init<R, TCtxFn, TCtx>(
182    procedures: impl Into<Procedures<TCtx>>,
183    ctx_fn: TCtxFn,
184) -> TauriPlugin<R>
185where
186    R: tauri::Runtime,
187    TCtxFn: Fn(tauri::Window<R>) -> TCtx + Send + Sync + 'static,
188    TCtx: Send + Sync + 'static,
189{
190    let procedures = procedures.into();
191
192    Builder::new("rspc")
193        .invoke_handler(generate_handler![handle_rpc])
194        .setup(move |app_handle, _| {
195            if !app_handle.manage(State(Arc::new(RpcHandler {
196                subscriptions: Default::default(),
197                ctx_fn,
198                procedures,
199                phantom: Default::default(),
200            }))) {
201                panic!("Attempted to mount `rspc_tauri::plugin` multiple times. Please ensure you only mount it once!");
202            }
203
204            Ok(())
205        })
206        .build()
207}
208
209#[derive(Deserialize, Serialize)]
210#[serde(tag = "method", content = "params", rename_all = "camelCase")]
211enum Request {
212    /// A request to execute a procedure.
213    Request {
214        path: String,
215        // #[serde(borrow)]
216        input: Option<Box<RawValue>>,
217    },
218    /// Abort a running task
219    /// You must provide the ID of the Tauri channel provided when the task was started.
220    Abort(u32),
221}
222
223#[derive(Serialize)]
224#[serde(untagged)]
225enum Response<'a, T: Serialize> {
226    Value { code: u16, value: &'a T },
227    Done,
228}
229
230fn send<'a, T: Serialize>(channel: &Channel<IpcResultResponse>, value: Response<'a, T>) {
231    channel
232        .send(IpcResultResponse(
233            serde_json::to_string(&value)
234                .map(|value| InvokeResponseBody::Json(value))
235                .map_err(|err| err.to_string()),
236        ))
237        .ok();
238}
239
240#[derive(Clone)]
241struct IpcResultResponse(Result<InvokeResponseBody, String>);
242
243impl IpcResponse for IpcResultResponse {
244    fn body(self) -> tauri::Result<InvokeResponseBody> {
245        self.0.map_err(|err| serde_json::Error::custom(err).into())
246    }
247}