Skip to main content

tauri_plugin_axum/
lib.rs

1#![doc = include_str!("../README.md")]
2
3use axum::body::Body;
4use axum::extract::Request;
5use axum::http::Response;
6use axum::Router;
7use futures_util::FutureExt;
8use http_body_util::BodyExt;
9use std::collections::HashMap;
10use std::future::Future;
11use std::marker::PhantomData;
12use std::ops::{Deref, DerefMut};
13use tauri::async_runtime::block_on;
14use tauri::ipc::{InvokeBody, Request as IpcRequest};
15use tauri::{plugin::TauriPlugin, Manager, Runtime};
16use tower::{Service, ServiceExt};
17
18mod commands;
19mod error;
20mod models;
21
22pub use error::{Error, Result};
23pub use models::*;
24
25/// Extensions to [`tauri::App`], [`tauri::AppHandle`] and [`tauri::Window`] to access the axum APIs.
26pub trait AxumExt<R: Runtime> {
27    fn axum(&self) -> &Axum;
28}
29
30impl<R: Runtime, T: Manager<R>> crate::AxumExt<R> for T {
31    fn axum(&self) -> &Axum {
32        self.state::<Axum>().inner()
33    }
34}
35
36pub struct Axum(pub Router);
37
38impl Deref for Axum {
39    type Target = Router;
40
41    fn deref(&self) -> &Self::Target {
42        &self.0
43    }
44}
45
46impl DerefMut for Axum {
47    fn deref_mut(&mut self) -> &mut Self::Target {
48        &mut self.0
49    }
50}
51
52impl Axum {
53    pub async fn call(&self, req: IpcRequest<'_>) -> Result<AxumResponse> {
54        let mut rr = Request::builder();
55        *rr.headers_mut().unwrap() = req.headers().clone();
56
57        rr = rr.uri(req.headers().get("x-uri").ok_or(Error::Uri)?.as_ref());
58        rr = rr.method(req.headers().get("x-method").ok_or(Error::Method)?.as_ref());
59
60        let bytes = match req.body() {
61            InvokeBody::Json(v) => serde_json::to_vec(&v)?,
62            InvokeBody::Raw(r) => r.to_vec(),
63        };
64
65        let result = rr.body(Body::from(bytes))?;
66        let response = self.0.clone().call(result).await.unwrap();
67        let status = response.status();
68        let mut headers: HashMap<String, String> = response
69            .headers()
70            .iter()
71            .map(|(k, v)| (k.to_string(), String::from(v.to_str().unwrap_or_default())))
72            .collect();
73        let colleted = response.into_body().collect().await?;
74
75        if let Some(t) = colleted.trailers() {
76            for (k, v) in t.iter() {
77                headers.insert(k.to_string(), String::from(v.to_str().unwrap_or_default()));
78            }
79        }
80
81        Ok(AxumResponse {
82            status,
83            headers,
84            body: colleted.to_bytes(),
85        })
86    }
87
88    pub(crate) async fn call_json(&self, req: IpcRequest<'_>) -> Result<Vec<u8>> {
89        let body = match req.body() {
90            InvokeBody::Raw(raw) => raw.to_vec(),
91            InvokeBody::Json(_) => {
92                return Err(Error::Canceled);
93            }
94        };
95
96        let mut rr = Request::builder().header("Content-Type", "application/json");
97
98        rr = rr.uri(req.headers().get("x-uri").ok_or(Error::Uri)?.as_ref());
99        rr = rr.method(req.headers().get("x-method").ok_or(Error::Method)?.as_ref());
100
101        let res = self
102            .0
103            .clone()
104            .call(rr.body(Body::from(body))?)
105            .await
106            .unwrap();
107
108        let body = res.into_body().collect().await?.to_bytes().to_vec();
109
110        Ok(body)
111    }
112}
113
114/// Initializes the plugin.
115/// ```rust,no_run
116/// tauri::Builder::default()
117///     .plugin(tauri_plugin_axum::init(
118///         Router::new()
119///             .route("/", routing::get(|| async { "Hello, World!" }))
120///             .route("/post", routing::post(post_handle))
121///     ))
122/// ```
123pub fn init<R: Runtime>(router: Router) -> TauriPlugin<R> {
124    Builder::new(router).build()
125}
126
127pub fn block_init<R: Runtime, F: Future<Output = Router>>(f: F) -> TauriPlugin<R> {
128    block_on(f.map(init))
129}
130
131pub fn try_block_init<
132    R: Runtime,
133    F: Future<Output = std::result::Result<Router, Box<dyn std::error::Error>>>,
134>(
135    f: F,
136) -> std::result::Result<TauriPlugin<R>, Box<dyn std::error::Error>> {
137    Ok(block_on(f).map(init)?)
138}
139
140pub struct Builder<R: Runtime> {
141    router: Router,
142    _r: PhantomData<R>,
143}
144
145impl<R: Runtime> Builder<R> {
146    pub fn new(router: Router) -> Self {
147        Self {
148            router,
149            _r: PhantomData,
150        }
151    }
152
153    pub fn build(self) -> TauriPlugin<R> {
154        let mut router_clone = self.router.clone();
155
156        #[cfg(feature = "catch-panic")]
157        {
158            router_clone = router_clone.layer(tower_http::catch_panic::CatchPanicLayer::new());
159        }
160
161        #[cfg(feature = "cors")]
162        {
163            router_clone = router_clone.layer(tower_http::cors::CorsLayer::permissive());
164        }
165
166        tauri::plugin::Builder::new("axum")
167            .register_asynchronous_uri_scheme_protocol("axum", move |_ctx, request, responder| {
168                let svc = router_clone.clone();
169                tauri::async_runtime::spawn(async move {
170                    let (mut parts, body) = svc
171                        .oneshot(request.map(Body::from))
172                        .await
173                        .unwrap()
174                        .into_parts();
175
176                    let body = match body.collect().await {
177                        Ok(b) => b.to_bytes().to_vec(),
178                        Err(e) => {
179                            parts.status = axum::http::StatusCode::INTERNAL_SERVER_ERROR;
180                            e.to_string().into_bytes()
181                        }
182                    };
183                    responder.respond(Response::from_parts(parts, body));
184                });
185            })
186            .invoke_handler(tauri::generate_handler![
187                commands::call,
188                commands::call_json,
189                commands::fetch,
190                commands::fetch_cancel,
191                commands::fetch_send,
192                commands::fetch_read_body
193            ])
194            .setup(|app, __api| {
195                app.manage(Axum(self.router));
196                Ok(())
197            })
198            .build()
199    }
200}
201
202impl<R: Runtime> Builder<R> {}