Skip to main content

ssi_jwt/claims/
registered.rs

1use super::{Claim, InvalidClaimValue, JWTClaims};
2use crate::{CastClaim, ClaimSet, InfallibleClaimSet, NumericDate, StringOrURI};
3use chrono::{DateTime, Utc};
4use ssi_claims_core::{ClaimsValidity, DateTimeProvider, InvalidClaims, ValidateClaims};
5use ssi_core::OneOrMany;
6use ssi_jws::JwsPayload;
7use std::{borrow::Cow, collections::BTreeMap};
8
9pub trait RegisteredClaim: Claim + Into<AnyRegisteredClaim> {
10    const JWT_REGISTERED_CLAIM_KIND: RegisteredClaimKind;
11
12    fn extract(claim: AnyRegisteredClaim) -> Option<Self>;
13
14    fn extract_ref(claim: &AnyRegisteredClaim) -> Option<&Self>;
15
16    fn extract_mut(claim: &mut AnyRegisteredClaim) -> Option<&mut Self>;
17}
18
19#[derive(Debug, Default, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
20pub struct RegisteredClaims(BTreeMap<RegisteredClaimKind, AnyRegisteredClaim>);
21
22impl RegisteredClaims {
23    pub fn new() -> Self {
24        Self::default()
25    }
26
27    pub fn is_empty(&self) -> bool {
28        self.0.is_empty()
29    }
30
31    pub fn len(&self) -> usize {
32        self.0.len()
33    }
34
35    pub fn iter(&'_ self) -> RegisteredClaimsIter<'_> {
36        self.0.values()
37    }
38
39    pub fn contains<C: RegisteredClaim>(&self) -> bool {
40        self.0.contains_key(&C::JWT_REGISTERED_CLAIM_KIND)
41    }
42
43    pub fn get<C: RegisteredClaim>(&self) -> Option<&C> {
44        self.0
45            .get(&C::JWT_REGISTERED_CLAIM_KIND)
46            .and_then(C::extract_ref)
47    }
48
49    pub fn get_mut<C: RegisteredClaim>(&mut self) -> Option<&mut C> {
50        self.0
51            .get_mut(&C::JWT_REGISTERED_CLAIM_KIND)
52            .and_then(C::extract_mut)
53    }
54
55    pub fn set<C: RegisteredClaim>(&mut self, claim: C) -> Option<C> {
56        self.0
57            .insert(C::JWT_REGISTERED_CLAIM_KIND, claim.into())
58            .and_then(C::extract)
59    }
60
61    pub fn insert_any(&mut self, claim: AnyRegisteredClaim) -> Option<AnyRegisteredClaim> {
62        self.0.insert(claim.kind(), claim)
63    }
64
65    pub fn remove<C: RegisteredClaim>(&mut self) -> Option<C> {
66        self.0
67            .remove(&C::JWT_REGISTERED_CLAIM_KIND)
68            .and_then(C::extract)
69    }
70
71    pub fn with_private_claims<P>(self, claims: P) -> JWTClaims<P> {
72        JWTClaims {
73            registered: self,
74            private: claims,
75        }
76    }
77}
78
79impl InfallibleClaimSet for RegisteredClaims {}
80
81impl JwsPayload for RegisteredClaims {
82    fn typ(&self) -> Option<&'static str> {
83        Some("JWT")
84    }
85
86    fn payload_bytes(&'_ self) -> Cow<'_, [u8]> {
87        Cow::Owned(serde_json::to_vec(self).unwrap())
88    }
89}
90
91impl<E, P> ValidateClaims<E, P> for RegisteredClaims
92where
93    E: DateTimeProvider,
94{
95    fn validate_claims(&self, env: &E, _proof: &P) -> ClaimsValidity {
96        ClaimSet::validate_registered_claims(self, env)
97    }
98}
99
100pub type RegisteredClaimsIter<'a> =
101    std::collections::btree_map::Values<'a, RegisteredClaimKind, AnyRegisteredClaim>;
102
103impl<'a> IntoIterator for &'a RegisteredClaims {
104    type IntoIter = RegisteredClaimsIter<'a>;
105    type Item = &'a AnyRegisteredClaim;
106
107    fn into_iter(self) -> Self::IntoIter {
108        self.iter()
109    }
110}
111
112impl serde::Serialize for RegisteredClaims {
113    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
114    where
115        S: serde::Serializer,
116    {
117        use serde::ser::SerializeMap;
118        let mut map = serializer.serialize_map(Some(self.0.len()))?;
119
120        for claim in self {
121            claim.serialize(&mut map)?;
122        }
123
124        map.end()
125    }
126}
127
128impl<'de> serde::Deserialize<'de> for RegisteredClaims {
129    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
130    where
131        D: serde::Deserializer<'de>,
132    {
133        struct Visitor;
134
135        impl<'de> serde::de::Visitor<'de> for Visitor {
136            type Value = RegisteredClaims;
137
138            fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
139                write!(formatter, "JWT registered claim set")
140            }
141
142            fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
143            where
144                A: serde::de::MapAccess<'de>,
145            {
146                let mut result = RegisteredClaims::new();
147                while let Some(kind) = map.next_key::<RegisteredClaimKind>()? {
148                    let claim = kind.deserialize_value(&mut map)?;
149                    result.insert_any(claim);
150                }
151                Ok(result)
152            }
153        }
154
155        deserializer.deserialize_map(Visitor)
156    }
157}
158
159pub trait TryIntoClaim<C> {
160    type Error;
161
162    fn try_into_claim(self) -> Result<C, Self::Error>;
163}
164
165macro_rules! registered_claims {
166    ($($(#[$meta:meta])* $name:literal: $variant:ident ( $ty:ty )),*) => {
167        $(
168            $(#[$meta])*
169            #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
170            #[derive(serde::Serialize, serde::Deserialize)]
171            #[serde(transparent)]
172            pub struct $variant(pub $ty);
173
174            impl Claim for $variant {
175                const JWT_CLAIM_NAME: &'static str = $name;
176            }
177
178            impl<T> TryIntoClaim<$variant> for T
179            where
180                T: TryInto<$ty>
181            {
182                type Error = T::Error;
183
184                fn try_into_claim(self) -> Result<$variant, Self::Error> {
185                    self.try_into().map($variant)
186                }
187            }
188
189            impl RegisteredClaim for $variant {
190                const JWT_REGISTERED_CLAIM_KIND: RegisteredClaimKind = RegisteredClaimKind::$variant;
191
192                fn extract(claim: AnyRegisteredClaim) -> Option<Self> {
193                    match claim {
194                        AnyRegisteredClaim::$variant(value) => Some(value),
195                        _ => None
196                    }
197                }
198
199                fn extract_ref(claim: &AnyRegisteredClaim) -> Option<&Self> {
200                    match claim {
201                        AnyRegisteredClaim::$variant(value) => Some(value),
202                        _ => None
203                    }
204                }
205
206                fn extract_mut(claim: &mut AnyRegisteredClaim) -> Option<&mut Self> {
207                    match claim {
208                        AnyRegisteredClaim::$variant(value) => Some(value),
209                        _ => None
210                    }
211                }
212            }
213        )*
214
215        impl ClaimSet for RegisteredClaims {
216            fn contains<C: Claim>(&self) -> bool {
217                $(
218                    if std::any::TypeId::of::<C>() == std::any::TypeId::of::<$variant>() {
219                        return self.contains::<$variant>();
220                    }
221                )*
222
223                false
224            }
225
226            fn try_get<C: Claim>(&'_ self) -> Result<Option<Cow<'_, C>>, InvalidClaimValue> {
227                $(
228                    if std::any::TypeId::of::<C>() == std::any::TypeId::of::<$variant>() {
229                        return Ok(unsafe { CastClaim::cast_claim(self.get::<$variant>()) }.map(Cow::Borrowed));
230                    }
231                )*
232
233                Ok(None)
234            }
235
236            fn try_set<C: Claim>(&mut self, claim: C) -> Result<Result<(), C>, InvalidClaimValue> {
237                $(
238                    if std::any::TypeId::of::<C>() == std::any::TypeId::of::<$variant>() {
239                        self.set::<$variant>(unsafe { CastClaim::cast_claim(claim) });
240                        return Ok(Ok(()))
241                    }
242                )*
243
244                Ok(Err(claim))
245            }
246
247            fn try_remove<C: Claim>(&mut self) -> Result<Option<C>, InvalidClaimValue> {
248                $(
249                    if std::any::TypeId::of::<C>() == std::any::TypeId::of::<$variant>() {
250                        return Ok(unsafe { CastClaim::cast_claim(self.remove::<$variant>()) });
251                    }
252                )*
253
254                Ok(None)
255            }
256        }
257
258        #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
259        pub enum RegisteredClaimKind {
260            $(
261                $variant
262            ),*
263        }
264
265        impl RegisteredClaimKind {
266            pub fn new(s: &str) -> Option<Self> {
267                match s {
268                    $(
269                        $name => Some(Self::$variant),
270                    )*
271                    _ => None
272                }
273            }
274
275            pub fn as_str(&self) -> &'static str {
276                match self {
277                    $(
278                        Self::$variant => $name,
279                    )*
280                }
281            }
282
283            pub(crate) fn deserialize_value<'de, M: serde::de::MapAccess<'de>>(&self, map: &mut M) -> Result<AnyRegisteredClaim, M::Error> {
284                match self {
285                    $(
286                        Self::$variant => {
287                            map.next_value().map(AnyRegisteredClaim::$variant)
288                        }
289                    ),*
290                }
291            }
292        }
293
294		impl<'de> serde::Deserialize<'de> for RegisteredClaimKind {
295			fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
296			where
297				D: serde::Deserializer<'de>
298			{
299				let name = String::deserialize(deserializer)?;
300				match Self::new(&name) {
301					Some(r) => Ok(r),
302					None => Err(serde::de::Error::custom(format!("unknown registered claim `{}`", name)))
303				}
304			}
305		}
306
307        #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
308        pub enum AnyRegisteredClaim {
309            $(
310                $variant($variant)
311            ),*
312        }
313
314        impl AnyRegisteredClaim {
315            pub fn kind(&self) -> RegisteredClaimKind {
316                match self {
317                    $(
318                        Self::$variant(_) => RegisteredClaimKind::$variant
319                    ),*
320                }
321            }
322
323            fn serialize<S: serde::ser::SerializeMap>(&self, serializer: &mut S) -> Result<(), S::Error> {
324                match self {
325                    $(
326                        Self::$variant(value) => {
327                            serializer.serialize_entry(
328                                $name,
329                                value
330                            )
331                        }
332                    ),*
333                }
334            }
335        }
336
337        $(
338            impl From<$variant> for AnyRegisteredClaim {
339                fn from(value: $variant) -> Self {
340                    Self::$variant(value)
341                }
342            }
343        )*
344    };
345}
346
347registered_claims! {
348    /// Issuer (`iss`) claim.
349    ///
350    /// Principal that issued the JWT. The processing of this claim is generally
351    /// application specific.
352    "iss": Issuer(StringOrURI),
353
354    /// Subject (`sub`) claim.
355    ///
356    /// Principal that is the subject of the JWT. The claims in a JWT are
357    /// normally statements about the subject. The subject value MUST either be
358    /// scoped to be locally unique in the context of the issuer or be globally
359    /// unique.
360    ///
361    /// The processing of this claim is generally application specific.
362    "sub": Subject(StringOrURI),
363
364    /// Audience (`aud`) claim.
365    ///
366    /// Recipients that the JWT is intended for. Each principal intended to
367    /// process the JWT MUST identify itself with a value in the audience claim.
368    /// If the principal processing the claim does not identify itself with a
369    /// value in the `aud` claim when this claim is present, then the JWT MUST
370    /// be rejected.
371    "aud": Audience(OneOrMany<StringOrURI>),
372
373    /// Expiration Time (`exp`) claim.
374    ///
375    /// Expiration time on or after which the JWT MUST NOT be accepted for
376    /// processing. The processing of the `exp` claim requires that the current
377    /// date/time MUST be before the expiration date/time listed in the `exp`
378    /// claim.
379    #[derive(Copy)]
380    "exp": ExpirationTime(NumericDate),
381
382    /// Not Before (`nbf`) claim.
383    ///
384    /// Time before which the JWT MUST NOT be accepted for processing. The
385    /// processing of the `nbf` claim requires that the current date/time MUST
386    /// be after or equal to the not-before date/time listed in the "nbf" claim.
387    /// Implementers MAY provide for some small leeway, usually no more than a
388    /// few minutes, to account for clock skew.
389    #[derive(Copy)]
390    "nbf": NotBefore(NumericDate),
391
392    /// Issued At (`iat`) claim.
393    ///
394    /// Time at which the JWT was issued. This claim can be used to determine
395    /// the age of the JWT.
396    #[derive(Copy)]
397    "iat": IssuedAt(NumericDate),
398
399    /// JWT ID (`jti`) claim.
400    ///
401    /// Unique identifier for the JWT. The identifier value MUST be assigned in
402    /// a manner that ensures that there is a negligible probability that the
403    /// same value will be accidentally assigned to a different data object; if
404    /// the application uses multiple issuers, collisions MUST be prevented
405    /// among values produced by different issuers as well.
406    ///
407    /// The "jti" claim can be used to prevent the JWT from being replayed.
408    "jti": JwtId(String),
409
410    "nonce": Nonce(String),
411
412    "vc": VerifiableCredential(json_syntax::Value),
413
414    "vp": VerifiablePresentation(json_syntax::Value)
415}
416
417pub enum JwtClaimValidationFailed {
418    Premature {
419        now: DateTime<Utc>,
420        valid_from: DateTime<Utc>,
421    },
422    Expired {
423        now: DateTime<Utc>,
424        valid_until: DateTime<Utc>,
425    },
426}
427
428impl From<JwtClaimValidationFailed> for InvalidClaims {
429    fn from(value: JwtClaimValidationFailed) -> Self {
430        match value {
431            JwtClaimValidationFailed::Premature { now, valid_from } => {
432                Self::Premature { now, valid_from }
433            }
434            JwtClaimValidationFailed::Expired { now, valid_until } => {
435                Self::Expired { now, valid_until }
436            }
437        }
438    }
439}
440
441impl ExpirationTime {
442    pub fn verify(&self, now: DateTime<Utc>) -> Result<(), JwtClaimValidationFailed> {
443        let exp: DateTime<Utc> = self.0.into();
444        if exp > now {
445            Ok(())
446        } else {
447            Err(JwtClaimValidationFailed::Expired {
448                now,
449                valid_until: exp,
450            })
451        }
452    }
453}
454
455impl NotBefore {
456    pub fn verify(&self, now: DateTime<Utc>) -> Result<(), JwtClaimValidationFailed> {
457        let nbf: DateTime<Utc> = self.0.into();
458        if nbf <= now {
459            Ok(())
460        } else {
461            Err(JwtClaimValidationFailed::Premature {
462                now,
463                valid_from: nbf,
464            })
465        }
466    }
467}
468
469impl IssuedAt {
470    pub fn now() -> Self {
471        Self(Utc::now().into())
472    }
473
474    pub fn verify(&self, now: DateTime<Utc>) -> Result<(), JwtClaimValidationFailed> {
475        let iat: DateTime<Utc> = self.0.into();
476        if iat <= now {
477            Ok(())
478        } else {
479            Err(JwtClaimValidationFailed::Premature {
480                now,
481                valid_from: iat,
482            })
483        }
484    }
485}