1#![doc = include_str!("../README.md")]
2
3use axum::body::Body;
4use axum::extract::Request;
5use axum::http::Response;
6use axum::Router;
7use futures_util::FutureExt;
8use http_body_util::BodyExt;
9use std::collections::HashMap;
10use std::future::Future;
11use std::marker::PhantomData;
12use std::ops::{Deref, DerefMut};
13use tauri::async_runtime::block_on;
14use tauri::ipc::{InvokeBody, Request as IpcRequest};
15use tauri::{plugin::TauriPlugin, Manager, Runtime};
16use tower::{Service, ServiceExt};
17
18mod commands;
19mod error;
20mod models;
21
22pub use error::{Error, Result};
23pub use models::*;
24
25pub trait AxumExt<R: Runtime> {
27 fn axum(&self) -> &Axum;
28}
29
30impl<R: Runtime, T: Manager<R>> crate::AxumExt<R> for T {
31 fn axum(&self) -> &Axum {
32 self.state::<Axum>().inner()
33 }
34}
35
36pub struct Axum(pub Router);
37
38impl Deref for Axum {
39 type Target = Router;
40
41 fn deref(&self) -> &Self::Target {
42 &self.0
43 }
44}
45
46impl DerefMut for Axum {
47 fn deref_mut(&mut self) -> &mut Self::Target {
48 &mut self.0
49 }
50}
51
52impl Axum {
53 pub async fn call(&self, req: IpcRequest<'_>) -> Result<AxumResponse> {
54 let mut rr = Request::builder();
55 *rr.headers_mut().unwrap() = req.headers().clone();
56
57 rr = rr.uri(req.headers().get("x-uri").ok_or(Error::Uri)?.as_ref());
58 rr = rr.method(req.headers().get("x-method").ok_or(Error::Method)?.as_ref());
59
60 let bytes = match req.body() {
61 InvokeBody::Json(v) => serde_json::to_vec(&v)?,
62 InvokeBody::Raw(r) => r.to_vec(),
63 };
64
65 let result = rr.body(Body::from(bytes))?;
66 let response = self.0.clone().call(result).await.unwrap();
67 let status = response.status();
68 let mut headers: HashMap<String, String> = response
69 .headers()
70 .iter()
71 .map(|(k, v)| (k.to_string(), String::from(v.to_str().unwrap_or_default())))
72 .collect();
73 let colleted = response.into_body().collect().await?;
74
75 if let Some(t) = colleted.trailers() {
76 for (k, v) in t.iter() {
77 headers.insert(k.to_string(), String::from(v.to_str().unwrap_or_default()));
78 }
79 }
80
81 Ok(AxumResponse {
82 status,
83 headers,
84 body: colleted.to_bytes(),
85 })
86 }
87
88 pub(crate) async fn call_json(&self, req: IpcRequest<'_>) -> Result<Vec<u8>> {
89 let body = match req.body() {
90 InvokeBody::Raw(raw) => raw.to_vec(),
91 InvokeBody::Json(_) => {
92 return Err(Error::Canceled);
93 }
94 };
95
96 let mut rr = Request::builder().header("Content-Type", "application/json");
97
98 rr = rr.uri(req.headers().get("x-uri").ok_or(Error::Uri)?.as_ref());
99 rr = rr.method(req.headers().get("x-method").ok_or(Error::Method)?.as_ref());
100
101 let res = self
102 .0
103 .clone()
104 .call(rr.body(Body::from(body))?)
105 .await
106 .unwrap();
107
108 let body = res.into_body().collect().await?.to_bytes().to_vec();
109
110 Ok(body)
111 }
112}
113
114pub fn init<R: Runtime>(router: Router) -> TauriPlugin<R> {
124 Builder::new(router).build()
125}
126
127pub fn block_init<R: Runtime, F: Future<Output = Router>>(f: F) -> TauriPlugin<R> {
128 block_on(f.map(init))
129}
130
131pub fn try_block_init<
132 R: Runtime,
133 F: Future<Output = std::result::Result<Router, Box<dyn std::error::Error>>>,
134>(
135 f: F,
136) -> std::result::Result<TauriPlugin<R>, Box<dyn std::error::Error>> {
137 Ok(block_on(f).map(init)?)
138}
139
140pub struct Builder<R: Runtime> {
141 router: Router,
142 _r: PhantomData<R>,
143}
144
145impl<R: Runtime> Builder<R> {
146 pub fn new(router: Router) -> Self {
147 Self {
148 router,
149 _r: PhantomData,
150 }
151 }
152
153 pub fn build(self) -> TauriPlugin<R> {
154 let mut router_clone = self.router.clone();
155
156 #[cfg(feature = "catch-panic")]
157 {
158 router_clone = router_clone.layer(tower_http::catch_panic::CatchPanicLayer::new());
159 }
160
161 #[cfg(feature = "cors")]
162 {
163 router_clone = router_clone.layer(tower_http::cors::CorsLayer::permissive());
164 }
165
166 tauri::plugin::Builder::new("axum")
167 .register_asynchronous_uri_scheme_protocol("axum", move |_ctx, request, responder| {
168 let svc = router_clone.clone();
169 tauri::async_runtime::spawn(async move {
170 let (mut parts, body) = svc
171 .oneshot(request.map(Body::from))
172 .await
173 .unwrap()
174 .into_parts();
175
176 let body = match body.collect().await {
177 Ok(b) => b.to_bytes().to_vec(),
178 Err(e) => {
179 parts.status = axum::http::StatusCode::INTERNAL_SERVER_ERROR;
180 e.to_string().into_bytes()
181 }
182 };
183 responder.respond(Response::from_parts(parts, body));
184 });
185 })
186 .invoke_handler(tauri::generate_handler![
187 commands::call,
188 commands::call_json,
189 commands::fetch,
190 commands::fetch_cancel,
191 commands::fetch_send,
192 commands::fetch_read_body
193 ])
194 .setup(|app, __api| {
195 app.manage(Axum(self.router));
196 Ok(())
197 })
198 .build()
199 }
200}
201
202impl<R: Runtime> Builder<R> {}