Skip to main content

tauri_plugin_orpc/
lib.rs

1use std::sync::Arc;
2
3use futures_util::StreamExt;
4use orpc::Router;
5use orpc_procedure::ProcedureStream;
6use orpc_server::rpc;
7use orpc_server::sse;
8use serde::Deserialize;
9use tauri::ipc::Channel;
10use tauri::plugin::{Builder as PluginBuilder, TauriPlugin};
11use tauri::{Manager, Runtime, State};
12
13/// IPC request from the TauriLink.
14#[derive(Debug, Deserialize)]
15struct RpcRequest {
16    path: String,
17    input: serde_json::Value,
18}
19
20type BoxFuture<T> = std::pin::Pin<Box<dyn std::future::Future<Output = T> + Send>>;
21type HandlerFn = dyn Fn(serde_json::Value, Channel<serde_json::Value>) -> BoxFuture<serde_json::Value>
22    + Send
23    + Sync;
24
25/// Type-erased handler stored as Tauri managed state.
26struct RpcHandler {
27    handler: Arc<HandlerFn>,
28}
29
30/// Create a Tauri plugin that serves an oRPC router via IPC.
31///
32/// Registers a single IPC command `plugin:orpc|handle_rpc` that auto-detects
33/// single-value vs subscription procedures. Single-value results are returned
34/// directly; subscriptions are streamed via the Tauri Channel.
35///
36/// # Example
37/// ```ignore
38/// tauri::Builder::default()
39///     .plugin(tauri_plugin_orpc::init(router, |app_handle| AppCtx { ... }))
40///     .run(tauri::generate_context!())
41///     .unwrap();
42/// ```
43pub fn init<TCtx, R, F>(router: Router<TCtx>, ctx_fn: F) -> TauriPlugin<R>
44where
45    TCtx: Send + Sync + 'static,
46    R: Runtime,
47    F: Fn(&tauri::AppHandle<R>) -> TCtx + Send + Sync + 'static,
48{
49    let router = Arc::new(router);
50    let ctx_fn = Arc::new(ctx_fn);
51
52    PluginBuilder::<R>::new("orpc")
53        .invoke_handler(tauri::generate_handler![handle_rpc])
54        .setup(move |app, _api| {
55            let router = router.clone();
56            let ctx_fn = ctx_fn.clone();
57            let app_handle = app.clone();
58
59            app.manage(RpcHandler {
60                handler: Arc::new(move |request, channel| {
61                    let router = router.clone();
62                    let ctx_fn = ctx_fn.clone();
63                    let app_handle = app_handle.clone();
64                    Box::pin(async move {
65                        execute_rpc(&router, &*ctx_fn, &app_handle, request, channel).await
66                    })
67                }),
68            });
69            Ok(())
70        })
71        .build()
72}
73
74/// Unified handler: auto-detects single-value vs subscription.
75///
76/// For single-value procedures the JSON response is returned directly.
77/// For subscriptions, streaming is spawned as a background task and a
78/// `{"type": "subscription"}` marker is returned immediately while
79/// events flow through the Channel.
80#[tauri::command]
81async fn handle_rpc(
82    handler: State<'_, RpcHandler>,
83    request: serde_json::Value,
84    channel: Channel<serde_json::Value>,
85) -> Result<serde_json::Value, String> {
86    Ok((handler.handler)(request, channel).await)
87}
88
89/// Stream ProcedureStream items through a Tauri Channel.
90///
91/// Stops streaming when the channel is closed (frontend disconnected)
92/// to avoid leaking background tasks and resources.
93async fn stream_to_channel(mut stream: ProcedureStream, channel: Channel<serde_json::Value>) {
94    let mut id: u64 = 0;
95    while let Some(item) = stream.next().await {
96        match item {
97            Ok(output) => {
98                let value = output.to_value().unwrap_or_default();
99                if channel
100                    .send(serde_json::json!({
101                        "event": "message",
102                        "id": id,
103                        "data": { "json": value }
104                    }))
105                    .is_err()
106                {
107                    // Channel closed — frontend disconnected, stop streaming.
108                    return;
109                }
110                id += 1;
111            }
112            Err(err) => {
113                let orpc_err = rpc::procedure_error_to_orpc_error(err);
114                let _ = channel.send(serde_json::json!({
115                    "event": "error",
116                    "data": { "json": serde_json::to_value(&orpc_err).unwrap_or_default() }
117                }));
118                return;
119            }
120        }
121    }
122    let _ = channel.send(serde_json::json!({ "event": "done" }));
123}
124
125async fn execute_rpc<TCtx, R, F>(
126    router: &Router<TCtx>,
127    ctx_fn: &F,
128    app_handle: &tauri::AppHandle<R>,
129    request: serde_json::Value,
130    channel: Channel<serde_json::Value>,
131) -> serde_json::Value
132where
133    TCtx: Send + Sync + 'static,
134    R: Runtime,
135    F: Fn(&tauri::AppHandle<R>) -> TCtx,
136{
137    let req: RpcRequest = match serde_json::from_value(request) {
138        Ok(r) => r,
139        Err(e) => {
140            return make_error_response(400, "BAD_REQUEST", &format!("Invalid request: {e}"));
141        }
142    };
143
144    let procedure = match router.get(&req.path) {
145        Some(p) => p,
146        None => {
147            return make_error_response(
148                404,
149                "NOT_FOUND",
150                &format!("Procedure not found: {}", req.path),
151            );
152        }
153    };
154
155    let input_bytes = serde_json::to_vec(&req.input).unwrap_or_default();
156    let input = match rpc::decode_rpc_request(&input_bytes) {
157        Ok(i) => i,
158        Err(err) => {
159            let (status, body) = rpc::encode_rpc_error(&err);
160            return serde_json::json!({
161                "type": "response",
162                "status": status.as_u16(),
163                "body": serde_json::from_slice::<serde_json::Value>(&body).unwrap_or_default()
164            });
165        }
166    };
167
168    let ctx = ctx_fn(app_handle);
169    let stream = procedure.exec(ctx, input);
170
171    if sse::is_subscription(&stream) {
172        tokio::spawn(async move {
173            stream_to_channel(stream, channel).await;
174        });
175        return serde_json::json!({ "type": "subscription" });
176    }
177
178    // Single-value: consume first item
179    let mut stream = stream;
180    match stream.next().await {
181        Some(Ok(output)) => match rpc::encode_rpc_success(output) {
182            Ok((status, body)) => serde_json::json!({
183                "type": "response",
184                "status": status.as_u16(),
185                "body": serde_json::from_slice::<serde_json::Value>(&body).unwrap_or_default()
186            }),
187            Err(err) => {
188                let orpc_err = rpc::procedure_error_to_orpc_error(err);
189                let (status, body) = rpc::encode_rpc_error(&orpc_err);
190                serde_json::json!({
191                    "type": "response",
192                    "status": status.as_u16(),
193                    "body": serde_json::from_slice::<serde_json::Value>(&body).unwrap_or_default()
194                })
195            }
196        },
197        Some(Err(err)) => {
198            let orpc_err = rpc::procedure_error_to_orpc_error(err);
199            let (status, body) = rpc::encode_rpc_error(&orpc_err);
200            serde_json::json!({
201                "type": "response",
202                "status": status.as_u16(),
203                "body": serde_json::from_slice::<serde_json::Value>(&body).unwrap_or_default()
204            })
205        }
206        None => make_error_response(500, "INTERNAL_SERVER_ERROR", "Procedure returned no output"),
207    }
208}
209
210fn make_error_response(status: u16, code: &str, message: &str) -> serde_json::Value {
211    serde_json::json!({
212        "type": "response",
213        "status": status,
214        "body": {
215            "json": {
216                "code": code,
217                "status": status,
218                "message": message,
219                "defined": false
220            }
221        }
222    })
223}