1use std::convert::Infallible;
4
5use axum::extract::OptionalFromRequestParts;
6use axum::http::HeaderName;
7use axum::http::request::Parts;
8
9#[macro_export]
11macro_rules! declare_identifier {
12 ($name:ident) => {
13 #[nameth::nameth]
14 #[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, serde::Serialize, serde::Deserialize)]
15 #[serde(transparent)]
16 pub struct $name {
17 id: std::sync::Arc<str>,
18 }
19
20 impl From<String> for $name {
21 fn from(id: String) -> Self {
22 Self {
23 id: id.into_boxed_str().into(),
24 }
25 }
26 }
27
28 impl From<&str> for $name {
29 fn from(id: &str) -> Self {
30 id.to_owned().into()
31 }
32 }
33
34 impl AsRef<str> for $name {
35 fn as_ref(&self) -> &str {
36 &self.id
37 }
38 }
39
40 impl std::ops::Deref for $name {
41 type Target = str;
42 fn deref(&self) -> &Self::Target {
43 self.as_ref()
44 }
45 }
46
47 impl std::borrow::Borrow<str> for $name {
48 fn borrow(&self) -> &str {
49 &self
50 }
51 }
52
53 impl std::fmt::Display for $name {
54 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
55 std::fmt::Display::fmt(self.as_ref(), f)
56 }
57 }
58 };
59}
60
61declare_identifier!(ClientName);
62declare_identifier!(ClientId);
63
64pub static CLIENT_ID_HEADER: HeaderName = HeaderName::from_static("x-client-id");
68
69impl<S> OptionalFromRequestParts<S> for ClientId
70where
71 S: Send + Sync,
72{
73 type Rejection = Infallible;
74
75 async fn from_request_parts(
76 parts: &mut Parts,
77 _state: &S,
78 ) -> Result<Option<Self>, Self::Rejection> {
79 let Some(client_id) = parts.headers.get(&CLIENT_ID_HEADER) else {
80 return Ok(None);
81 };
82 Ok(client_id.to_str().ok().map(ClientId::from))
83 }
84}
85
86#[cfg(test)]
87mod tests {
88 use std::collections::HashMap;
89
90 #[test]
91 fn declare_identifier_compare() {
92 declare_identifier!(ConnectionId);
93 let c1: ConnectionId = "123".into();
94 let c2: ConnectionId = "123".to_string().into();
95 assert_eq!(c1, c2);
96
97 let c3: ConnectionId = "124".to_string().into();
98 assert_ne!(c1, c3);
99
100 assert!(c1 == c2);
101 assert!(c1 < c3);
102 }
103
104 #[test]
105 fn declare_identifier_hash() {
106 declare_identifier!(ConnectionId);
107 let c1: ConnectionId = "123".into();
108 let c2: ConnectionId = "124".to_string().into();
109
110 let mut map: HashMap<ConnectionId, i32> = HashMap::new();
111 map.insert(c1.clone(), 21);
112 map.insert(c2.clone(), 34);
113
114 assert_eq!(map[&c1], 21);
115 assert_eq!(map[&c2], 34);
116 }
117
118 #[test]
119 fn declare_identifier_serde() {
120 declare_identifier!(ConnectionId);
121 let c = ConnectionId::from("ABC123");
122 let s = serde_json::to_string(&c).unwrap();
123 assert_eq!("\"ABC123\"", s);
124 let cc = serde_json::from_str(&s).unwrap();
125 assert_eq!(c, cc);
126 }
127}