salvo_oapi/
routing.rs

1use std::any::TypeId;
2use std::collections::{BTreeSet, HashMap};
3use std::sync::{LazyLock, RwLock};
4
5use regex::Regex;
6use salvo_core::Router;
7
8use crate::SecurityRequirement;
9use crate::path::PathItemType;
10
11#[derive(Debug, Default)]
12pub(crate) struct NormNode {
13    // pub(crate) router_id: usize,
14    pub(crate) handler_type_id: Option<TypeId>,
15    pub(crate) handler_type_name: Option<&'static str>,
16    pub(crate) method: Option<PathItemType>,
17    pub(crate) path: Option<String>,
18    pub(crate) children: Vec<Self>,
19    pub(crate) metadata: Metadata,
20}
21
22impl NormNode {
23    pub(crate) fn new(router: &Router, inherited_metadata: Metadata) -> Self {
24        let mut node = Self {
25            // router_id: router.id,
26            metadata: inherited_metadata,
27            ..Self::default()
28        };
29        let registry = METADATA_REGISTRY
30            .read()
31            .expect("failed to lock METADATA_REGISTRY for read");
32        if let Some(metadata) = registry.get(&router.id) {
33            node.metadata.tags.extend(metadata.tags.iter().cloned());
34            node.metadata
35                .securities
36                .extend(metadata.securities.iter().cloned());
37        }
38
39        let regex = Regex::new(r#"<([^/:>]+)(:[^>]*)?>"#).expect("invalid regex");
40        for filter in router.filters() {
41            let info = format!("{filter:?}");
42            if info.starts_with("path:") {
43                let path = info
44                    .split_once(':')
45                    .expect("split once by ':' should not be get `None`")
46                    .1;
47                node.path = Some(regex.replace_all(path, "{$1}").to_string());
48            } else if info.starts_with("method:") {
49                match info
50                    .split_once(':')
51                    .expect("split once by ':' should not be get `None`.")
52                    .1
53                {
54                    "GET" => node.method = Some(PathItemType::Get),
55                    "POST" => node.method = Some(PathItemType::Post),
56                    "PUT" => node.method = Some(PathItemType::Put),
57                    "DELETE" => node.method = Some(PathItemType::Delete),
58                    "HEAD" => node.method = Some(PathItemType::Head),
59                    "OPTIONS" => node.method = Some(PathItemType::Options),
60                    "CONNECT" => node.method = Some(PathItemType::Connect),
61                    "TRACE" => node.method = Some(PathItemType::Trace),
62                    "PATCH" => node.method = Some(PathItemType::Patch),
63                    _ => {}
64                }
65            }
66        }
67        node.handler_type_id = router.goal.as_ref().map(|h| h.type_id());
68        node.handler_type_name = router.goal.as_ref().map(|h| h.type_name());
69        let routers = router.routers();
70        if !routers.is_empty() {
71            for router in routers {
72                node.children.push(Self::new(router, node.metadata.clone()));
73            }
74        }
75        node
76    }
77}
78
79/// A component for save router metadata.
80type MetadataMap = RwLock<HashMap<usize, Metadata>>;
81static METADATA_REGISTRY: LazyLock<MetadataMap> = LazyLock::new(MetadataMap::default);
82
83/// Router extension trait for openapi metadata.
84pub trait RouterExt {
85    /// Add security requirement to the router.
86    ///
87    /// All endpoints in the router and it's descents will inherit this security requirement.
88    #[must_use]
89    fn oapi_security(self, security: SecurityRequirement) -> Self;
90
91    /// Add security requirements to the router.
92    ///
93    /// All endpoints in the router and it's descents will inherit these security requirements.
94    #[must_use]
95    fn oapi_securities<I>(self, security: I) -> Self
96    where
97        I: IntoIterator<Item = SecurityRequirement>;
98
99    /// Add tag to the router.
100    ///
101    /// All endpoints in the router and it's descents will inherit this tag.
102    #[must_use]
103    fn oapi_tag(self, tag: impl Into<String>) -> Self;
104
105    /// Add tags to the router.
106    ///
107    /// All endpoints in the router and it's descents will inherit these tags.
108    #[must_use]
109    fn oapi_tags<I, V>(self, tags: I) -> Self
110    where
111        I: IntoIterator<Item = V>,
112        V: Into<String>;
113}
114
115impl RouterExt for Router {
116    fn oapi_security(self, security: SecurityRequirement) -> Self {
117        let mut guard = METADATA_REGISTRY
118            .write()
119            .expect("failed to lock METADATA_REGISTRY for write");
120        let metadata = guard.entry(self.id).or_default();
121        metadata.securities.push(security);
122        self
123    }
124    fn oapi_securities<I>(self, iter: I) -> Self
125    where
126        I: IntoIterator<Item = SecurityRequirement>,
127    {
128        let mut guard = METADATA_REGISTRY
129            .write()
130            .expect("failed to lock METADATA_REGISTRY for write");
131        let metadata = guard.entry(self.id).or_default();
132        metadata.securities.extend(iter);
133        self
134    }
135    fn oapi_tag(self, tag: impl Into<String>) -> Self {
136        let mut guard = METADATA_REGISTRY
137            .write()
138            .expect("failed to lock METADATA_REGISTRY for write");
139        let metadata = guard.entry(self.id).or_default();
140        metadata.tags.insert(tag.into());
141        self
142    }
143    fn oapi_tags<I, V>(self, iter: I) -> Self
144    where
145        I: IntoIterator<Item = V>,
146        V: Into<String>,
147    {
148        let mut guard = METADATA_REGISTRY
149            .write()
150            .expect("failed to lock METADATA_REGISTRY for write");
151        let metadata = guard.entry(self.id).or_default();
152        metadata.tags.extend(iter.into_iter().map(Into::into));
153        self
154    }
155}
156
157#[non_exhaustive]
158#[derive(Default, Clone, Debug)]
159pub(crate) struct Metadata {
160    pub(crate) tags: BTreeSet<String>,
161    pub(crate) securities: Vec<SecurityRequirement>,
162}