1use std::sync::Arc;
2
3use futures_util::StreamExt;
4use orpc::Router;
5use orpc_procedure::ProcedureStream;
6use orpc_server::rpc;
7use orpc_server::sse;
8use serde::Deserialize;
9use tauri::ipc::Channel;
10use tauri::plugin::{Builder as PluginBuilder, TauriPlugin};
11use tauri::{Manager, Runtime, State};
12
13#[derive(Debug, Deserialize)]
15struct RpcRequest {
16 path: String,
17 input: serde_json::Value,
18}
19
20type BoxFuture<T> = std::pin::Pin<Box<dyn std::future::Future<Output = T> + Send>>;
21type HandlerFn = dyn Fn(serde_json::Value, Channel<serde_json::Value>) -> BoxFuture<serde_json::Value>
22 + Send
23 + Sync;
24
25struct RpcHandler {
27 handler: Arc<HandlerFn>,
28}
29
30pub fn init<TCtx, R, F>(router: Router<TCtx>, ctx_fn: F) -> TauriPlugin<R>
44where
45 TCtx: Send + Sync + 'static,
46 R: Runtime,
47 F: Fn(&tauri::AppHandle<R>) -> TCtx + Send + Sync + 'static,
48{
49 let router = Arc::new(router);
50 let ctx_fn = Arc::new(ctx_fn);
51
52 PluginBuilder::<R>::new("orpc")
53 .invoke_handler(tauri::generate_handler![handle_rpc])
54 .setup(move |app, _api| {
55 let router = router.clone();
56 let ctx_fn = ctx_fn.clone();
57 let app_handle = app.clone();
58
59 app.manage(RpcHandler {
60 handler: Arc::new(move |request, channel| {
61 let router = router.clone();
62 let ctx_fn = ctx_fn.clone();
63 let app_handle = app_handle.clone();
64 Box::pin(async move {
65 execute_rpc(&router, &*ctx_fn, &app_handle, request, channel).await
66 })
67 }),
68 });
69 Ok(())
70 })
71 .build()
72}
73
74#[tauri::command]
81async fn handle_rpc(
82 handler: State<'_, RpcHandler>,
83 request: serde_json::Value,
84 channel: Channel<serde_json::Value>,
85) -> Result<serde_json::Value, String> {
86 Ok((handler.handler)(request, channel).await)
87}
88
89async fn stream_to_channel(mut stream: ProcedureStream, channel: Channel<serde_json::Value>) {
91 let mut id: u64 = 0;
92 while let Some(item) = stream.next().await {
93 match item {
94 Ok(output) => {
95 let value = output.to_value().unwrap_or_default();
96 let _ = channel.send(serde_json::json!({
97 "event": "message",
98 "id": id,
99 "data": { "json": value }
100 }));
101 id += 1;
102 }
103 Err(err) => {
104 let orpc_err = rpc::procedure_error_to_orpc_error(err);
105 let _ = channel.send(serde_json::json!({
106 "event": "error",
107 "data": { "json": serde_json::to_value(&orpc_err).unwrap_or_default() }
108 }));
109 return;
110 }
111 }
112 }
113 let _ = channel.send(serde_json::json!({ "event": "done" }));
114}
115
116async fn execute_rpc<TCtx, R, F>(
117 router: &Router<TCtx>,
118 ctx_fn: &F,
119 app_handle: &tauri::AppHandle<R>,
120 request: serde_json::Value,
121 channel: Channel<serde_json::Value>,
122) -> serde_json::Value
123where
124 TCtx: Send + Sync + 'static,
125 R: Runtime,
126 F: Fn(&tauri::AppHandle<R>) -> TCtx,
127{
128 let req: RpcRequest = match serde_json::from_value(request) {
129 Ok(r) => r,
130 Err(e) => {
131 return make_error_response(400, "BAD_REQUEST", &format!("Invalid request: {e}"));
132 }
133 };
134
135 let procedure = match router.get(&req.path) {
136 Some(p) => p,
137 None => {
138 return make_error_response(
139 404,
140 "NOT_FOUND",
141 &format!("Procedure not found: {}", req.path),
142 );
143 }
144 };
145
146 let input_bytes = serde_json::to_vec(&req.input).unwrap_or_default();
147 let input = match rpc::decode_rpc_request(&input_bytes) {
148 Ok(i) => i,
149 Err(err) => {
150 let (status, body) = rpc::encode_rpc_error(&err);
151 return serde_json::json!({
152 "type": "response",
153 "status": status.as_u16(),
154 "body": serde_json::from_slice::<serde_json::Value>(&body).unwrap_or_default()
155 });
156 }
157 };
158
159 let ctx = ctx_fn(app_handle);
160 let stream = procedure.exec(ctx, input);
161
162 if sse::is_subscription(&stream) {
163 tokio::spawn(async move {
164 stream_to_channel(stream, channel).await;
165 });
166 return serde_json::json!({ "type": "subscription" });
167 }
168
169 let mut stream = stream;
171 match stream.next().await {
172 Some(Ok(output)) => match rpc::encode_rpc_success(output) {
173 Ok((status, body)) => serde_json::json!({
174 "type": "response",
175 "status": status.as_u16(),
176 "body": serde_json::from_slice::<serde_json::Value>(&body).unwrap_or_default()
177 }),
178 Err(err) => {
179 let orpc_err = rpc::procedure_error_to_orpc_error(err);
180 let (status, body) = rpc::encode_rpc_error(&orpc_err);
181 serde_json::json!({
182 "type": "response",
183 "status": status.as_u16(),
184 "body": serde_json::from_slice::<serde_json::Value>(&body).unwrap_or_default()
185 })
186 }
187 },
188 Some(Err(err)) => {
189 let orpc_err = rpc::procedure_error_to_orpc_error(err);
190 let (status, body) = rpc::encode_rpc_error(&orpc_err);
191 serde_json::json!({
192 "type": "response",
193 "status": status.as_u16(),
194 "body": serde_json::from_slice::<serde_json::Value>(&body).unwrap_or_default()
195 })
196 }
197 None => make_error_response(500, "INTERNAL_SERVER_ERROR", "Procedure returned no output"),
198 }
199}
200
201fn make_error_response(status: u16, code: &str, message: &str) -> serde_json::Value {
202 serde_json::json!({
203 "type": "response",
204 "status": status,
205 "body": {
206 "json": {
207 "code": code,
208 "status": status,
209 "message": message,
210 "defined": false
211 }
212 }
213 })
214}