Skip to main content

tauri_plugin_orpc/
lib.rs

1use std::sync::Arc;
2
3use futures_util::StreamExt;
4use orpc::Router;
5use orpc_procedure::ProcedureStream;
6use orpc_server::rpc;
7use orpc_server::sse;
8use serde::Deserialize;
9use tauri::ipc::Channel;
10use tauri::plugin::{Builder as PluginBuilder, TauriPlugin};
11use tauri::{Manager, Runtime, State};
12
13/// IPC request from the TauriLink.
14#[derive(Debug, Deserialize)]
15struct RpcRequest {
16    path: String,
17    input: serde_json::Value,
18}
19
20type BoxFuture<T> = std::pin::Pin<Box<dyn std::future::Future<Output = T> + Send>>;
21type HandlerFn = dyn Fn(serde_json::Value, Channel<serde_json::Value>) -> BoxFuture<serde_json::Value>
22    + Send
23    + Sync;
24
25/// Type-erased handler stored as Tauri managed state.
26struct RpcHandler {
27    handler: Arc<HandlerFn>,
28}
29
30/// Create a Tauri plugin that serves an oRPC router via IPC.
31///
32/// Registers a single IPC command `plugin:orpc|handle_rpc` that auto-detects
33/// single-value vs subscription procedures. Single-value results are returned
34/// directly; subscriptions are streamed via the Tauri Channel.
35///
36/// # Example
37/// ```ignore
38/// tauri::Builder::default()
39///     .plugin(tauri_plugin_orpc::init(router, |app_handle| AppCtx { ... }))
40///     .run(tauri::generate_context!())
41///     .unwrap();
42/// ```
43pub fn init<TCtx, R, F>(router: Router<TCtx>, ctx_fn: F) -> TauriPlugin<R>
44where
45    TCtx: Send + Sync + 'static,
46    R: Runtime,
47    F: Fn(&tauri::AppHandle<R>) -> TCtx + Send + Sync + 'static,
48{
49    let router = Arc::new(router);
50    let ctx_fn = Arc::new(ctx_fn);
51
52    PluginBuilder::<R>::new("orpc")
53        .invoke_handler(tauri::generate_handler![handle_rpc])
54        .setup(move |app, _api| {
55            let router = router.clone();
56            let ctx_fn = ctx_fn.clone();
57            let app_handle = app.clone();
58
59            app.manage(RpcHandler {
60                handler: Arc::new(move |request, channel| {
61                    let router = router.clone();
62                    let ctx_fn = ctx_fn.clone();
63                    let app_handle = app_handle.clone();
64                    Box::pin(async move {
65                        execute_rpc(&router, &*ctx_fn, &app_handle, request, channel).await
66                    })
67                }),
68            });
69            Ok(())
70        })
71        .build()
72}
73
74/// Unified handler: auto-detects single-value vs subscription.
75///
76/// For single-value procedures the JSON response is returned directly.
77/// For subscriptions, streaming is spawned as a background task and a
78/// `{"type": "subscription"}` marker is returned immediately while
79/// events flow through the Channel.
80#[tauri::command]
81async fn handle_rpc(
82    handler: State<'_, RpcHandler>,
83    request: serde_json::Value,
84    channel: Channel<serde_json::Value>,
85) -> Result<serde_json::Value, String> {
86    Ok((handler.handler)(request, channel).await)
87}
88
89/// Stream ProcedureStream items through a Tauri Channel.
90async fn stream_to_channel(mut stream: ProcedureStream, channel: Channel<serde_json::Value>) {
91    let mut id: u64 = 0;
92    while let Some(item) = stream.next().await {
93        match item {
94            Ok(output) => {
95                let value = output.to_value().unwrap_or_default();
96                let _ = channel.send(serde_json::json!({
97                    "event": "message",
98                    "id": id,
99                    "data": { "json": value }
100                }));
101                id += 1;
102            }
103            Err(err) => {
104                let orpc_err = rpc::procedure_error_to_orpc_error(err);
105                let _ = channel.send(serde_json::json!({
106                    "event": "error",
107                    "data": { "json": serde_json::to_value(&orpc_err).unwrap_or_default() }
108                }));
109                return;
110            }
111        }
112    }
113    let _ = channel.send(serde_json::json!({ "event": "done" }));
114}
115
116async fn execute_rpc<TCtx, R, F>(
117    router: &Router<TCtx>,
118    ctx_fn: &F,
119    app_handle: &tauri::AppHandle<R>,
120    request: serde_json::Value,
121    channel: Channel<serde_json::Value>,
122) -> serde_json::Value
123where
124    TCtx: Send + Sync + 'static,
125    R: Runtime,
126    F: Fn(&tauri::AppHandle<R>) -> TCtx,
127{
128    let req: RpcRequest = match serde_json::from_value(request) {
129        Ok(r) => r,
130        Err(e) => {
131            return make_error_response(400, "BAD_REQUEST", &format!("Invalid request: {e}"));
132        }
133    };
134
135    let procedure = match router.get(&req.path) {
136        Some(p) => p,
137        None => {
138            return make_error_response(
139                404,
140                "NOT_FOUND",
141                &format!("Procedure not found: {}", req.path),
142            );
143        }
144    };
145
146    let input_bytes = serde_json::to_vec(&req.input).unwrap_or_default();
147    let input = match rpc::decode_rpc_request(&input_bytes) {
148        Ok(i) => i,
149        Err(err) => {
150            let (status, body) = rpc::encode_rpc_error(&err);
151            return serde_json::json!({
152                "type": "response",
153                "status": status.as_u16(),
154                "body": serde_json::from_slice::<serde_json::Value>(&body).unwrap_or_default()
155            });
156        }
157    };
158
159    let ctx = ctx_fn(app_handle);
160    let stream = procedure.exec(ctx, input);
161
162    if sse::is_subscription(&stream) {
163        tokio::spawn(async move {
164            stream_to_channel(stream, channel).await;
165        });
166        return serde_json::json!({ "type": "subscription" });
167    }
168
169    // Single-value: consume first item
170    let mut stream = stream;
171    match stream.next().await {
172        Some(Ok(output)) => match rpc::encode_rpc_success(output) {
173            Ok((status, body)) => serde_json::json!({
174                "type": "response",
175                "status": status.as_u16(),
176                "body": serde_json::from_slice::<serde_json::Value>(&body).unwrap_or_default()
177            }),
178            Err(err) => {
179                let orpc_err = rpc::procedure_error_to_orpc_error(err);
180                let (status, body) = rpc::encode_rpc_error(&orpc_err);
181                serde_json::json!({
182                    "type": "response",
183                    "status": status.as_u16(),
184                    "body": serde_json::from_slice::<serde_json::Value>(&body).unwrap_or_default()
185                })
186            }
187        },
188        Some(Err(err)) => {
189            let orpc_err = rpc::procedure_error_to_orpc_error(err);
190            let (status, body) = rpc::encode_rpc_error(&orpc_err);
191            serde_json::json!({
192                "type": "response",
193                "status": status.as_u16(),
194                "body": serde_json::from_slice::<serde_json::Value>(&body).unwrap_or_default()
195            })
196        }
197        None => make_error_response(500, "INTERNAL_SERVER_ERROR", "Procedure returned no output"),
198    }
199}
200
201fn make_error_response(status: u16, code: &str, message: &str) -> serde_json::Value {
202    serde_json::json!({
203        "type": "response",
204        "status": status,
205        "body": {
206            "json": {
207                "code": code,
208                "status": status,
209                "message": message,
210                "defined": false
211            }
212        }
213    })
214}