1use std::sync::Arc;
2
3use futures_util::StreamExt;
4use orpc::Router;
5use orpc_procedure::ProcedureStream;
6use orpc_server::rpc;
7use orpc_server::sse;
8use serde::Deserialize;
9use tauri::ipc::Channel;
10use tauri::plugin::{Builder as PluginBuilder, TauriPlugin};
11use tauri::{Manager, Runtime, State};
12
13#[derive(Debug, Deserialize)]
15struct RpcRequest {
16 path: String,
17 input: serde_json::Value,
18}
19
20type BoxFuture<T> = std::pin::Pin<Box<dyn std::future::Future<Output = T> + Send>>;
21type HandlerFn = dyn Fn(serde_json::Value, Channel<serde_json::Value>) -> BoxFuture<serde_json::Value>
22 + Send
23 + Sync;
24
25struct RpcHandler {
27 handler: Arc<HandlerFn>,
28}
29
30pub fn init<TCtx, R, F>(router: Router<TCtx>, ctx_fn: F) -> TauriPlugin<R>
44where
45 TCtx: Send + Sync + 'static,
46 R: Runtime,
47 F: Fn(&tauri::AppHandle<R>) -> TCtx + Send + Sync + 'static,
48{
49 let router = Arc::new(router);
50 let ctx_fn = Arc::new(ctx_fn);
51
52 PluginBuilder::<R>::new("orpc")
53 .invoke_handler(tauri::generate_handler![handle_rpc])
54 .setup(move |app, _api| {
55 let router = router.clone();
56 let ctx_fn = ctx_fn.clone();
57 let app_handle = app.clone();
58
59 app.manage(RpcHandler {
60 handler: Arc::new(move |request, channel| {
61 let router = router.clone();
62 let ctx_fn = ctx_fn.clone();
63 let app_handle = app_handle.clone();
64 Box::pin(async move {
65 execute_rpc(&router, &*ctx_fn, &app_handle, request, channel).await
66 })
67 }),
68 });
69 Ok(())
70 })
71 .build()
72}
73
74#[tauri::command]
81async fn handle_rpc(
82 handler: State<'_, RpcHandler>,
83 request: serde_json::Value,
84 channel: Channel<serde_json::Value>,
85) -> Result<serde_json::Value, String> {
86 Ok((handler.handler)(request, channel).await)
87}
88
89async fn stream_to_channel(mut stream: ProcedureStream, channel: Channel<serde_json::Value>) {
94 let mut id: u64 = 0;
95 while let Some(item) = stream.next().await {
96 match item {
97 Ok(output) => {
98 let value = output.to_value().unwrap_or_default();
99 if channel
100 .send(serde_json::json!({
101 "event": "message",
102 "id": id,
103 "data": { "json": value }
104 }))
105 .is_err()
106 {
107 return;
109 }
110 id += 1;
111 }
112 Err(err) => {
113 let orpc_err = rpc::procedure_error_to_orpc_error(err);
114 let _ = channel.send(serde_json::json!({
115 "event": "error",
116 "data": { "json": serde_json::to_value(&orpc_err).unwrap_or_default() }
117 }));
118 return;
119 }
120 }
121 }
122 let _ = channel.send(serde_json::json!({ "event": "done" }));
123}
124
125async fn execute_rpc<TCtx, R, F>(
126 router: &Router<TCtx>,
127 ctx_fn: &F,
128 app_handle: &tauri::AppHandle<R>,
129 request: serde_json::Value,
130 channel: Channel<serde_json::Value>,
131) -> serde_json::Value
132where
133 TCtx: Send + Sync + 'static,
134 R: Runtime,
135 F: Fn(&tauri::AppHandle<R>) -> TCtx,
136{
137 let req: RpcRequest = match serde_json::from_value(request) {
138 Ok(r) => r,
139 Err(e) => {
140 return make_error_response(400, "BAD_REQUEST", &format!("Invalid request: {e}"));
141 }
142 };
143
144 let procedure = match router.get(&req.path) {
145 Some(p) => p,
146 None => {
147 return make_error_response(
148 404,
149 "NOT_FOUND",
150 &format!("Procedure not found: {}", req.path),
151 );
152 }
153 };
154
155 let input_bytes = serde_json::to_vec(&req.input).unwrap_or_default();
156 let input = match rpc::decode_rpc_request(&input_bytes) {
157 Ok(i) => i,
158 Err(err) => {
159 let (status, body) = rpc::encode_rpc_error(&err);
160 return serde_json::json!({
161 "type": "response",
162 "status": status.as_u16(),
163 "body": serde_json::from_slice::<serde_json::Value>(&body).unwrap_or_default()
164 });
165 }
166 };
167
168 let ctx = ctx_fn(app_handle);
169 let stream = procedure.exec(ctx, input);
170
171 if sse::is_subscription(&stream) {
172 tokio::spawn(async move {
173 stream_to_channel(stream, channel).await;
174 });
175 return serde_json::json!({ "type": "subscription" });
176 }
177
178 let mut stream = stream;
180 match stream.next().await {
181 Some(Ok(output)) => match rpc::encode_rpc_success(output) {
182 Ok((status, body)) => serde_json::json!({
183 "type": "response",
184 "status": status.as_u16(),
185 "body": serde_json::from_slice::<serde_json::Value>(&body).unwrap_or_default()
186 }),
187 Err(err) => {
188 let orpc_err = rpc::procedure_error_to_orpc_error(err);
189 let (status, body) = rpc::encode_rpc_error(&orpc_err);
190 serde_json::json!({
191 "type": "response",
192 "status": status.as_u16(),
193 "body": serde_json::from_slice::<serde_json::Value>(&body).unwrap_or_default()
194 })
195 }
196 },
197 Some(Err(err)) => {
198 let orpc_err = rpc::procedure_error_to_orpc_error(err);
199 let (status, body) = rpc::encode_rpc_error(&orpc_err);
200 serde_json::json!({
201 "type": "response",
202 "status": status.as_u16(),
203 "body": serde_json::from_slice::<serde_json::Value>(&body).unwrap_or_default()
204 })
205 }
206 None => make_error_response(500, "INTERNAL_SERVER_ERROR", "Procedure returned no output"),
207 }
208}
209
210fn make_error_response(status: u16, code: &str, message: &str) -> serde_json::Value {
211 serde_json::json!({
212 "type": "response",
213 "status": status,
214 "body": {
215 "json": {
216 "code": code,
217 "status": status,
218 "message": message,
219 "defined": false
220 }
221 }
222 })
223}