salvo_oapi/
naming.rs

1use std::any::TypeId;
2use std::collections::BTreeMap;
3use std::sync::LazyLock;
4
5use parking_lot::{RwLock, RwLockReadGuard};
6use regex::Regex;
7
8/// NameRule is used to specify the rule of naming.
9#[derive(Default, Debug, Clone, Copy)]
10pub enum NameRule {
11    /// Auto generate name by namer.
12    #[default]
13    Auto,
14    /// Force to use the given name.
15    Force(&'static str),
16}
17
18static GLOBAL_NAMER: LazyLock<RwLock<Box<dyn Namer>>> =
19    LazyLock::new(|| RwLock::new(Box::new(FlexNamer::new())));
20static NAME_TYPES: LazyLock<RwLock<BTreeMap<String, (TypeId, &'static str)>>> =
21    LazyLock::new(Default::default);
22
23/// Set global namer.
24///
25/// Set global namer, all the types will be named by this namer. You should call this method before
26/// at before you generate OpenAPI schema.
27///
28/// # Example
29///
30/// ```rust
31/// # use salvo_oapi::extract::*;
32/// # use salvo_core::prelude::*;
33/// # #[tokio::main]
34/// # async fn main() {
35///     salvo_oapi::naming::set_namer(salvo_oapi::naming::FlexNamer::new().short_mode(true).generic_delimiter('_', '_'));
36/// # }
37/// ```
38pub fn set_namer(namer: impl Namer) {
39    *GLOBAL_NAMER.write() = Box::new(namer);
40    NAME_TYPES.write().clear();
41}
42
43#[doc(hidden)]
44pub fn namer() -> RwLockReadGuard<'static, Box<dyn Namer>> {
45    GLOBAL_NAMER.read()
46}
47
48/// Get type info by name.
49pub fn type_info_by_name(name: &str) -> Option<(TypeId, &'static str)> {
50    NAME_TYPES.read().get(name).cloned()
51}
52
53/// Set type info by name.
54pub fn set_name_type_info(
55    name: String,
56    type_id: TypeId,
57    type_name: &'static str,
58) -> Option<(TypeId, &'static str)> {
59    NAME_TYPES
60        .write()
61        .insert(name.clone(), (type_id, type_name))
62}
63
64/// Assign name to type and returns the name.
65///
66/// If the type is already named, return the existing name.
67pub fn assign_name<T: 'static>(rule: NameRule) -> String {
68    let type_id = TypeId::of::<T>();
69    let type_name = std::any::type_name::<T>();
70    for (name, (exist_id, _)) in NAME_TYPES.read().iter() {
71        if *exist_id == type_id {
72            return name.clone();
73        }
74    }
75    namer().assign_name(type_id, type_name, rule)
76}
77
78/// Get the name of the type. Panic if the name is not exist.
79pub fn get_name<T: 'static>() -> String {
80    let type_id = TypeId::of::<T>();
81    for (name, (exist_id, _)) in NAME_TYPES.read().iter() {
82        if *exist_id == type_id {
83            return name.clone();
84        }
85    }
86    panic!(
87        "Type not found in the name registry: {:?}",
88        std::any::type_name::<T>()
89    );
90}
91
92fn type_generic_part(type_name: &str) -> String {
93    let re = Regex::new(r"^[^<]+").expect("Invalid regex");
94    let result = re.replace_all(type_name, "");
95    result.to_string()
96}
97/// Namer is used to assign names to types.
98pub trait Namer: Sync + Send + 'static {
99    /// Assign name to type.
100    fn assign_name(&self, type_id: TypeId, type_name: &'static str, rule: NameRule) -> String;
101}
102
103/// A namer that generates wordy names.
104#[derive(Default, Clone, Debug)]
105pub struct FlexNamer {
106    short_mode: bool,
107    generic_delimiter: Option<(String, String)>,
108}
109impl FlexNamer {
110    /// Create a new FlexNamer.
111    pub fn new() -> Self {
112        Default::default()
113    }
114
115    /// Set the short mode.
116    pub fn short_mode(mut self, short_mode: bool) -> Self {
117        self.short_mode = short_mode;
118        self
119    }
120
121    /// Set the delimiter for generic types.
122    pub fn generic_delimiter(mut self, open: impl Into<String>, close: impl Into<String>) -> Self {
123        self.generic_delimiter = Some((open.into(), close.into()));
124        self
125    }
126}
127impl Namer for FlexNamer {
128    fn assign_name(&self, type_id: TypeId, type_name: &'static str, rule: NameRule) -> String {
129        let name = match rule {
130            NameRule::Auto => {
131                let mut base = if self.short_mode {
132                    let re = Regex::new(r"([^<>]*::)+").expect("Invalid regex");
133                    re.replace_all(type_name, "").to_string()
134                } else {
135                    type_name.replace("::", ".")
136                };
137                if let Some((open, close)) = &self.generic_delimiter {
138                    base = base.replace('<', open).replace('>', close);
139                }
140                let mut name = base.to_string();
141                let mut count = 1;
142                while let Some(exist_id) = type_info_by_name(&name).map(|t| t.0) {
143                    if exist_id != type_id {
144                        count += 1;
145                        name = format!("{}{}", base, count);
146                    } else {
147                        break;
148                    }
149                }
150                name
151            }
152            NameRule::Force(force_name) => {
153                let mut base = if self.short_mode {
154                    let re = Regex::new(r"([^<>]*::)+").expect("Invalid regex");
155                    re.replace_all(type_name, "").to_string()
156                } else {
157                    format! {"{}{}", force_name, type_generic_part(type_name).replace("::", ".")}
158                };
159                if let Some((open, close)) = &self.generic_delimiter {
160                    base = base.replace('<', open).replace('>', close).to_string();
161                }
162                let mut name = base.to_string();
163                let mut count = 1;
164                while let Some((exist_id, exist_name)) = type_info_by_name(&name) {
165                    if exist_id != type_id {
166                        count += 1;
167                        tracing::error!("Duplicate name for types: {}, {}", exist_name, type_name);
168                        name = format!("{}{}", base, count);
169                    } else {
170                        break;
171                    }
172                }
173                name.to_string()
174            }
175        };
176        set_name_type_info(name.clone(), type_id, type_name);
177        name
178    }
179}
180
181mod tests {
182    #[test]
183    fn test_name() {
184        use super::*;
185
186        struct MyString;
187        mod nest {
188            pub(crate) struct MyString;
189        }
190
191        let name = assign_name::<String>(NameRule::Auto);
192        assert_eq!(name, "alloc.string.String");
193        let name = assign_name::<Vec<String>>(NameRule::Auto);
194        assert_eq!(name, "alloc.vec.Vec<alloc.string.String>");
195
196        let name = assign_name::<MyString>(NameRule::Auto);
197        assert_eq!(name, "salvo_oapi.naming.tests.test_name.MyString");
198        let name = assign_name::<nest::MyString>(NameRule::Auto);
199        assert_eq!(name, "salvo_oapi.naming.tests.test_name.nest.MyString");
200
201        // let namer = FlexNamer::new().generic_delimiter('_', '_');
202        // set_namer(namer);
203
204        // let name = assign_name::<String>(NameRule::Auto);
205        // assert_eq!(name, "alloc.string.String");
206        // let name = assign_name::<Vec<String>>(NameRule::Auto);
207        // assert_eq!(name, "alloc.vec.Vec_alloc.string.String_");
208
209        // let namer = FlexNamer::new().short_mode(true).generic_delimiter('_', '_');
210        // set_namer(namer);
211
212        // let name = assign_name::<String>(NameRule::Auto);
213        // assert_eq!(name, "String");
214        // let name = assign_name::<Vec<String>>(NameRule::Auto);
215        // assert_eq!(name, "Vec_String_");
216
217        // let namer = FlexNamer::new().short_mode(true).generic_delimiter('_', '_');
218        // set_namer(namer);
219
220        // struct MyString;
221        // mod nest {
222        //     pub(crate) struct MyString;
223        // }
224
225        // let name = assign_name::<MyString>(NameRule::Auto);
226        // assert_eq!(name, "MyString");
227        // let name = assign_name::<nest::MyString>(NameRule::Auto);
228        // assert_eq!(name, "MyString2");
229    }
230}