Skip to main content

salvo_oapi/
naming.rs

1use std::any::TypeId;
2use std::collections::BTreeMap;
3use std::sync::LazyLock;
4
5use parking_lot::{RwLock, RwLockReadGuard};
6use regex::Regex;
7
8/// NameRule is used to specify the rule of naming.
9#[derive(Default, Debug, Clone, Copy)]
10pub enum NameRule {
11    /// Auto generate name by namer.
12    #[default]
13    Auto,
14    /// Force to use the given name.
15    Force(&'static str),
16}
17
18static GLOBAL_NAMER: LazyLock<RwLock<Box<dyn Namer>>> =
19    LazyLock::new(|| RwLock::new(Box::new(FlexNamer::new())));
20static NAME_TYPES: LazyLock<RwLock<BTreeMap<String, (TypeId, &'static str)>>> =
21    LazyLock::new(Default::default);
22
23/// Set global namer.
24///
25/// Set global namer, all the types will be named by this namer. You should call this method before
26/// at before you generate OpenAPI schema.
27///
28/// # Example
29///
30/// ```rust
31/// # use salvo_oapi::extract::*;
32/// # use salvo_core::prelude::*;
33/// # #[tokio::main]
34/// # async fn main() {
35/// salvo_oapi::naming::set_namer(
36///     salvo_oapi::naming::FlexNamer::new()
37///         .short_mode(true)
38///         .generic_delimiter('_', '_'),
39/// );
40/// # }
41/// ```
42pub fn set_namer(namer: impl Namer) {
43    *GLOBAL_NAMER.write() = Box::new(namer);
44    NAME_TYPES.write().clear();
45}
46
47/// Reset global naming state to defaults.
48///
49/// This clears all registered type names and resets the namer to default `FlexNamer`.
50/// Primarily useful for testing to ensure test isolation.
51#[cfg(test)]
52pub fn reset_global_state() {
53    *GLOBAL_NAMER.write() = Box::new(FlexNamer::new());
54    NAME_TYPES.write().clear();
55}
56
57#[doc(hidden)]
58pub fn namer() -> RwLockReadGuard<'static, Box<dyn Namer>> {
59    GLOBAL_NAMER.read()
60}
61
62/// Get type info by name.
63pub fn type_info_by_name(name: &str) -> Option<(TypeId, &'static str)> {
64    NAME_TYPES.read().get(name).cloned()
65}
66
67/// Get registered name by rust type name (from `std::any::type_name`).
68///
69/// This searches through NAME_TYPES to find if a type with the given rust type name
70/// has been registered with a custom name.
71pub fn name_by_type_name(type_name: &str) -> Option<String> {
72    NAME_TYPES
73        .read()
74        .iter()
75        .find(|(_, (_, registered_type_name))| *registered_type_name == type_name)
76        .map(|(name, _)| name.clone())
77}
78
79/// Resolve generic type parameters to their registered names.
80///
81/// This function recursively processes a type name string and replaces any
82/// generic type parameters with their registered names from NAME_TYPES.
83///
84/// For example, if `CityDTO` is registered as `City`, then:
85/// - `Response<CityDTO>` becomes `Response<City>`
86/// - `Vec<HashMap<String, CityDTO>>` becomes `Vec<HashMap<String, City>>`
87#[must_use]
88pub fn resolve_generic_names(type_name: &str) -> String {
89    // First check if the entire type (without generics) has a registered name
90    if let Some(registered_name) = name_by_type_name(type_name) {
91        return registered_name;
92    }
93
94    // Find the position of the first '<' to separate base type from generic params
95    let Some(generic_start) = type_name.find('<') else {
96        // No generics, return as-is
97        return type_name.to_owned();
98    };
99
100    // Extract base type and generic part
101    let Some(base_type) = type_name.get(..generic_start) else {
102        return type_name.to_owned();
103    };
104    let Some(generic_part) = type_name.get(generic_start..) else {
105        return type_name.to_owned();
106    };
107
108    // Parse and resolve each generic parameter
109    let resolved_generic = resolve_generic_part(generic_part);
110
111    format!("{base_type}{resolved_generic}")
112}
113
114/// Parse generic part like `<A, B<C, D>, E>` and resolve each type parameter.
115fn resolve_generic_part(generic_part: &str) -> String {
116    if !generic_part.starts_with('<') || !generic_part.ends_with('>') {
117        return generic_part.to_owned();
118    }
119
120    // Remove outer < and >
121    let Some(inner) = generic_part
122        .strip_prefix('<')
123        .and_then(|generic_part| generic_part.strip_suffix('>'))
124    else {
125        return generic_part.to_owned();
126    };
127
128    // Split by top-level commas (not nested in <>)
129    let params = split_generic_params(inner);
130
131    let resolved_params: Vec<String> = params
132        .into_iter()
133        .map(|param| {
134            let param = param.trim();
135            // Check if this exact type has a registered name
136            if let Some(registered_name) = name_by_type_name(param) {
137                registered_name
138            } else if param.contains('<') {
139                // Recursively resolve nested generics
140                resolve_generic_names(param)
141            } else {
142                // Use short name for unregistered types (like primitive types)
143                // e.g., "alloc::string::String" -> "String"
144                short_type_name(param).to_owned()
145            }
146        })
147        .collect();
148
149    format!("<{}>", resolved_params.join(", "))
150}
151
152/// Split generic parameters at top-level commas, respecting nested angle brackets.
153fn split_generic_params(s: &str) -> Vec<&str> {
154    let mut result = Vec::new();
155    let mut depth = 0;
156    let mut start = 0;
157
158    for (i, c) in s.char_indices() {
159        match c {
160            '<' => depth += 1,
161            '>' => depth -= 1,
162            ',' if depth == 0 => {
163                if let Some(param) = s.get(start..i) {
164                    result.push(param);
165                }
166                start = i + 1;
167            }
168            _ => {}
169        }
170    }
171
172    // Don't forget the last segment
173    if start < s.len()
174        && let Some(param) = s.get(start..)
175    {
176        result.push(param);
177    }
178
179    result
180}
181
182/// Extract the short name from a fully qualified type path.
183///
184/// For example:
185/// - `alloc::string::String` -> `String`
186/// - `std::collections::HashMap` -> `HashMap`
187/// - `my_crate::module::MyType` -> `MyType`
188fn short_type_name(type_name: &str) -> &str {
189    // Find the last `::` and return everything after it
190    type_name
191        .rfind("::")
192        .and_then(|pos| type_name.get(pos + 2..))
193        .unwrap_or(type_name)
194}
195
196/// Set type info by name.
197pub fn set_name_type_info(
198    name: String,
199    type_id: TypeId,
200    type_name: &'static str,
201) -> Option<(TypeId, &'static str)> {
202    NAME_TYPES.write().insert(name, (type_id, type_name))
203}
204
205/// Assign name to type and returns the name.
206///
207/// If the type is already named, return the existing name.
208pub fn assign_name<T: 'static>(rule: NameRule) -> String {
209    let type_id = TypeId::of::<T>();
210    let type_name = std::any::type_name::<T>();
211    for (name, (exist_id, _)) in NAME_TYPES.read().iter() {
212        if *exist_id == type_id {
213            return name.clone();
214        }
215    }
216    namer().assign_name(type_id, type_name, rule)
217}
218
219/// Get the name of the type. Panic if the name is not exist.
220pub fn get_name<T: 'static>() -> String {
221    let type_id = TypeId::of::<T>();
222    for (name, (exist_id, _)) in NAME_TYPES.read().iter() {
223        if *exist_id == type_id {
224            return name.clone();
225        }
226    }
227    panic!(
228        "Type not found in the name registry: {:?}",
229        std::any::type_name::<T>()
230    );
231}
232
233fn type_generic_part(type_name: &str) -> String {
234    if let Some(pos) = type_name.find('<') {
235        type_name.get(pos..).unwrap_or_default().to_owned()
236    } else {
237        String::new()
238    }
239}
240
241/// Resolve generic part and format it according to namer settings.
242fn resolve_and_format_generic_part(type_name: &str, short_mode: bool) -> String {
243    let generic_part = type_generic_part(type_name);
244    if generic_part.is_empty() {
245        return generic_part;
246    }
247
248    // Resolve registered names in generic parameters
249    let resolved = resolve_generic_part(&generic_part);
250
251    // Apply formatting (:: -> . for non-short mode, or strip module paths for short mode)
252    if short_mode {
253        let re = Regex::new(r"([^<>, ]*::)+").expect("Invalid regex");
254        re.replace_all(&resolved, "").into_owned()
255    } else {
256        resolved.replace("::", ".")
257    }
258}
259/// Namer is used to assign names to types.
260pub trait Namer: Sync + Send + 'static {
261    /// Assign name to type.
262    fn assign_name(&self, type_id: TypeId, type_name: &'static str, rule: NameRule) -> String;
263}
264
265/// A namer that generates wordy names.
266#[derive(Default, Clone, Debug)]
267pub struct FlexNamer {
268    short_mode: bool,
269    generic_delimiter: Option<(String, String)>,
270}
271impl FlexNamer {
272    /// Create a new FlexNamer.
273    #[must_use]
274    pub fn new() -> Self {
275        Default::default()
276    }
277
278    /// Set the short mode.
279    #[must_use]
280    pub fn short_mode(mut self, short_mode: bool) -> Self {
281        self.short_mode = short_mode;
282        self
283    }
284
285    /// Set the delimiter for generic types.
286    #[must_use]
287    pub fn generic_delimiter(mut self, open: impl Into<String>, close: impl Into<String>) -> Self {
288        self.generic_delimiter = Some((open.into(), close.into()));
289        self
290    }
291}
292impl Namer for FlexNamer {
293    fn assign_name(&self, type_id: TypeId, type_name: &'static str, rule: NameRule) -> String {
294        let name = match rule {
295            NameRule::Auto => {
296                // First resolve any registered names in generic parameters
297                let resolved_type_name = resolve_generic_names(type_name);
298
299                let mut base = if self.short_mode {
300                    let re = Regex::new(r"([^<>, ]*::)+").expect("Invalid regex");
301                    re.replace_all(&resolved_type_name, "").into_owned()
302                } else {
303                    resolved_type_name.replace("::", ".")
304                };
305                if let Some((open, close)) = &self.generic_delimiter {
306                    base = base.replace('<', open).replace('>', close);
307                }
308                let mut name = base.clone();
309                let mut count = 1;
310                while let Some(exist_id) = type_info_by_name(&name).map(|t| t.0) {
311                    if exist_id != type_id {
312                        count += 1;
313                        name = format!("{base}{count}");
314                    } else {
315                        break;
316                    }
317                }
318                name
319            }
320            NameRule::Force(force_name) => {
321                // Resolve registered names in generic parameters
322                let resolved_generic = resolve_and_format_generic_part(type_name, self.short_mode);
323
324                let mut base = if self.short_mode {
325                    // In short mode with Force, use the forced name + resolved generics
326                    format!("{force_name}{resolved_generic}")
327                } else {
328                    format!("{force_name}{resolved_generic}")
329                };
330                if let Some((open, close)) = &self.generic_delimiter {
331                    base = base.replace('<', open).replace('>', close);
332                }
333                let mut name = base.clone();
334                let mut count = 1;
335                while let Some((exist_id, exist_name)) = type_info_by_name(&name) {
336                    if exist_id != type_id {
337                        count += 1;
338                        tracing::error!("Duplicate name for types: {}, {}", exist_name, type_name);
339                        name = format!("{base}{count}");
340                    } else {
341                        break;
342                    }
343                }
344                name
345            }
346        };
347        set_name_type_info(name.clone(), type_id, type_name);
348        name
349    }
350}
351
352#[cfg(test)]
353mod tests {
354    use serial_test::serial;
355
356    #[test]
357    #[serial]
358    fn test_name() {
359        use super::*;
360
361        // Reset global state to ensure deterministic test results
362        reset_global_state();
363
364        struct MyString;
365        mod nest {
366            pub(crate) struct MyString;
367        }
368
369        let name = assign_name::<String>(NameRule::Auto);
370        assert_eq!(name, "alloc.string.String");
371        let name = assign_name::<Vec<String>>(NameRule::Auto);
372        assert_eq!(name, "alloc.vec.Vec<alloc.string.String>");
373
374        let name = assign_name::<MyString>(NameRule::Auto);
375        assert!(
376            name.contains("MyString") && !name.contains("nest"),
377            "Expected name containing 'MyString' but not 'nest', got: {name}"
378        );
379        let name = assign_name::<nest::MyString>(NameRule::Auto);
380        assert!(
381            name.contains("nest") && name.contains("MyString"),
382            "Expected name containing 'nest.MyString', got: {name}"
383        );
384    }
385
386    #[test]
387    #[serial]
388    fn test_resolve_generic_names() {
389        use super::*;
390
391        // Reset global state to ensure deterministic test results
392        reset_global_state();
393
394        // Simulate registering CityDTO as "City"
395        let city_type_name = "test_module::CityDTO";
396        set_name_type_info(
397            "City".to_owned(),
398            TypeId::of::<()>(), // dummy TypeId
399            city_type_name,
400        );
401
402        // Test resolve_generic_names with registered type
403        let resolved = resolve_generic_names("Response<test_module::CityDTO>");
404        assert_eq!(resolved, "Response<City>");
405
406        // Test nested generics - unregistered types get short names
407        let resolved = resolve_generic_names("Vec<HashMap<String, test_module::CityDTO>>");
408        assert_eq!(resolved, "Vec<HashMap<String, City>>");
409
410        // Test multiple generic parameters
411        let resolved = resolve_generic_names("Tuple<test_module::CityDTO, test_module::CityDTO>");
412        assert_eq!(resolved, "Tuple<City, City>");
413    }
414
415    #[test]
416    #[serial]
417    fn test_resolve_primitive_types() {
418        use super::*;
419
420        // Reset global state to ensure deterministic test results
421        reset_global_state();
422
423        // Test with primitive types (not registered, should use short names in generic params)
424        let resolved = resolve_generic_names("Response<alloc::string::String>");
425        assert_eq!(resolved, "Response<String>");
426
427        // Note: The base type (Vec) keeps its path, only generic params are shortened
428        // FlexNamer::assign_name handles the full path transformation later
429        let resolved = resolve_generic_names("Vec<alloc::vec::Vec<alloc::string::String>>");
430        assert_eq!(resolved, "Vec<alloc::vec::Vec<String>>");
431
432        // Test HashMap with primitive types
433        let resolved =
434            resolve_generic_names("std::collections::HashMap<alloc::string::String, i32>");
435        assert_eq!(resolved, "std::collections::HashMap<String, i32>");
436
437        // Test that nested generic base types are also shortened in their generics
438        let resolved = resolve_generic_names("Option<Vec<alloc::string::String>>");
439        assert_eq!(resolved, "Option<Vec<String>>");
440    }
441
442    #[test]
443    fn test_short_type_name() {
444        use super::*;
445
446        assert_eq!(short_type_name("alloc::string::String"), "String");
447        assert_eq!(short_type_name("std::collections::HashMap"), "HashMap");
448        assert_eq!(short_type_name("MyType"), "MyType");
449        assert_eq!(short_type_name("my_crate::module::submodule::Type"), "Type");
450    }
451
452    #[test]
453    fn test_split_generic_params() {
454        use super::*;
455
456        let params = split_generic_params("A, B, C");
457        assert_eq!(params, vec!["A", " B", " C"]);
458
459        let params = split_generic_params("A<X, Y>, B, C<Z>");
460        assert_eq!(params, vec!["A<X, Y>", " B", " C<Z>"]);
461
462        let params = split_generic_params("A<X<Y, Z>>, B");
463        assert_eq!(params, vec!["A<X<Y, Z>>", " B"]);
464    }
465
466    #[test]
467    #[serial]
468    fn test_assign_name_with_generic_resolution() {
469        use super::*;
470
471        // Reset global state to ensure deterministic test results
472        reset_global_state();
473
474        // Define unique test types for this test to avoid conflicts with other tests
475        mod test_generic_resolution {
476            pub(super) struct CityDTO;
477            pub(super) struct Response<T>(std::marker::PhantomData<T>);
478            pub(super) struct Wrapper<T>(std::marker::PhantomData<T>);
479        }
480        use test_generic_resolution::*;
481
482        // First, register CityDTO with a custom name "City"
483        let city_name = assign_name::<CityDTO>(NameRule::Force("City"));
484        assert_eq!(city_name, "City");
485
486        // Now register Response<CityDTO> with Force("Response")
487        // It should resolve CityDTO to "City" in the generic parameter
488        let response_name = assign_name::<Response<CityDTO>>(NameRule::Force("Response"));
489        assert_eq!(response_name, "Response<City>");
490
491        // Test with Auto mode - should also resolve generic parameters
492        let wrapper_name = assign_name::<Wrapper<CityDTO>>(NameRule::Auto);
493        // The base type will have full path, but CityDTO should be resolved to City
494        assert!(
495            wrapper_name.contains("<City>"),
496            "Expected wrapper name to contain '<City>', got: {wrapper_name}"
497        );
498    }
499
500    #[test]
501    #[serial]
502    fn test_assign_name_with_primitive_generics() {
503        use super::*;
504
505        // Reset global state to ensure deterministic test results
506        reset_global_state();
507
508        mod test_primitive_generics {
509            pub(super) struct Response<T>(std::marker::PhantomData<T>);
510        }
511        use test_primitive_generics::*;
512
513        // Test Response<String> with Force("Response")
514        // String is not registered, but should be shortened to "String"
515        let response_name = assign_name::<Response<String>>(NameRule::Force("Response"));
516        assert_eq!(response_name, "Response<String>");
517
518        // Test Response<Vec<String>> - nested generics with primitives
519        let response_vec_name =
520            assign_name::<Response<Vec<String>>>(NameRule::Force("ResponseVec"));
521        assert!(
522            response_vec_name.contains("<String>"),
523            "Expected name to contain '<String>', got: {response_vec_name}"
524        );
525    }
526}