Skip to main content

ssi_jwt/claims/
matching.rs

1use std::borrow::Cow;
2
3use crate::Claim;
4
5/// Dynamic claim type matching.
6///
7/// # Usage
8///
9/// There are two ways to use this macro.
10/// The first one is to simply match on the value of a claim type parameter:
11/// ```ignore
12/// match_claim_type! {
13///   match MyClaimTypeParameter {
14///     TypeA => { ... },
15///     TypeB => { ... },
16///     _ => { ... }
17///   }
18/// }
19/// ```
20///
21/// The second one also allows you to properly cast a claim variable.
22/// ```ignore
23/// match_claim_type! {
24///   match claim: MyClaimTypeParameter {
25///     TypeA => {
26///       // In this block, `claim` has type `TypeA`.
27///       ...
28///     },
29///     TypeB => {
30///       // In this block, `claim` has type `TypeB`.
31///       ...
32///     },
33///     _ => {
34///       // In this block, `claim` has type `MyClaimTypeParameter`.
35///       ...
36///     },
37///   }
38/// }
39/// ```
40#[macro_export]
41macro_rules! match_claim_type {
42    {
43        match $id:ident {
44            $($ty:ident => $e:expr,)*
45            _ => $default_case:expr
46        }
47    } => {
48        $(
49            if std::any::TypeId::of::<$id>() == std::any::TypeId::of::<$ty>() {
50				let result = $e;
51				return unsafe {
52                    // SAFETY: We just checked that `$ty` is equal to `$id`.
53                    $crate::CastClaim::<$ty, $id>::cast_claim(result)
54                };
55			}
56        )*
57
58        $default_case
59    };
60	{
61        match $x:ident: $id:ident {
62            $($ty:ident => $e:expr,)*
63            _ => $default_case:expr
64        }
65    } => {
66        $(
67            if std::any::TypeId::of::<$id>() == std::any::TypeId::of::<$ty>() {
68				let $x: $ty = unsafe {
69                    // SAFETY: We just checked that `$ty` is equal to `$id`.
70                    $crate::CastClaim::<$id, $ty>::cast_claim($x)
71                };
72				let result = $e;
73				return unsafe {
74                    // SAFETY: We just checked that `$ty` is equal to `$id`.
75                    $crate::CastClaim::<$ty, $id>::cast_claim(result)
76                };
77			}
78        )*
79
80        $default_case
81    };
82}
83
84/// Cast claim type `A` into `B`.
85pub trait CastClaim<A, B>: Sized {
86    type Target;
87
88    /// Cast claim type `A` into `B`.
89    ///
90    /// # Safety
91    ///
92    /// `A` **must** be equal to `B`.
93    unsafe fn cast_claim(value: Self) -> Self::Target;
94}
95
96impl<A: Claim, B: Claim> CastClaim<A, B> for A {
97    type Target = B;
98
99    unsafe fn cast_claim(value: Self) -> Self::Target {
100        // SAFETY: The precondition to this function is that `A` (`Self`) is
101        //         equal to `B`, meaning that the transmute does nothing.
102        //         We are just copying `value`, and forgetting the original.
103        let result = std::mem::transmute_copy(&value);
104        std::mem::forget(value);
105        result
106    }
107}
108
109impl<'a, A: Claim, B: Claim> CastClaim<A, B> for &'a A {
110    type Target = &'a B;
111
112    unsafe fn cast_claim(value: Self) -> Self::Target {
113        std::mem::transmute_copy(&value)
114    }
115}
116
117impl<A, B> CastClaim<A, B> for () {
118    type Target = Self;
119
120    unsafe fn cast_claim(value: Self) -> Self::Target {
121        value
122    }
123}
124
125impl<A, B> CastClaim<A, B> for bool {
126    type Target = Self;
127
128    unsafe fn cast_claim(value: Self) -> Self::Target {
129        value
130    }
131}
132
133impl<A, B, T: CastClaim<A, B>> CastClaim<A, B> for Option<T> {
134    type Target = Option<T::Target>;
135
136    unsafe fn cast_claim(value: Self) -> Self::Target {
137        value.map(|t| T::cast_claim(t))
138    }
139}
140
141impl<A, B, T: CastClaim<A, B>, E> CastClaim<A, B> for Result<T, E> {
142    type Target = Result<T::Target, E>;
143
144    unsafe fn cast_claim(value: Self) -> Self::Target {
145        value.map(|t| T::cast_claim(t))
146    }
147}
148
149impl<'a, A: Claim, B: Claim> CastClaim<A, B> for Cow<'a, A> {
150    type Target = Cow<'a, B>;
151
152    unsafe fn cast_claim(value: Self) -> Self::Target {
153        match value {
154            Self::Owned(value) => Cow::Owned(CastClaim::cast_claim(value)),
155            Self::Borrowed(value) => Cow::Borrowed(CastClaim::cast_claim(value)),
156        }
157    }
158}
159
160#[cfg(test)]
161mod tests {
162    use serde::{Deserialize, Serialize};
163    use std::borrow::Cow;
164
165    use crate::{AnyClaims, Claim, ClaimSet, InvalidClaimValue};
166
167    #[derive(Clone, Serialize, Deserialize)]
168    struct CustomClaim;
169
170    impl Claim for CustomClaim {
171        const JWT_CLAIM_NAME: &'static str = "custom";
172    }
173
174    #[allow(unused)]
175    struct CustomClaimSet {
176        custom: Option<CustomClaim>,
177        other_claims: AnyClaims,
178    }
179
180    impl ClaimSet for CustomClaimSet {
181        fn contains<C: Claim>(&self) -> bool {
182            match_claim_type! {
183                match C {
184                    CustomClaim => self.custom.is_some(),
185                    _ => ClaimSet::contains::<C>(&self.other_claims)
186                }
187            }
188        }
189
190        fn try_get<C: Claim>(&self) -> Result<Option<Cow<'_, C>>, InvalidClaimValue> {
191            match_claim_type! {
192                match C {
193                    CustomClaim => {
194                        Ok(self.custom.as_ref().map(Cow::Borrowed))
195                    },
196                    _ => {
197                        ClaimSet::try_get::<C>(&self.other_claims)
198                    }
199                }
200            }
201        }
202
203        fn try_set<C: Claim>(&mut self, claim: C) -> Result<Result<(), C>, InvalidClaimValue> {
204            match_claim_type! {
205                match claim: C {
206                    CustomClaim => {
207                        self.custom = Some(claim);
208                        Ok(Ok(()))
209                    },
210                    _ => {
211                        ClaimSet::try_set(&mut self.other_claims, claim)
212                    }
213                }
214            }
215        }
216
217        fn try_remove<C: Claim>(&mut self) -> Result<Option<C>, InvalidClaimValue> {
218            match_claim_type! {
219                match C {
220                    CustomClaim => {
221                        Ok(self.custom.take())
222                    },
223                    _ => {
224                        ClaimSet::try_remove::<C>(&mut self.other_claims)
225                    }
226                }
227            }
228        }
229    }
230}