1use crate::error::Result;
4use crate::middleware::{BodyLimitLayer, LayerStack, MiddlewareLayer, DEFAULT_BODY_LIMIT};
5use crate::router::{MethodRouter, Router};
6use crate::server::Server;
7use std::collections::HashMap;
8use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilter};
9
10pub struct RustApi {
28 router: Router,
29 openapi_spec: rustapi_openapi::OpenApiSpec,
30 layers: LayerStack,
31 body_limit: Option<usize>,
32}
33
34impl RustApi {
35 pub fn new() -> Self {
37 let _ = tracing_subscriber::registry()
39 .with(
40 EnvFilter::try_from_default_env()
41 .unwrap_or_else(|_| EnvFilter::new("info,rustapi=debug")),
42 )
43 .with(tracing_subscriber::fmt::layer())
44 .try_init();
45
46 Self {
47 router: Router::new(),
48 openapi_spec: rustapi_openapi::OpenApiSpec::new("RustAPI Application", "1.0.0")
49 .register::<rustapi_openapi::ErrorSchema>()
50 .register::<rustapi_openapi::ErrorBodySchema>()
51 .register::<rustapi_openapi::ValidationErrorSchema>()
52 .register::<rustapi_openapi::ValidationErrorBodySchema>()
53 .register::<rustapi_openapi::FieldErrorSchema>(),
54 layers: LayerStack::new(),
55 body_limit: Some(DEFAULT_BODY_LIMIT), }
57 }
58
59 #[cfg(feature = "swagger-ui")]
83 pub fn auto() -> Self {
84 Self::new().mount_auto_routes_grouped().docs("/docs")
86 }
87
88 #[cfg(not(feature = "swagger-ui"))]
93 pub fn auto() -> Self {
94 Self::new().mount_auto_routes_grouped()
95 }
96
97 pub fn config() -> RustApiConfig {
115 RustApiConfig::new()
116 }
117
118 pub fn body_limit(mut self, limit: usize) -> Self {
139 self.body_limit = Some(limit);
140 self
141 }
142
143 pub fn no_body_limit(mut self) -> Self {
156 self.body_limit = None;
157 self
158 }
159
160 pub fn layer<L>(mut self, layer: L) -> Self
180 where
181 L: MiddlewareLayer,
182 {
183 self.layers.push(Box::new(layer));
184 self
185 }
186
187 pub fn state<S>(self, _state: S) -> Self
203 where
204 S: Clone + Send + Sync + 'static,
205 {
206 let state = _state;
208 let mut app = self;
209 app.router = app.router.state(state);
210 app
211 }
212
213 pub fn register_schema<T: for<'a> rustapi_openapi::Schema<'a>>(mut self) -> Self {
225 self.openapi_spec = self.openapi_spec.register::<T>();
226 self
227 }
228
229 pub fn openapi_info(mut self, title: &str, version: &str, description: Option<&str>) -> Self {
231 self.openapi_spec.info.title = title.to_string();
234 self.openapi_spec.info.version = version.to_string();
235 self.openapi_spec.info.description = description.map(|d| d.to_string());
236 self
237 }
238
239 pub fn openapi_spec(&self) -> &rustapi_openapi::OpenApiSpec {
241 &self.openapi_spec
242 }
243
244 fn mount_auto_routes_grouped(mut self) -> Self {
245 let routes = crate::auto_route::collect_auto_routes();
246 let mut by_path: HashMap<String, MethodRouter> = HashMap::new();
247
248 for route in routes {
249 let method_enum = match route.method {
250 "GET" => http::Method::GET,
251 "POST" => http::Method::POST,
252 "PUT" => http::Method::PUT,
253 "DELETE" => http::Method::DELETE,
254 "PATCH" => http::Method::PATCH,
255 _ => http::Method::GET,
256 };
257
258 let path = if route.path.starts_with('/') {
259 route.path.to_string()
260 } else {
261 format!("/{}", route.path)
262 };
263
264 let entry = by_path.entry(path).or_insert_with(MethodRouter::new);
265 entry.insert_boxed_with_operation(method_enum, route.handler, route.operation);
266 }
267
268 let route_count = by_path
269 .values()
270 .map(|mr| mr.allowed_methods().len())
271 .sum::<usize>();
272 let path_count = by_path.len();
273
274 for (path, method_router) in by_path {
275 self = self.route(&path, method_router);
276 }
277
278 tracing::info!(paths = path_count, routes = route_count, "Auto-registered routes");
279
280 crate::auto_schema::apply_auto_schemas(&mut self.openapi_spec);
282
283 self
284 }
285
286 pub fn route(mut self, path: &str, method_router: MethodRouter) -> Self {
297 for (method, op) in &method_router.operations {
299 let mut op = op.clone();
300 add_path_params_to_operation(path, &mut op);
301 self.openapi_spec = self.openapi_spec.path(path, method.as_str(), op);
302 }
303
304 self.router = self.router.route(path, method_router);
305 self
306 }
307
308 #[deprecated(note = "Use route() directly or mount_route() for macro-based routing")]
312 pub fn mount(self, path: &str, method_router: MethodRouter) -> Self {
313 self.route(path, method_router)
314 }
315
316 pub fn mount_route(mut self, route: crate::handler::Route) -> Self {
334 let method_enum = match route.method {
335 "GET" => http::Method::GET,
336 "POST" => http::Method::POST,
337 "PUT" => http::Method::PUT,
338 "DELETE" => http::Method::DELETE,
339 "PATCH" => http::Method::PATCH,
340 _ => http::Method::GET,
341 };
342
343 let mut op = route.operation;
345 add_path_params_to_operation(route.path, &mut op);
346 self.openapi_spec = self.openapi_spec.path(route.path, route.method, op);
347
348 self.route_with_method(route.path, method_enum, route.handler)
349 }
350
351 fn route_with_method(
353 self,
354 path: &str,
355 method: http::Method,
356 handler: crate::handler::BoxedHandler,
357 ) -> Self {
358 use crate::router::MethodRouter;
359 let path = if !path.starts_with('/') {
368 format!("/{}", path)
369 } else {
370 path.to_string()
371 };
372
373 let mut handlers = std::collections::HashMap::new();
382 handlers.insert(method, handler);
383
384 let method_router = MethodRouter::from_boxed(handlers);
385 self.route(&path, method_router)
386 }
387
388 pub fn nest(mut self, prefix: &str, router: Router) -> Self {
400 self.router = self.router.nest(prefix, router);
401 self
402 }
403
404 #[cfg(feature = "swagger-ui")]
420 pub fn docs(self, path: &str) -> Self {
421 let title = self.openapi_spec.info.title.clone();
422 let version = self.openapi_spec.info.version.clone();
423 let description = self.openapi_spec.info.description.clone();
424
425 self.docs_with_info(path, &title, &version, description.as_deref())
426 }
427
428 #[cfg(feature = "swagger-ui")]
437 pub fn docs_with_info(
438 mut self,
439 path: &str,
440 title: &str,
441 version: &str,
442 description: Option<&str>,
443 ) -> Self {
444 use crate::router::get;
445 self.openapi_spec.info.title = title.to_string();
447 self.openapi_spec.info.version = version.to_string();
448 if let Some(desc) = description {
449 self.openapi_spec.info.description = Some(desc.to_string());
450 }
451
452 let path = path.trim_end_matches('/');
453 let openapi_path = format!("{}/openapi.json", path);
454
455 let spec_json =
457 serde_json::to_string_pretty(&self.openapi_spec.to_json()).unwrap_or_default();
458 let openapi_url = openapi_path.clone();
459
460 let spec_handler = move || {
462 let json = spec_json.clone();
463 async move {
464 http::Response::builder()
465 .status(http::StatusCode::OK)
466 .header(http::header::CONTENT_TYPE, "application/json")
467 .body(http_body_util::Full::new(bytes::Bytes::from(json)))
468 .unwrap()
469 }
470 };
471
472 let docs_handler = move || {
474 let url = openapi_url.clone();
475 async move { rustapi_openapi::swagger_ui_html(&url) }
476 };
477
478 self.route(&openapi_path, get(spec_handler))
479 .route(path, get(docs_handler))
480 }
481
482 #[cfg(feature = "swagger-ui")]
498 pub fn docs_with_auth(self, path: &str, username: &str, password: &str) -> Self {
499 let title = self.openapi_spec.info.title.clone();
500 let version = self.openapi_spec.info.version.clone();
501 let description = self.openapi_spec.info.description.clone();
502
503 self.docs_with_auth_and_info(
504 path,
505 username,
506 password,
507 &title,
508 &version,
509 description.as_deref(),
510 )
511 }
512
513 #[cfg(feature = "swagger-ui")]
529 pub fn docs_with_auth_and_info(
530 mut self,
531 path: &str,
532 username: &str,
533 password: &str,
534 title: &str,
535 version: &str,
536 description: Option<&str>,
537 ) -> Self {
538 use crate::router::MethodRouter;
539 use base64::{engine::general_purpose::STANDARD, Engine};
540 use std::collections::HashMap;
541
542 self.openapi_spec.info.title = title.to_string();
544 self.openapi_spec.info.version = version.to_string();
545 if let Some(desc) = description {
546 self.openapi_spec.info.description = Some(desc.to_string());
547 }
548
549 let path = path.trim_end_matches('/');
550 let openapi_path = format!("{}/openapi.json", path);
551
552 let credentials = format!("{}:{}", username, password);
554 let encoded = STANDARD.encode(credentials.as_bytes());
555 let expected_auth = format!("Basic {}", encoded);
556
557 let spec_json =
559 serde_json::to_string_pretty(&self.openapi_spec.to_json()).unwrap_or_default();
560 let openapi_url = openapi_path.clone();
561 let expected_auth_spec = expected_auth.clone();
562 let expected_auth_docs = expected_auth;
563
564 let spec_handler: crate::handler::BoxedHandler =
566 std::sync::Arc::new(move |req: crate::Request| {
567 let json = spec_json.clone();
568 let expected = expected_auth_spec.clone();
569 Box::pin(async move {
570 if !check_basic_auth(&req, &expected) {
571 return unauthorized_response();
572 }
573 http::Response::builder()
574 .status(http::StatusCode::OK)
575 .header(http::header::CONTENT_TYPE, "application/json")
576 .body(http_body_util::Full::new(bytes::Bytes::from(json)))
577 .unwrap()
578 })
579 as std::pin::Pin<Box<dyn std::future::Future<Output = crate::Response> + Send>>
580 });
581
582 let docs_handler: crate::handler::BoxedHandler =
584 std::sync::Arc::new(move |req: crate::Request| {
585 let url = openapi_url.clone();
586 let expected = expected_auth_docs.clone();
587 Box::pin(async move {
588 if !check_basic_auth(&req, &expected) {
589 return unauthorized_response();
590 }
591 rustapi_openapi::swagger_ui_html(&url)
592 })
593 as std::pin::Pin<Box<dyn std::future::Future<Output = crate::Response> + Send>>
594 });
595
596 let mut spec_handlers = HashMap::new();
598 spec_handlers.insert(http::Method::GET, spec_handler);
599 let spec_router = MethodRouter::from_boxed(spec_handlers);
600
601 let mut docs_handlers = HashMap::new();
602 docs_handlers.insert(http::Method::GET, docs_handler);
603 let docs_router = MethodRouter::from_boxed(docs_handlers);
604
605 self.route(&openapi_path, spec_router)
606 .route(path, docs_router)
607 }
608
609 pub async fn run(mut self, addr: &str) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
620 if let Some(limit) = self.body_limit {
622 self.layers.prepend(Box::new(BodyLimitLayer::new(limit)));
624 }
625
626 let server = Server::new(self.router, self.layers);
627 server.run(addr).await
628 }
629
630 pub fn into_router(self) -> Router {
632 self.router
633 }
634
635 pub fn layers(&self) -> &LayerStack {
637 &self.layers
638 }
639}
640
641fn add_path_params_to_operation(path: &str, op: &mut rustapi_openapi::Operation) {
642 let mut params: Vec<String> = Vec::new();
643 let mut in_brace = false;
644 let mut current = String::new();
645
646 for ch in path.chars() {
647 match ch {
648 '{' => {
649 in_brace = true;
650 current.clear();
651 }
652 '}' => {
653 if in_brace {
654 in_brace = false;
655 if !current.is_empty() {
656 params.push(current.clone());
657 }
658 }
659 }
660 _ => {
661 if in_brace {
662 current.push(ch);
663 }
664 }
665 }
666 }
667
668 if params.is_empty() {
669 return;
670 }
671
672 let op_params = op.parameters.get_or_insert_with(Vec::new);
673
674 for name in params {
675 let already = op_params
676 .iter()
677 .any(|p| p.location == "path" && p.name == name);
678 if already {
679 continue;
680 }
681
682 op_params.push(rustapi_openapi::Parameter {
683 name,
684 location: "path".to_string(),
685 required: true,
686 description: None,
687 schema: rustapi_openapi::SchemaRef::Inline(serde_json::json!({ "type": "string" })),
688 });
689 }
690}
691
692impl Default for RustApi {
693 fn default() -> Self {
694 Self::new()
695 }
696}
697
698#[cfg(test)]
699mod tests {
700 use super::RustApi;
701 use crate::extract::{FromRequestParts, State};
702 use crate::request::Request;
703 use bytes::Bytes;
704 use http::Method;
705 use std::collections::HashMap;
706
707 #[test]
708 fn state_is_available_via_extractor() {
709 let app = RustApi::new().state(123u32);
710 let router = app.into_router();
711
712 let req = http::Request::builder()
713 .method(Method::GET)
714 .uri("/test")
715 .body(())
716 .unwrap();
717 let (parts, _) = req.into_parts();
718
719 let request = Request::new(parts, Bytes::new(), router.state_ref(), HashMap::new());
720 let State(value) = State::<u32>::from_request_parts(&request).unwrap();
721 assert_eq!(value, 123u32);
722 }
723}
724
725#[cfg(feature = "swagger-ui")]
727fn check_basic_auth(req: &crate::Request, expected: &str) -> bool {
728 req.headers()
729 .get(http::header::AUTHORIZATION)
730 .and_then(|v| v.to_str().ok())
731 .map(|auth| auth == expected)
732 .unwrap_or(false)
733}
734
735#[cfg(feature = "swagger-ui")]
737fn unauthorized_response() -> crate::Response {
738 http::Response::builder()
739 .status(http::StatusCode::UNAUTHORIZED)
740 .header(
741 http::header::WWW_AUTHENTICATE,
742 "Basic realm=\"API Documentation\"",
743 )
744 .header(http::header::CONTENT_TYPE, "text/plain")
745 .body(http_body_util::Full::new(bytes::Bytes::from(
746 "Unauthorized",
747 )))
748 .unwrap()
749}
750
751pub struct RustApiConfig {
753 docs_path: Option<String>,
754 docs_enabled: bool,
755 api_title: String,
756 api_version: String,
757 api_description: Option<String>,
758 body_limit: Option<usize>,
759 layers: LayerStack,
760}
761
762impl RustApiConfig {
763 pub fn new() -> Self {
764 Self {
765 docs_path: Some("/docs".to_string()),
766 docs_enabled: true,
767 api_title: "RustAPI".to_string(),
768 api_version: "1.0.0".to_string(),
769 api_description: None,
770 body_limit: None,
771 layers: LayerStack::new(),
772 }
773 }
774
775 pub fn docs_path(mut self, path: impl Into<String>) -> Self {
777 self.docs_path = Some(path.into());
778 self
779 }
780
781 pub fn docs_enabled(mut self, enabled: bool) -> Self {
783 self.docs_enabled = enabled;
784 self
785 }
786
787 pub fn openapi_info(
789 mut self,
790 title: impl Into<String>,
791 version: impl Into<String>,
792 description: Option<impl Into<String>>,
793 ) -> Self {
794 self.api_title = title.into();
795 self.api_version = version.into();
796 self.api_description = description.map(|d| d.into());
797 self
798 }
799
800 pub fn body_limit(mut self, limit: usize) -> Self {
802 self.body_limit = Some(limit);
803 self
804 }
805
806 pub fn layer<L>(mut self, layer: L) -> Self
808 where
809 L: MiddlewareLayer,
810 {
811 self.layers.push(Box::new(layer));
812 self
813 }
814
815 pub fn build(self) -> RustApi {
817 let mut app = RustApi::new().mount_auto_routes_grouped();
818
819 if let Some(limit) = self.body_limit {
821 app = app.body_limit(limit);
822 }
823
824 app = app.openapi_info(
825 &self.api_title,
826 &self.api_version,
827 self.api_description.as_deref(),
828 );
829
830 #[cfg(feature = "swagger-ui")]
831 if self.docs_enabled {
832 if let Some(path) = self.docs_path {
833 app = app.docs(&path);
834 }
835 }
836
837 app.layers.extend(self.layers.into_iter());
840
841 app
842 }
843
844 pub async fn run(
846 self,
847 addr: impl AsRef<str>,
848 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
849 self.build().run(addr.as_ref()).await
850 }
851}