trz_gateway_common/
id.rs

1//! Utils to generate string-based identifiers.
2
3use std::convert::Infallible;
4
5use axum::extract::OptionalFromRequestParts;
6use axum::http::HeaderName;
7use axum::http::request::Parts;
8
9/// A macro to declare string-based identifiers.
10#[macro_export]
11macro_rules! declare_identifier {
12    ($name:ident) => {
13        #[nameth::nameth]
14        #[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, serde::Serialize, serde::Deserialize)]
15        #[serde(transparent)]
16        pub struct $name {
17            id: std::sync::Arc<str>,
18        }
19
20        impl From<String> for $name {
21            fn from(id: String) -> Self {
22                Self {
23                    id: id.into_boxed_str().into(),
24                }
25            }
26        }
27
28        impl From<&str> for $name {
29            fn from(id: &str) -> Self {
30                id.to_owned().into()
31            }
32        }
33
34        impl AsRef<str> for $name {
35            fn as_ref(&self) -> &str {
36                &self.id
37            }
38        }
39
40        impl std::ops::Deref for $name {
41            type Target = str;
42            fn deref(&self) -> &Self::Target {
43                self.as_ref()
44            }
45        }
46
47        impl std::borrow::Borrow<str> for $name {
48            fn borrow(&self) -> &str {
49                &self
50            }
51        }
52
53        impl std::fmt::Display for $name {
54            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
55                std::fmt::Display::fmt(self.as_ref(), f)
56            }
57        }
58    };
59}
60
61declare_identifier!(ClientName);
62declare_identifier!(ClientId);
63
64/// Name of the header to pass the [ClientId].
65///
66/// This is used to trace connections from a client.
67pub static CLIENT_ID_HEADER: HeaderName = HeaderName::from_static("x-client-id");
68
69impl<S> OptionalFromRequestParts<S> for ClientId
70where
71    S: Send + Sync,
72{
73    type Rejection = Infallible;
74
75    async fn from_request_parts(
76        parts: &mut Parts,
77        _state: &S,
78    ) -> Result<Option<Self>, Self::Rejection> {
79        let Some(client_id) = parts.headers.get(&CLIENT_ID_HEADER) else {
80            return Ok(None);
81        };
82        Ok(client_id.to_str().ok().map(ClientId::from))
83    }
84}
85
86#[cfg(test)]
87mod tests {
88    use std::collections::HashMap;
89
90    #[test]
91    fn declare_identifier_compare() {
92        declare_identifier!(ConnectionId);
93        let c1: ConnectionId = "123".into();
94        let c2: ConnectionId = "123".to_string().into();
95        assert_eq!(c1, c2);
96
97        let c3: ConnectionId = "124".to_string().into();
98        assert_ne!(c1, c3);
99
100        assert!(c1 == c2);
101        assert!(c1 < c3);
102    }
103
104    #[test]
105    fn declare_identifier_hash() {
106        declare_identifier!(ConnectionId);
107        let c1: ConnectionId = "123".into();
108        let c2: ConnectionId = "124".to_string().into();
109
110        let mut map: HashMap<ConnectionId, i32> = HashMap::new();
111        map.insert(c1.clone(), 21);
112        map.insert(c2.clone(), 34);
113
114        assert_eq!(map[&c1], 21);
115        assert_eq!(map[&c2], 34);
116    }
117
118    #[test]
119    fn declare_identifier_serde() {
120        declare_identifier!(ConnectionId);
121        let c = ConnectionId::from("ABC123");
122        let s = serde_json::to_string(&c).unwrap();
123        assert_eq!("\"ABC123\"", s);
124        let cc = serde_json::from_str(&s).unwrap();
125        assert_eq!(c, cc);
126    }
127}