Skip to main content

salvo_oapi/
naming.rs

1use std::any::TypeId;
2use std::collections::BTreeMap;
3use std::sync::LazyLock;
4
5use parking_lot::{RwLock, RwLockReadGuard};
6use regex::Regex;
7
8/// NameRule is used to specify the rule of naming.
9#[derive(Default, Debug, Clone, Copy)]
10pub enum NameRule {
11    /// Auto generate name by namer.
12    #[default]
13    Auto,
14    /// Force to use the given name.
15    Force(&'static str),
16}
17
18static GLOBAL_NAMER: LazyLock<RwLock<Box<dyn Namer>>> =
19    LazyLock::new(|| RwLock::new(Box::new(FlexNamer::new())));
20static NAME_TYPES: LazyLock<RwLock<BTreeMap<String, (TypeId, &'static str)>>> =
21    LazyLock::new(Default::default);
22
23/// Set global namer.
24///
25/// Set global namer, all the types will be named by this namer. You should call this method before
26/// at before you generate OpenAPI schema.
27///
28/// # Example
29///
30/// ```rust
31/// # use salvo_oapi::extract::*;
32/// # use salvo_core::prelude::*;
33/// # #[tokio::main]
34/// # async fn main() {
35/// salvo_oapi::naming::set_namer(
36///     salvo_oapi::naming::FlexNamer::new()
37///         .short_mode(true)
38///         .generic_delimiter('_', '_'),
39/// );
40/// # }
41/// ```
42pub fn set_namer(namer: impl Namer) {
43    *GLOBAL_NAMER.write() = Box::new(namer);
44    NAME_TYPES.write().clear();
45}
46
47/// Reset global naming state to defaults.
48///
49/// This clears all registered type names and resets the namer to default `FlexNamer`.
50/// Primarily useful for testing to ensure test isolation.
51#[cfg(test)]
52pub fn reset_global_state() {
53    *GLOBAL_NAMER.write() = Box::new(FlexNamer::new());
54    NAME_TYPES.write().clear();
55}
56
57#[doc(hidden)]
58pub fn namer() -> RwLockReadGuard<'static, Box<dyn Namer>> {
59    GLOBAL_NAMER.read()
60}
61
62/// Get type info by name.
63pub fn type_info_by_name(name: &str) -> Option<(TypeId, &'static str)> {
64    NAME_TYPES.read().get(name).cloned()
65}
66
67/// Get registered name by rust type name (from `std::any::type_name`).
68///
69/// This searches through NAME_TYPES to find if a type with the given rust type name
70/// has been registered with a custom name.
71pub fn name_by_type_name(type_name: &str) -> Option<String> {
72    NAME_TYPES
73        .read()
74        .iter()
75        .find(|(_, (_, registered_type_name))| *registered_type_name == type_name)
76        .map(|(name, _)| name.clone())
77}
78
79/// Resolve generic type parameters to their registered names.
80///
81/// This function recursively processes a type name string and replaces any
82/// generic type parameters with their registered names from NAME_TYPES.
83///
84/// For example, if `CityDTO` is registered as `City`, then:
85/// - `Response<CityDTO>` becomes `Response<City>`
86/// - `Vec<HashMap<String, CityDTO>>` becomes `Vec<HashMap<String, City>>`
87#[must_use]
88pub fn resolve_generic_names(type_name: &str) -> String {
89    // First check if the entire type (without generics) has a registered name
90    if let Some(registered_name) = name_by_type_name(type_name) {
91        return registered_name;
92    }
93
94    // Find the position of the first '<' to separate base type from generic params
95    let Some(generic_start) = type_name.find('<') else {
96        // No generics, return as-is
97        return type_name.to_owned();
98    };
99
100    // Extract base type and generic part
101    let base_type = &type_name[..generic_start];
102    let generic_part = &type_name[generic_start..];
103
104    // Parse and resolve each generic parameter
105    let resolved_generic = resolve_generic_part(generic_part);
106
107    format!("{base_type}{resolved_generic}")
108}
109
110/// Parse generic part like `<A, B<C, D>, E>` and resolve each type parameter.
111fn resolve_generic_part(generic_part: &str) -> String {
112    if !generic_part.starts_with('<') || !generic_part.ends_with('>') {
113        return generic_part.to_owned();
114    }
115
116    // Remove outer < and >
117    let inner = &generic_part[1..generic_part.len() - 1];
118
119    // Split by top-level commas (not nested in <>)
120    let params = split_generic_params(inner);
121
122    let resolved_params: Vec<String> = params
123        .into_iter()
124        .map(|param| {
125            let param = param.trim();
126            // Check if this exact type has a registered name
127            if let Some(registered_name) = name_by_type_name(param) {
128                registered_name
129            } else if param.contains('<') {
130                // Recursively resolve nested generics
131                resolve_generic_names(param)
132            } else {
133                // Use short name for unregistered types (like primitive types)
134                // e.g., "alloc::string::String" -> "String"
135                short_type_name(param).to_owned()
136            }
137        })
138        .collect();
139
140    format!("<{}>", resolved_params.join(", "))
141}
142
143/// Split generic parameters at top-level commas, respecting nested angle brackets.
144fn split_generic_params(s: &str) -> Vec<&str> {
145    let mut result = Vec::new();
146    let mut depth = 0;
147    let mut start = 0;
148
149    for (i, c) in s.char_indices() {
150        match c {
151            '<' => depth += 1,
152            '>' => depth -= 1,
153            ',' if depth == 0 => {
154                result.push(&s[start..i]);
155                start = i + 1;
156            }
157            _ => {}
158        }
159    }
160
161    // Don't forget the last segment
162    if start < s.len() {
163        result.push(&s[start..]);
164    }
165
166    result
167}
168
169/// Extract the short name from a fully qualified type path.
170///
171/// For example:
172/// - `alloc::string::String` -> `String`
173/// - `std::collections::HashMap` -> `HashMap`
174/// - `my_crate::module::MyType` -> `MyType`
175fn short_type_name(type_name: &str) -> &str {
176    // Find the last `::` and return everything after it
177    type_name
178        .rfind("::")
179        .map(|pos| &type_name[pos + 2..])
180        .unwrap_or(type_name)
181}
182
183/// Set type info by name.
184pub fn set_name_type_info(
185    name: String,
186    type_id: TypeId,
187    type_name: &'static str,
188) -> Option<(TypeId, &'static str)> {
189    NAME_TYPES.write().insert(name, (type_id, type_name))
190}
191
192/// Assign name to type and returns the name.
193///
194/// If the type is already named, return the existing name.
195pub fn assign_name<T: 'static>(rule: NameRule) -> String {
196    let type_id = TypeId::of::<T>();
197    let type_name = std::any::type_name::<T>();
198    for (name, (exist_id, _)) in NAME_TYPES.read().iter() {
199        if *exist_id == type_id {
200            return name.clone();
201        }
202    }
203    namer().assign_name(type_id, type_name, rule)
204}
205
206/// Get the name of the type. Panic if the name is not exist.
207pub fn get_name<T: 'static>() -> String {
208    let type_id = TypeId::of::<T>();
209    for (name, (exist_id, _)) in NAME_TYPES.read().iter() {
210        if *exist_id == type_id {
211            return name.clone();
212        }
213    }
214    panic!(
215        "Type not found in the name registry: {:?}",
216        std::any::type_name::<T>()
217    );
218}
219
220fn type_generic_part(type_name: &str) -> String {
221    if let Some(pos) = type_name.find('<') {
222        type_name[pos..].to_string()
223    } else {
224        String::new()
225    }
226}
227
228/// Resolve generic part and format it according to namer settings.
229fn resolve_and_format_generic_part(type_name: &str, short_mode: bool) -> String {
230    let generic_part = type_generic_part(type_name);
231    if generic_part.is_empty() {
232        return generic_part;
233    }
234
235    // Resolve registered names in generic parameters
236    let resolved = resolve_generic_part(&generic_part);
237
238    // Apply formatting (:: -> . for non-short mode, or strip module paths for short mode)
239    if short_mode {
240        let re = Regex::new(r"([^<>, ]*::)+").expect("Invalid regex");
241        re.replace_all(&resolved, "").to_string()
242    } else {
243        resolved.replace("::", ".")
244    }
245}
246/// Namer is used to assign names to types.
247pub trait Namer: Sync + Send + 'static {
248    /// Assign name to type.
249    fn assign_name(&self, type_id: TypeId, type_name: &'static str, rule: NameRule) -> String;
250}
251
252/// A namer that generates wordy names.
253#[derive(Default, Clone, Debug)]
254pub struct FlexNamer {
255    short_mode: bool,
256    generic_delimiter: Option<(String, String)>,
257}
258impl FlexNamer {
259    /// Create a new FlexNamer.
260    #[must_use]
261    pub fn new() -> Self {
262        Default::default()
263    }
264
265    /// Set the short mode.
266    #[must_use]
267    pub fn short_mode(mut self, short_mode: bool) -> Self {
268        self.short_mode = short_mode;
269        self
270    }
271
272    /// Set the delimiter for generic types.
273    #[must_use]
274    pub fn generic_delimiter(mut self, open: impl Into<String>, close: impl Into<String>) -> Self {
275        self.generic_delimiter = Some((open.into(), close.into()));
276        self
277    }
278}
279impl Namer for FlexNamer {
280    fn assign_name(&self, type_id: TypeId, type_name: &'static str, rule: NameRule) -> String {
281        let name = match rule {
282            NameRule::Auto => {
283                // First resolve any registered names in generic parameters
284                let resolved_type_name = resolve_generic_names(type_name);
285
286                let mut base = if self.short_mode {
287                    let re = Regex::new(r"([^<>, ]*::)+").expect("Invalid regex");
288                    re.replace_all(&resolved_type_name, "").to_string()
289                } else {
290                    resolved_type_name.replace("::", ".")
291                };
292                if let Some((open, close)) = &self.generic_delimiter {
293                    base = base.replace('<', open).replace('>', close);
294                }
295                let mut name = base.clone();
296                let mut count = 1;
297                while let Some(exist_id) = type_info_by_name(&name).map(|t| t.0) {
298                    if exist_id != type_id {
299                        count += 1;
300                        name = format!("{base}{count}");
301                    } else {
302                        break;
303                    }
304                }
305                name
306            }
307            NameRule::Force(force_name) => {
308                // Resolve registered names in generic parameters
309                let resolved_generic = resolve_and_format_generic_part(type_name, self.short_mode);
310
311                let mut base = if self.short_mode {
312                    // In short mode with Force, use the forced name + resolved generics
313                    format!("{force_name}{resolved_generic}")
314                } else {
315                    format!("{force_name}{resolved_generic}")
316                };
317                if let Some((open, close)) = &self.generic_delimiter {
318                    base = base.replace('<', open).replace('>', close);
319                }
320                let mut name = base.clone();
321                let mut count = 1;
322                while let Some((exist_id, exist_name)) = type_info_by_name(&name) {
323                    if exist_id != type_id {
324                        count += 1;
325                        tracing::error!("Duplicate name for types: {}, {}", exist_name, type_name);
326                        name = format!("{base}{count}");
327                    } else {
328                        break;
329                    }
330                }
331                name
332            }
333        };
334        set_name_type_info(name.clone(), type_id, type_name);
335        name
336    }
337}
338
339#[cfg(test)]
340mod tests {
341    use serial_test::serial;
342
343    #[test]
344    #[serial]
345    fn test_name() {
346        use super::*;
347
348        // Reset global state to ensure deterministic test results
349        reset_global_state();
350
351        struct MyString;
352        mod nest {
353            pub(crate) struct MyString;
354        }
355
356        let name = assign_name::<String>(NameRule::Auto);
357        assert_eq!(name, "alloc.string.String");
358        let name = assign_name::<Vec<String>>(NameRule::Auto);
359        assert_eq!(name, "alloc.vec.Vec<alloc.string.String>");
360
361        let name = assign_name::<MyString>(NameRule::Auto);
362        assert!(
363            name.contains("MyString") && !name.contains("nest"),
364            "Expected name containing 'MyString' but not 'nest', got: {name}"
365        );
366        let name = assign_name::<nest::MyString>(NameRule::Auto);
367        assert!(
368            name.contains("nest") && name.contains("MyString"),
369            "Expected name containing 'nest.MyString', got: {name}"
370        );
371    }
372
373    #[test]
374    #[serial]
375    fn test_resolve_generic_names() {
376        use super::*;
377
378        // Reset global state to ensure deterministic test results
379        reset_global_state();
380
381        // Simulate registering CityDTO as "City"
382        let city_type_name = "test_module::CityDTO";
383        set_name_type_info(
384            "City".to_string(),
385            TypeId::of::<()>(), // dummy TypeId
386            city_type_name,
387        );
388
389        // Test resolve_generic_names with registered type
390        let resolved = resolve_generic_names("Response<test_module::CityDTO>");
391        assert_eq!(resolved, "Response<City>");
392
393        // Test nested generics - unregistered types get short names
394        let resolved = resolve_generic_names("Vec<HashMap<String, test_module::CityDTO>>");
395        assert_eq!(resolved, "Vec<HashMap<String, City>>");
396
397        // Test multiple generic parameters
398        let resolved = resolve_generic_names("Tuple<test_module::CityDTO, test_module::CityDTO>");
399        assert_eq!(resolved, "Tuple<City, City>");
400    }
401
402    #[test]
403    #[serial]
404    fn test_resolve_primitive_types() {
405        use super::*;
406
407        // Reset global state to ensure deterministic test results
408        reset_global_state();
409
410        // Test with primitive types (not registered, should use short names in generic params)
411        let resolved = resolve_generic_names("Response<alloc::string::String>");
412        assert_eq!(resolved, "Response<String>");
413
414        // Note: The base type (Vec) keeps its path, only generic params are shortened
415        // FlexNamer::assign_name handles the full path transformation later
416        let resolved = resolve_generic_names("Vec<alloc::vec::Vec<alloc::string::String>>");
417        assert_eq!(resolved, "Vec<alloc::vec::Vec<String>>");
418
419        // Test HashMap with primitive types
420        let resolved =
421            resolve_generic_names("std::collections::HashMap<alloc::string::String, i32>");
422        assert_eq!(resolved, "std::collections::HashMap<String, i32>");
423
424        // Test that nested generic base types are also shortened in their generics
425        let resolved = resolve_generic_names("Option<Vec<alloc::string::String>>");
426        assert_eq!(resolved, "Option<Vec<String>>");
427    }
428
429    #[test]
430    fn test_short_type_name() {
431        use super::*;
432
433        assert_eq!(short_type_name("alloc::string::String"), "String");
434        assert_eq!(short_type_name("std::collections::HashMap"), "HashMap");
435        assert_eq!(short_type_name("MyType"), "MyType");
436        assert_eq!(short_type_name("my_crate::module::submodule::Type"), "Type");
437    }
438
439    #[test]
440    fn test_split_generic_params() {
441        use super::*;
442
443        let params = split_generic_params("A, B, C");
444        assert_eq!(params, vec!["A", " B", " C"]);
445
446        let params = split_generic_params("A<X, Y>, B, C<Z>");
447        assert_eq!(params, vec!["A<X, Y>", " B", " C<Z>"]);
448
449        let params = split_generic_params("A<X<Y, Z>>, B");
450        assert_eq!(params, vec!["A<X<Y, Z>>", " B"]);
451    }
452
453    #[test]
454    #[serial]
455    fn test_assign_name_with_generic_resolution() {
456        use super::*;
457
458        // Reset global state to ensure deterministic test results
459        reset_global_state();
460
461        // Define unique test types for this test to avoid conflicts with other tests
462        mod test_generic_resolution {
463            pub(super) struct CityDTO;
464            pub(super) struct Response<T>(std::marker::PhantomData<T>);
465            pub(super) struct Wrapper<T>(std::marker::PhantomData<T>);
466        }
467        use test_generic_resolution::*;
468
469        // First, register CityDTO with a custom name "City"
470        let city_name = assign_name::<CityDTO>(NameRule::Force("City"));
471        assert_eq!(city_name, "City");
472
473        // Now register Response<CityDTO> with Force("Response")
474        // It should resolve CityDTO to "City" in the generic parameter
475        let response_name = assign_name::<Response<CityDTO>>(NameRule::Force("Response"));
476        assert_eq!(response_name, "Response<City>");
477
478        // Test with Auto mode - should also resolve generic parameters
479        let wrapper_name = assign_name::<Wrapper<CityDTO>>(NameRule::Auto);
480        // The base type will have full path, but CityDTO should be resolved to City
481        assert!(
482            wrapper_name.contains("<City>"),
483            "Expected wrapper name to contain '<City>', got: {}",
484            wrapper_name
485        );
486    }
487
488    #[test]
489    #[serial]
490    fn test_assign_name_with_primitive_generics() {
491        use super::*;
492
493        // Reset global state to ensure deterministic test results
494        reset_global_state();
495
496        mod test_primitive_generics {
497            pub(super) struct Response<T>(std::marker::PhantomData<T>);
498        }
499        use test_primitive_generics::*;
500
501        // Test Response<String> with Force("Response")
502        // String is not registered, but should be shortened to "String"
503        let response_name = assign_name::<Response<String>>(NameRule::Force("Response"));
504        assert_eq!(response_name, "Response<String>");
505
506        // Test Response<Vec<String>> - nested generics with primitives
507        let response_vec_name =
508            assign_name::<Response<Vec<String>>>(NameRule::Force("ResponseVec"));
509        assert!(
510            response_vec_name.contains("<String>"),
511            "Expected name to contain '<String>', got: {}",
512            response_vec_name
513        );
514    }
515}