tauri_axum_htmx/
lib.rs

1//! # Tauri Axum HTMX
2//! A library for creating interactive UIs using HTMX and Axum within Tauri applications.
3//! This crate provides the necessary infrastructure to handle HTMX requests through
4//! Tauri's FFI bridge, process them using an Axum application, and return responses
5//! back to the webview.
6//! ## Overview
7//! In a typical HTMX application, requests are sent to a server which returns HTML to
8//! the client. This crate enables this pattern within Tauri applications by:
9//! - Intercepting HTMX requests in the webview
10//! - Forwarding them through Tauri's FFI bridge
11//! - Processing them with an Axum application running in the Tauri backend
12//! - Returning the response back to be handled by HTMX in the webview
13//! ## Quick Start
14//! 1. First, initialize the client-side integration in your HTML:
15//! ```html
16//! <!doctype html>
17//! <html lang="en">
18//!   <head>
19//!     <script src="https://unpkg.com/htmx.org@2.0.4"></script>
20//!     <script type="module">
21//!       import { initialize } from "https://unpkg.com/tauri-axum-htmx";
22//!
23//!       initialize("/"); // the initial path for the application to start on
24//!     </script>
25//!   </head>
26//! </html>
27//! ```
28//! 2. Then, set up the Tauri command to handle requests:
29//! ```rust,no_run
30//! use std::sync::Arc;
31//! use tokio::sync::Mutex;
32//! use axum::{Router, routing::get};
33//! use tauri::State;
34//! use tauri_axum_htmx::{LocalRequest, LocalResponse};
35//! struct TauriState {
36//!     router: Arc<Mutex<Router>>,
37//! }
38//! #[tauri::command]
39//! async fn local_app_request(
40//!     state: State<'_, TauriState>,
41//!     local_request: LocalRequest,
42//! ) -> Result<LocalResponse, ()> {
43//!     let mut router = state.router.lock().await;
44//!     let response = local_request.send_to_router(&mut router).await;
45//!     Ok(response)
46//! }
47//! fn main() {
48//!     let app = Router::new()
49//!         .route("/", get(|| async { "Hello, World!" }));
50//!     let tauri_state = TauriState {
51//!         router: Arc::new(Mutex::new(app)),
52//!     };
53//!     tauri::Builder::default()
54//!         .manage(tauri_state)
55//!         .invoke_handler(tauri::generate_handler![local_app_request]);
56//! }
57//! ```
58
59use axum::http::{self};
60use axum::response::Response;
61use axum::Router;
62use axum::{body::Body, http::Request};
63use serde::{Deserialize, Serialize};
64use std::collections::HashMap;
65use std::fmt::Display;
66use thiserror::Error;
67use tower_service::Service;
68
69#[derive(Error, Debug)]
70pub enum Error {
71    #[error("Could not parse method from LocalRequest")]
72    RequestMethodParseError(String),
73
74    #[error("Could not parse body from LocalRequest")]
75    RequestBodyParseError(#[from] http::Error),
76}
77
78/// Represents an HTTP request that can be processed by an Axum router.
79#[derive(Serialize, Deserialize, Clone, Debug)]
80pub struct LocalRequest {
81    pub uri: String,
82    pub method: String,
83    pub body: Option<String>,
84    pub headers: HashMap<String, String>,
85}
86
87impl LocalRequest {
88    pub async fn send_to_router(self, router: &mut Router) -> LocalResponse {
89        match self.to_axum_request() {
90            Ok(request) => match router.call(request).await {
91                Ok(response) => LocalResponse::from_response(response).await,
92                Err(error) => LocalResponse::internal_server_error(error),
93            },
94            Err(error) => LocalResponse::internal_server_error(error),
95        }
96    }
97
98    fn to_axum_request(&self) -> Result<http::Request<Body>, Error> {
99        let uri = self.uri.to_string();
100        let mut request_builder = match self.method.to_uppercase().as_str() {
101            "GET" => Ok(Request::get(uri)),
102            "POST" => Ok(Request::post(uri)),
103            "PUT" => Ok(Request::put(uri)),
104            "DELETE" => Ok(Request::delete(uri)),
105            "PATCH" => Ok(Request::patch(uri)),
106            _ => Err(Error::RequestMethodParseError(self.method.to_string())),
107        }?;
108
109        for (key, value) in self.headers.iter() {
110            request_builder = request_builder.header(key, value);
111        }
112
113        let request = match &self.body {
114            None => request_builder.body(Body::empty()),
115            Some(body) => request_builder.body(body.to_string().into()),
116        }?;
117
118        Ok(request)
119    }
120}
121
122/// Represents an HTTP response returned from an Axum router.
123#[derive(Serialize, Deserialize, Debug, Clone)]
124pub struct LocalResponse {
125    pub status_code: u16,
126    pub body: Vec<u8>,
127    pub headers: HashMap<String, String>,
128}
129
130impl LocalResponse {
131    pub fn internal_server_error(error: impl Display) -> Self {
132        let error_message = format!("An error occured: {}", error);
133        LocalResponse {
134            status_code: 500,
135            body: error_message.into(),
136            headers: Default::default(),
137        }
138    }
139}
140
141impl LocalResponse {
142    pub async fn from_response(response: Response) -> Self {
143        let code = response.status();
144        let response_headers = response.headers().clone();
145        let bytes_result = axum::body::to_bytes(response.into_body(), usize::MAX).await;
146
147        let mut headers: HashMap<String, String> = HashMap::new();
148        for (key, value) in response_headers.iter() {
149            headers.insert(key.to_string(), value.to_str().unwrap().to_string());
150        }
151
152        match bytes_result {
153            Ok(data) => LocalResponse {
154                status_code: code.as_u16(),
155                body: data.to_vec(),
156                headers,
157            },
158            Err(_) => LocalResponse {
159                status_code: code.as_u16(),
160                body: Vec::new(),
161                headers: headers.clone(),
162            },
163        }
164    }
165}
166
167#[cfg(test)]
168mod tests {
169    use super::*;
170    use axum::{
171        routing::{get, post},
172        Json,
173    };
174    use serde_json::json;
175
176    // Helper function to create a basic router for testing
177    fn create_test_router() -> Router {
178        Router::new()
179            .route("/test", get(|| async { "Hello, World!" }))
180            .route("/echo", post(|body: String| async move { body }))
181            .route("/json", get(|| async { Json(json!({"status": "ok"})) }))
182    }
183
184    mod local_request_tests {
185        use super::*;
186
187        #[tokio::test]
188        async fn test_basic_get_request() {
189            let mut router = create_test_router();
190            let request = LocalRequest {
191                uri: "/test".to_string(),
192                method: "GET".to_string(),
193                body: None,
194                headers: HashMap::new(),
195            };
196
197            let response = request.send_to_router(&mut router).await;
198            assert_eq!(response.status_code, 200);
199            assert_eq!(String::from_utf8(response.body).unwrap(), "Hello, World!");
200        }
201
202        #[tokio::test]
203        async fn test_post_request_with_body() {
204            let mut router = create_test_router();
205            let body = "Test Body";
206            let request = LocalRequest {
207                uri: "/echo".to_string(),
208                method: "POST".to_string(),
209                body: Some(body.to_string()),
210                headers: HashMap::new(),
211            };
212
213            let response = request.send_to_router(&mut router).await;
214            assert_eq!(response.status_code, 200);
215            assert_eq!(String::from_utf8(response.body).unwrap(), body);
216        }
217
218        #[tokio::test]
219        async fn test_invalid_method() {
220            let mut router = create_test_router();
221            let request = LocalRequest {
222                uri: "/test".to_string(),
223                method: "INVALID".to_string(),
224                body: None,
225                headers: HashMap::new(),
226            };
227
228            let response = request.send_to_router(&mut router).await;
229            assert_eq!(response.status_code, 500);
230            assert!(String::from_utf8(response.body)
231                .unwrap()
232                .contains("Could not parse method"));
233        }
234
235        #[tokio::test]
236        async fn test_request_with_headers() {
237            let mut router = Router::new().route(
238                "/headers",
239                get(|req: Request<Body>| async move {
240                    let header_value = req
241                        .headers()
242                        .get("X-Test-Header")
243                        .and_then(|v| v.to_str().ok())
244                        .unwrap_or("")
245                        .to_string();
246                    header_value
247                }),
248            );
249
250            let mut headers = HashMap::new();
251            headers.insert("X-Test-Header".to_string(), "test-value".to_string());
252
253            let request = LocalRequest {
254                uri: "/headers".to_string(),
255                method: "GET".to_string(),
256                body: None,
257                headers,
258            };
259
260            let response = request.send_to_router(&mut router).await;
261            assert_eq!(response.status_code, 200);
262            assert_eq!(String::from_utf8(response.body).unwrap(), "test-value");
263        }
264    }
265
266    mod local_response_tests {
267        use super::*;
268        use http::response::Builder;
269
270        #[tokio::test]
271        async fn test_response_creation_with_body() {
272            let response = Builder::new()
273                .status(200)
274                .body(Body::from("test body"))
275                .unwrap();
276
277            let local_response = LocalResponse::from_response(response).await;
278            assert_eq!(local_response.status_code, 200);
279            assert_eq!(String::from_utf8(local_response.body).unwrap(), "test body");
280        }
281
282        #[tokio::test]
283        async fn test_response_with_headers() {
284            let response = Builder::new()
285                .status(200)
286                .header("X-Test", "test-value")
287                .body(Body::empty())
288                .unwrap();
289
290            let local_response = LocalResponse::from_response(response).await;
291            assert_eq!(local_response.status_code, 200);
292            assert_eq!(local_response.headers.get("x-test").unwrap(), "test-value");
293        }
294
295        #[tokio::test]
296        async fn test_internal_server_error() {
297            let error_message = "Test error";
298            let response = LocalResponse::internal_server_error(error_message);
299
300            assert_eq!(response.status_code, 500);
301            assert!(String::from_utf8(response.body)
302                .unwrap()
303                .contains(error_message));
304            assert!(response.headers.is_empty());
305        }
306    }
307
308    mod method_tests {
309        use super::*;
310
311        #[tokio::test]
312        async fn test_all_valid_methods() {
313            let methods = vec!["GET", "POST", "PUT", "DELETE", "PATCH"];
314
315            for method in methods {
316                let request = LocalRequest {
317                    uri: "/test".to_string(),
318                    method: method.to_string(),
319                    body: None,
320                    headers: HashMap::new(),
321                };
322
323                assert!(request.to_axum_request().is_ok());
324            }
325        }
326
327        #[tokio::test]
328        async fn test_method_case_insensitivity() {
329            let request = LocalRequest {
330                uri: "/test".to_string(),
331                method: "get".to_string(),
332                body: None,
333                headers: HashMap::new(),
334            };
335
336            assert!(request.to_axum_request().is_ok());
337        }
338    }
339}