Skip to main content

salvo_oapi/
routing.rs

1use std::any::TypeId;
2use std::collections::{BTreeSet, HashMap};
3use std::sync::{LazyLock, RwLock};
4
5use salvo_core::Router;
6
7use crate::SecurityRequirement;
8use crate::path::PathItemType;
9
10fn normalize_oapi_path(path: &str) -> String {
11    let mut normalized = String::with_capacity(path.len());
12    let mut chars = path.char_indices().peekable();
13
14    while let Some((start, ch)) = chars.next() {
15        if ch != '{' {
16            normalized.push(ch);
17            continue;
18        }
19        // Keep escaped literal braces (`{{`) as-is.
20        if chars.peek().map(|(_, next)| *next) == Some('{') {
21            normalized.push('{');
22            normalized.push('{');
23            chars.next();
24            continue;
25        }
26
27        let content_start = start + ch.len_utf8();
28        let mut braces_depth = 0usize;
29        let mut escaping = false;
30        let mut param_end = None;
31
32        while let Some((idx, current)) = chars.next() {
33            if escaping {
34                escaping = false;
35                continue;
36            }
37            match current {
38                '\\' => escaping = true,
39                '{' => braces_depth += 1,
40                '}' => {
41                    if braces_depth == 0 {
42                        param_end = Some(idx);
43                        break;
44                    }
45                    braces_depth -= 1;
46                }
47                _ => {}
48            }
49        }
50
51        if let Some(param_end) = param_end {
52            let content = &path[content_start..param_end];
53            if let Some(name_end) = content.find([':', '|']) {
54                normalized.push('{');
55                normalized.push_str(&content[..name_end]);
56                normalized.push('}');
57            } else {
58                normalized.push('{');
59                normalized.push_str(content);
60                normalized.push('}');
61            }
62        } else {
63            normalized.push_str(&path[start..]);
64            break;
65        }
66    }
67    normalized
68}
69
70#[derive(Debug, Default)]
71pub(crate) struct NormNode {
72    // pub(crate) router_id: usize,
73    pub(crate) handler_type_id: Option<TypeId>,
74    pub(crate) handler_type_name: Option<&'static str>,
75    pub(crate) method: Option<PathItemType>,
76    pub(crate) path: Option<String>,
77    pub(crate) children: Vec<Self>,
78    pub(crate) metadata: Metadata,
79}
80
81impl NormNode {
82    pub(crate) fn new(router: &Router, inherited_metadata: Metadata) -> Self {
83        let mut node = Self {
84            // router_id: router.id,
85            metadata: inherited_metadata,
86            ..Self::default()
87        };
88        let registry = METADATA_REGISTRY
89            .read()
90            .expect("failed to lock METADATA_REGISTRY for read");
91        if let Some(metadata) = registry.get(&router.id) {
92            node.metadata.tags.extend(metadata.tags.iter().cloned());
93            node.metadata
94                .securities
95                .extend(metadata.securities.iter().cloned());
96        }
97
98        for filter in router.filters() {
99            let info = format!("{filter:?}");
100            if info.starts_with("path:") {
101                let path = info
102                    .split_once(':')
103                    .expect("split once by ':' should not be get `None`")
104                    .1;
105                node.path = Some(normalize_oapi_path(path));
106            } else if info.starts_with("method:") {
107                match info
108                    .split_once(':')
109                    .expect("split once by ':' should not be get `None`.")
110                    .1
111                {
112                    "GET" => node.method = Some(PathItemType::Get),
113                    "POST" => node.method = Some(PathItemType::Post),
114                    "PUT" => node.method = Some(PathItemType::Put),
115                    "DELETE" => node.method = Some(PathItemType::Delete),
116                    "HEAD" => node.method = Some(PathItemType::Head),
117                    "OPTIONS" => node.method = Some(PathItemType::Options),
118                    "CONNECT" => node.method = Some(PathItemType::Connect),
119                    "TRACE" => node.method = Some(PathItemType::Trace),
120                    "PATCH" => node.method = Some(PathItemType::Patch),
121                    _ => {}
122                }
123            }
124        }
125        node.handler_type_id = router.goal.as_ref().map(|h| h.type_id());
126        node.handler_type_name = router.goal.as_ref().map(|h| h.type_name());
127        let routers = router.routers();
128        if !routers.is_empty() {
129            for router in routers {
130                node.children.push(Self::new(router, node.metadata.clone()));
131            }
132        }
133        node
134    }
135}
136
137/// A component for save router metadata.
138type MetadataMap = RwLock<HashMap<usize, Metadata>>;
139static METADATA_REGISTRY: LazyLock<MetadataMap> = LazyLock::new(MetadataMap::default);
140
141/// Router extension trait for openapi metadata.
142pub trait RouterExt {
143    /// Add security requirement to the router.
144    ///
145    /// All endpoints in the router and it's descents will inherit this security requirement.
146    #[must_use]
147    fn oapi_security(self, security: SecurityRequirement) -> Self;
148
149    /// Add security requirements to the router.
150    ///
151    /// All endpoints in the router and it's descents will inherit these security requirements.
152    #[must_use]
153    fn oapi_securities<I>(self, security: I) -> Self
154    where
155        I: IntoIterator<Item = SecurityRequirement>;
156
157    /// Add tag to the router.
158    ///
159    /// All endpoints in the router and it's descents will inherit this tag.
160    #[must_use]
161    fn oapi_tag(self, tag: impl Into<String>) -> Self;
162
163    /// Add tags to the router.
164    ///
165    /// All endpoints in the router and it's descents will inherit these tags.
166    #[must_use]
167    fn oapi_tags<I, V>(self, tags: I) -> Self
168    where
169        I: IntoIterator<Item = V>,
170        V: Into<String>;
171}
172
173impl RouterExt for Router {
174    fn oapi_security(self, security: SecurityRequirement) -> Self {
175        let mut guard = METADATA_REGISTRY
176            .write()
177            .expect("failed to lock METADATA_REGISTRY for write");
178        let metadata = guard.entry(self.id).or_default();
179        metadata.securities.push(security);
180        self
181    }
182    fn oapi_securities<I>(self, iter: I) -> Self
183    where
184        I: IntoIterator<Item = SecurityRequirement>,
185    {
186        let mut guard = METADATA_REGISTRY
187            .write()
188            .expect("failed to lock METADATA_REGISTRY for write");
189        let metadata = guard.entry(self.id).or_default();
190        metadata.securities.extend(iter);
191        self
192    }
193    fn oapi_tag(self, tag: impl Into<String>) -> Self {
194        let mut guard = METADATA_REGISTRY
195            .write()
196            .expect("failed to lock METADATA_REGISTRY for write");
197        let metadata = guard.entry(self.id).or_default();
198        metadata.tags.insert(tag.into());
199        self
200    }
201    fn oapi_tags<I, V>(self, iter: I) -> Self
202    where
203        I: IntoIterator<Item = V>,
204        V: Into<String>,
205    {
206        let mut guard = METADATA_REGISTRY
207            .write()
208            .expect("failed to lock METADATA_REGISTRY for write");
209        let metadata = guard.entry(self.id).or_default();
210        metadata.tags.extend(iter.into_iter().map(Into::into));
211        self
212    }
213}
214
215#[non_exhaustive]
216#[derive(Default, Clone, Debug)]
217pub(crate) struct Metadata {
218    pub(crate) tags: BTreeSet<String>,
219    pub(crate) securities: Vec<SecurityRequirement>,
220}
221
222#[cfg(test)]
223mod tests {
224    use super::normalize_oapi_path;
225
226    #[test]
227    fn normalize_braced_path_constraints() {
228        assert_eq!(normalize_oapi_path("/posts/{id}"), "/posts/{id}");
229        assert_eq!(normalize_oapi_path("/posts/{id:num}"), "/posts/{id}");
230        assert_eq!(
231            normalize_oapi_path("/posts/{id:num(3..=10)}"),
232            "/posts/{id}"
233        );
234        assert_eq!(normalize_oapi_path(r"/posts/{id|\d+}"), "/posts/{id}");
235        assert_eq!(normalize_oapi_path("/posts/{id|[a-z]{2}}"), "/posts/{id}");
236        assert_eq!(
237            normalize_oapi_path("/posts/article_{id:num}"),
238            "/posts/article_{id}"
239        );
240    }
241}