salvo_oapi/
routing.rs

1use std::any::TypeId;
2use std::collections::{BTreeSet, HashMap};
3use std::sync::{LazyLock, RwLock};
4
5use regex::Regex;
6use salvo_core::Router;
7
8use crate::{SecurityRequirement, path::PathItemType};
9
10#[derive(Debug, Default)]
11pub(crate) struct NormNode {
12    // pub(crate) router_id: usize,
13    pub(crate) handler_type_id: Option<TypeId>,
14    pub(crate) handler_type_name: Option<&'static str>,
15    pub(crate) method: Option<PathItemType>,
16    pub(crate) path: Option<String>,
17    pub(crate) children: Vec<NormNode>,
18    pub(crate) metadata: Metadata,
19}
20
21impl NormNode {
22    pub(crate) fn new(router: &Router, inherited_metadata: Metadata) -> Self {
23        let mut node = Self {
24            // router_id: router.id,
25            metadata: inherited_metadata,
26            ..Self::default()
27        };
28        let registry = METADATA_REGISTRY
29            .read()
30            .expect("failed to lock METADATA_REGISTRY for read");
31        if let Some(metadata) = registry.get(&router.id) {
32            node.metadata.tags.extend(metadata.tags.iter().cloned());
33            node.metadata
34                .securities
35                .extend(metadata.securities.iter().cloned());
36        }
37
38        let regex = Regex::new(r#"<([^/:>]+)(:[^>]*)?>"#).expect("invalid regex");
39        for filter in router.filters() {
40            let info = format!("{filter:?}");
41            if info.starts_with("path:") {
42                let path = info
43                    .split_once(':')
44                    .expect("split once by ':' should not be get `None`")
45                    .1;
46                node.path = Some(regex.replace_all(path, "{$1}").to_string());
47            } else if info.starts_with("method:") {
48                match info
49                    .split_once(':')
50                    .expect("split once by ':' should not be get `None`.")
51                    .1
52                {
53                    "GET" => node.method = Some(PathItemType::Get),
54                    "POST" => node.method = Some(PathItemType::Post),
55                    "PUT" => node.method = Some(PathItemType::Put),
56                    "DELETE" => node.method = Some(PathItemType::Delete),
57                    "HEAD" => node.method = Some(PathItemType::Head),
58                    "OPTIONS" => node.method = Some(PathItemType::Options),
59                    "CONNECT" => node.method = Some(PathItemType::Connect),
60                    "TRACE" => node.method = Some(PathItemType::Trace),
61                    "PATCH" => node.method = Some(PathItemType::Patch),
62                    _ => {}
63                }
64            }
65        }
66        node.handler_type_id = router.goal.as_ref().map(|h| h.type_id());
67        node.handler_type_name = router.goal.as_ref().map(|h| h.type_name());
68        let routers = router.routers();
69        if !routers.is_empty() {
70            for router in routers {
71                node.children.push(Self::new(router, node.metadata.clone()));
72            }
73        }
74        node
75    }
76}
77
78/// A component for save router metadata.
79type MetadataMap = RwLock<HashMap<usize, Metadata>>;
80static METADATA_REGISTRY: LazyLock<MetadataMap> = LazyLock::new(MetadataMap::default);
81
82/// Router extension trait for openapi metadata.
83pub trait RouterExt {
84    /// Add security requirement to the router.
85    ///
86    /// All endpoints in the router and it's descents will inherit this security requirement.
87    #[must_use]
88    fn oapi_security(self, security: SecurityRequirement) -> Self;
89
90    /// Add security requirements to the router.
91    ///
92    /// All endpoints in the router and it's descents will inherit these security requirements.
93    #[must_use]
94    fn oapi_securities<I>(self, security: I) -> Self
95    where
96        I: IntoIterator<Item = SecurityRequirement>;
97
98    /// Add tag to the router.
99    ///
100    /// All endpoints in the router and it's descents will inherit this tag.
101    #[must_use]
102    fn oapi_tag(self, tag: impl Into<String>) -> Self;
103
104    /// Add tags to the router.
105    ///
106    /// All endpoints in the router and it's descents will inherit these tags.
107    #[must_use]
108    fn oapi_tags<I, V>(self, tags: I) -> Self
109    where
110        I: IntoIterator<Item = V>,
111        V: Into<String>;
112}
113
114impl RouterExt for Router {
115    fn oapi_security(self, security: SecurityRequirement) -> Self {
116        let mut guard = METADATA_REGISTRY
117            .write()
118            .expect("failed to lock METADATA_REGISTRY for write");
119        let metadata = guard.entry(self.id).or_default();
120        metadata.securities.push(security);
121        self
122    }
123    fn oapi_securities<I>(self, iter: I) -> Self
124    where
125        I: IntoIterator<Item = SecurityRequirement>,
126    {
127        let mut guard = METADATA_REGISTRY
128            .write()
129            .expect("failed to lock METADATA_REGISTRY for write");
130        let metadata = guard.entry(self.id).or_default();
131        metadata.securities.extend(iter);
132        self
133    }
134    fn oapi_tag(self, tag: impl Into<String>) -> Self {
135        let mut guard = METADATA_REGISTRY
136            .write()
137            .expect("failed to lock METADATA_REGISTRY for write");
138        let metadata = guard.entry(self.id).or_default();
139        metadata.tags.insert(tag.into());
140        self
141    }
142    fn oapi_tags<I, V>(self, iter: I) -> Self
143    where
144        I: IntoIterator<Item = V>,
145        V: Into<String>,
146    {
147        let mut guard = METADATA_REGISTRY
148            .write()
149            .expect("failed to lock METADATA_REGISTRY for write");
150        let metadata = guard.entry(self.id).or_default();
151        metadata.tags.extend(iter.into_iter().map(Into::into));
152        self
153    }
154}
155
156#[non_exhaustive]
157#[derive(Default, Clone, Debug)]
158pub(crate) struct Metadata {
159    pub(crate) tags: BTreeSet<String>,
160    pub(crate) securities: Vec<SecurityRequirement>,
161}