1#![forbid(unsafe_code)]
18#![cfg_attr(docsrs, feature(doc_cfg))]
19#![doc(
20 html_logo_url = "https://github.com/specta-rs/rspc/raw/main/.github/logo.png",
21 html_favicon_url = "https://github.com/specta-rs/rspc/raw/main/.github/logo.png"
22)]
23
24use std::{
25 borrow::Cow,
26 collections::HashMap,
27 sync::{Arc, Mutex, MutexGuard, PoisonError},
28};
29
30use rspc_procedure::{ProcedureError, Procedures};
31use serde::{de::Error, Deserialize, Serialize};
32use serde_json::value::RawValue;
33use tauri::{
34 async_runtime::{spawn, JoinHandle},
35 generate_handler,
36 ipc::{Channel, InvokeResponseBody, IpcResponse},
37 plugin::{Builder, TauriPlugin},
38 Manager,
39};
40
41struct RpcHandler<R, TCtxFn, TCtx> {
42 subscriptions: Mutex<HashMap<u32, JoinHandle<()>>>,
43 ctx_fn: TCtxFn,
44 procedures: Procedures<TCtx>,
45 phantom: std::marker::PhantomData<fn() -> R>,
46}
47
48impl<R, TCtxFn, TCtx> RpcHandler<R, TCtxFn, TCtx>
49where
50 R: tauri::Runtime,
51 TCtxFn: Fn(tauri::Window<R>) -> TCtx + Send + Sync + 'static,
52 TCtx: Send + 'static,
53{
54 fn subscriptions(&self) -> MutexGuard<HashMap<u32, JoinHandle<()>>> {
55 self.subscriptions
56 .lock()
57 .unwrap_or_else(PoisonError::into_inner)
58 }
59
60 fn handle_rpc_impl(
61 self: Arc<Self>,
62 window: tauri::Window<R>,
63 channel: tauri::ipc::Channel<IpcResultResponse>,
64 req: Request,
65 ) {
66 match req {
67 Request::Request { path, input } => {
68 let id = channel.id();
69 let ctx = (self.ctx_fn)(window);
70
71 let Some(procedure) = self.procedures.get(&Cow::Borrowed(&*path)) else {
72 let err = ProcedureError::NotFound;
73 send(
74 &channel,
75 Response::Value {
76 code: match err {
77 ProcedureError::NotFound => 404,
78 ProcedureError::Deserialize(_) => 400,
79 ProcedureError::Downcast(_) => 400,
80 ProcedureError::Resolver(_) => 500, ProcedureError::Unwind(_) => 500,
82 },
83 value: &err,
84 },
85 );
86 send::<()>(&channel, Response::Done);
87 return;
88 };
89
90 let mut stream = match input {
91 Some(i) => procedure.exec_with_deserializer(ctx, i.as_ref()),
92 None => procedure.exec_with_deserializer(ctx, serde_json::Value::Null),
93 };
94
95 let this = self.clone();
96 let handle = spawn(async move {
97 while let Some(value) = stream.next().await {
98 match value {
99 Ok(v) => send(
100 &channel,
101 Response::Value {
102 code: 200,
103 value: &v.as_serialize().unwrap(),
104 },
105 ),
106 Err(err) => send(
107 &channel,
108 Response::Value {
109 code: match err {
110 ProcedureError::NotFound => 404,
111 ProcedureError::Deserialize(_) => 400,
112 ProcedureError::Downcast(_) => 400,
113 ProcedureError::Resolver(_) => 500, ProcedureError::Unwind(_) => 500,
115 },
116 value: &err,
117 },
118 ),
119 }
120 }
121
122 this.subscriptions().remove(&id);
123 send::<()>(&channel, Response::Done);
124 });
125
126 if let Some(old) = self.subscriptions().insert(id, handle) {
128 old.abort();
129 }
130 }
131 Request::Abort(id) => {
132 if let Some(h) = self.subscriptions().remove(&id) {
133 h.abort();
134 }
135 }
136 }
137 }
138}
139
140trait HandleRpc<R: tauri::Runtime>: Send + Sync {
141 fn handle_rpc(
142 self: Arc<Self>,
143 window: tauri::Window<R>,
144 channel: tauri::ipc::Channel<IpcResultResponse>,
145 req: Request,
146 );
147}
148
149impl<R, TCtxFn, TCtx> HandleRpc<R> for RpcHandler<R, TCtxFn, TCtx>
150where
151 R: tauri::Runtime,
152 TCtxFn: Fn(tauri::Window<R>) -> TCtx + Send + Sync + 'static,
153 TCtx: Send + 'static,
154{
155 fn handle_rpc(
156 self: Arc<Self>,
157 window: tauri::Window<R>,
158 channel: tauri::ipc::Channel<IpcResultResponse>,
159 req: Request,
160 ) {
161 Self::handle_rpc_impl(self, window, channel, req);
162 }
163}
164
165struct State<R>(Arc<dyn HandleRpc<R>>);
170
171#[tauri::command]
172fn handle_rpc<R: tauri::Runtime>(
173 state: tauri::State<'_, State<R>>,
174 window: tauri::Window<R>,
175 channel: tauri::ipc::Channel<IpcResultResponse>,
176 req: Request,
177) {
178 state.0.clone().handle_rpc(window, channel, req);
179}
180
181pub fn init<R, TCtxFn, TCtx>(
182 procedures: impl Into<Procedures<TCtx>>,
183 ctx_fn: TCtxFn,
184) -> TauriPlugin<R>
185where
186 R: tauri::Runtime,
187 TCtxFn: Fn(tauri::Window<R>) -> TCtx + Send + Sync + 'static,
188 TCtx: Send + Sync + 'static,
189{
190 let procedures = procedures.into();
191
192 Builder::new("rspc")
193 .invoke_handler(generate_handler![handle_rpc])
194 .setup(move |app_handle, _| {
195 if !app_handle.manage(State(Arc::new(RpcHandler {
196 subscriptions: Default::default(),
197 ctx_fn,
198 procedures,
199 phantom: Default::default(),
200 }))) {
201 panic!("Attempted to mount `rspc_tauri::plugin` multiple times. Please ensure you only mount it once!");
202 }
203
204 Ok(())
205 })
206 .build()
207}
208
209#[derive(Deserialize, Serialize)]
210#[serde(tag = "method", content = "params", rename_all = "camelCase")]
211enum Request {
212 Request {
214 path: String,
215 input: Option<Box<RawValue>>,
217 },
218 Abort(u32),
221}
222
223#[derive(Serialize)]
224#[serde(untagged)]
225enum Response<'a, T: Serialize> {
226 Value { code: u16, value: &'a T },
227 Done,
228}
229
230fn send<'a, T: Serialize>(channel: &Channel<IpcResultResponse>, value: Response<'a, T>) {
231 channel
232 .send(IpcResultResponse(
233 serde_json::to_string(&value)
234 .map(|value| InvokeResponseBody::Json(value))
235 .map_err(|err| err.to_string()),
236 ))
237 .ok();
238}
239
240#[derive(Clone)]
241struct IpcResultResponse(Result<InvokeResponseBody, String>);
242
243impl IpcResponse for IpcResultResponse {
244 fn body(self) -> tauri::Result<InvokeResponseBody> {
245 self.0.map_err(|err| serde_json::Error::custom(err).into())
246 }
247}