1use super::{Claim, InvalidClaimValue, JWTClaims};
2use crate::{CastClaim, ClaimSet, InfallibleClaimSet, NumericDate, StringOrURI};
3use chrono::{DateTime, Utc};
4use ssi_claims_core::{ClaimsValidity, DateTimeProvider, InvalidClaims, ValidateClaims};
5use ssi_core::OneOrMany;
6use ssi_jws::JwsPayload;
7use std::{borrow::Cow, collections::BTreeMap};
8
9pub trait RegisteredClaim: Claim + Into<AnyRegisteredClaim> {
10 const JWT_REGISTERED_CLAIM_KIND: RegisteredClaimKind;
11
12 fn extract(claim: AnyRegisteredClaim) -> Option<Self>;
13
14 fn extract_ref(claim: &AnyRegisteredClaim) -> Option<&Self>;
15
16 fn extract_mut(claim: &mut AnyRegisteredClaim) -> Option<&mut Self>;
17}
18
19#[derive(Debug, Default, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
20pub struct RegisteredClaims(BTreeMap<RegisteredClaimKind, AnyRegisteredClaim>);
21
22impl RegisteredClaims {
23 pub fn new() -> Self {
24 Self::default()
25 }
26
27 pub fn is_empty(&self) -> bool {
28 self.0.is_empty()
29 }
30
31 pub fn len(&self) -> usize {
32 self.0.len()
33 }
34
35 pub fn iter(&'_ self) -> RegisteredClaimsIter<'_> {
36 self.0.values()
37 }
38
39 pub fn contains<C: RegisteredClaim>(&self) -> bool {
40 self.0.contains_key(&C::JWT_REGISTERED_CLAIM_KIND)
41 }
42
43 pub fn get<C: RegisteredClaim>(&self) -> Option<&C> {
44 self.0
45 .get(&C::JWT_REGISTERED_CLAIM_KIND)
46 .and_then(C::extract_ref)
47 }
48
49 pub fn get_mut<C: RegisteredClaim>(&mut self) -> Option<&mut C> {
50 self.0
51 .get_mut(&C::JWT_REGISTERED_CLAIM_KIND)
52 .and_then(C::extract_mut)
53 }
54
55 pub fn set<C: RegisteredClaim>(&mut self, claim: C) -> Option<C> {
56 self.0
57 .insert(C::JWT_REGISTERED_CLAIM_KIND, claim.into())
58 .and_then(C::extract)
59 }
60
61 pub fn insert_any(&mut self, claim: AnyRegisteredClaim) -> Option<AnyRegisteredClaim> {
62 self.0.insert(claim.kind(), claim)
63 }
64
65 pub fn remove<C: RegisteredClaim>(&mut self) -> Option<C> {
66 self.0
67 .remove(&C::JWT_REGISTERED_CLAIM_KIND)
68 .and_then(C::extract)
69 }
70
71 pub fn with_private_claims<P>(self, claims: P) -> JWTClaims<P> {
72 JWTClaims {
73 registered: self,
74 private: claims,
75 }
76 }
77}
78
79impl InfallibleClaimSet for RegisteredClaims {}
80
81impl JwsPayload for RegisteredClaims {
82 fn typ(&self) -> Option<&'static str> {
83 Some("JWT")
84 }
85
86 fn payload_bytes(&'_ self) -> Cow<'_, [u8]> {
87 Cow::Owned(serde_json::to_vec(self).unwrap())
88 }
89}
90
91impl<E, P> ValidateClaims<E, P> for RegisteredClaims
92where
93 E: DateTimeProvider,
94{
95 fn validate_claims(&self, env: &E, _proof: &P) -> ClaimsValidity {
96 ClaimSet::validate_registered_claims(self, env)
97 }
98}
99
100pub type RegisteredClaimsIter<'a> =
101 std::collections::btree_map::Values<'a, RegisteredClaimKind, AnyRegisteredClaim>;
102
103impl<'a> IntoIterator for &'a RegisteredClaims {
104 type IntoIter = RegisteredClaimsIter<'a>;
105 type Item = &'a AnyRegisteredClaim;
106
107 fn into_iter(self) -> Self::IntoIter {
108 self.iter()
109 }
110}
111
112impl serde::Serialize for RegisteredClaims {
113 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
114 where
115 S: serde::Serializer,
116 {
117 use serde::ser::SerializeMap;
118 let mut map = serializer.serialize_map(Some(self.0.len()))?;
119
120 for claim in self {
121 claim.serialize(&mut map)?;
122 }
123
124 map.end()
125 }
126}
127
128impl<'de> serde::Deserialize<'de> for RegisteredClaims {
129 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
130 where
131 D: serde::Deserializer<'de>,
132 {
133 struct Visitor;
134
135 impl<'de> serde::de::Visitor<'de> for Visitor {
136 type Value = RegisteredClaims;
137
138 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
139 write!(formatter, "JWT registered claim set")
140 }
141
142 fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
143 where
144 A: serde::de::MapAccess<'de>,
145 {
146 let mut result = RegisteredClaims::new();
147 while let Some(kind) = map.next_key::<RegisteredClaimKind>()? {
148 let claim = kind.deserialize_value(&mut map)?;
149 result.insert_any(claim);
150 }
151 Ok(result)
152 }
153 }
154
155 deserializer.deserialize_map(Visitor)
156 }
157}
158
159pub trait TryIntoClaim<C> {
160 type Error;
161
162 fn try_into_claim(self) -> Result<C, Self::Error>;
163}
164
165macro_rules! registered_claims {
166 ($($(#[$meta:meta])* $name:literal: $variant:ident ( $ty:ty )),*) => {
167 $(
168 $(#[$meta])*
169 #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
170 #[derive(serde::Serialize, serde::Deserialize)]
171 #[serde(transparent)]
172 pub struct $variant(pub $ty);
173
174 impl Claim for $variant {
175 const JWT_CLAIM_NAME: &'static str = $name;
176 }
177
178 impl<T> TryIntoClaim<$variant> for T
179 where
180 T: TryInto<$ty>
181 {
182 type Error = T::Error;
183
184 fn try_into_claim(self) -> Result<$variant, Self::Error> {
185 self.try_into().map($variant)
186 }
187 }
188
189 impl RegisteredClaim for $variant {
190 const JWT_REGISTERED_CLAIM_KIND: RegisteredClaimKind = RegisteredClaimKind::$variant;
191
192 fn extract(claim: AnyRegisteredClaim) -> Option<Self> {
193 match claim {
194 AnyRegisteredClaim::$variant(value) => Some(value),
195 _ => None
196 }
197 }
198
199 fn extract_ref(claim: &AnyRegisteredClaim) -> Option<&Self> {
200 match claim {
201 AnyRegisteredClaim::$variant(value) => Some(value),
202 _ => None
203 }
204 }
205
206 fn extract_mut(claim: &mut AnyRegisteredClaim) -> Option<&mut Self> {
207 match claim {
208 AnyRegisteredClaim::$variant(value) => Some(value),
209 _ => None
210 }
211 }
212 }
213 )*
214
215 impl ClaimSet for RegisteredClaims {
216 fn contains<C: Claim>(&self) -> bool {
217 $(
218 if std::any::TypeId::of::<C>() == std::any::TypeId::of::<$variant>() {
219 return self.contains::<$variant>();
220 }
221 )*
222
223 false
224 }
225
226 fn try_get<C: Claim>(&'_ self) -> Result<Option<Cow<'_, C>>, InvalidClaimValue> {
227 $(
228 if std::any::TypeId::of::<C>() == std::any::TypeId::of::<$variant>() {
229 return Ok(unsafe { CastClaim::cast_claim(self.get::<$variant>()) }.map(Cow::Borrowed));
230 }
231 )*
232
233 Ok(None)
234 }
235
236 fn try_set<C: Claim>(&mut self, claim: C) -> Result<Result<(), C>, InvalidClaimValue> {
237 $(
238 if std::any::TypeId::of::<C>() == std::any::TypeId::of::<$variant>() {
239 self.set::<$variant>(unsafe { CastClaim::cast_claim(claim) });
240 return Ok(Ok(()))
241 }
242 )*
243
244 Ok(Err(claim))
245 }
246
247 fn try_remove<C: Claim>(&mut self) -> Result<Option<C>, InvalidClaimValue> {
248 $(
249 if std::any::TypeId::of::<C>() == std::any::TypeId::of::<$variant>() {
250 return Ok(unsafe { CastClaim::cast_claim(self.remove::<$variant>()) });
251 }
252 )*
253
254 Ok(None)
255 }
256 }
257
258 #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
259 pub enum RegisteredClaimKind {
260 $(
261 $variant
262 ),*
263 }
264
265 impl RegisteredClaimKind {
266 pub fn new(s: &str) -> Option<Self> {
267 match s {
268 $(
269 $name => Some(Self::$variant),
270 )*
271 _ => None
272 }
273 }
274
275 pub fn as_str(&self) -> &'static str {
276 match self {
277 $(
278 Self::$variant => $name,
279 )*
280 }
281 }
282
283 pub(crate) fn deserialize_value<'de, M: serde::de::MapAccess<'de>>(&self, map: &mut M) -> Result<AnyRegisteredClaim, M::Error> {
284 match self {
285 $(
286 Self::$variant => {
287 map.next_value().map(AnyRegisteredClaim::$variant)
288 }
289 ),*
290 }
291 }
292 }
293
294 impl<'de> serde::Deserialize<'de> for RegisteredClaimKind {
295 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
296 where
297 D: serde::Deserializer<'de>
298 {
299 let name = String::deserialize(deserializer)?;
300 match Self::new(&name) {
301 Some(r) => Ok(r),
302 None => Err(serde::de::Error::custom(format!("unknown registered claim `{}`", name)))
303 }
304 }
305 }
306
307 #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
308 pub enum AnyRegisteredClaim {
309 $(
310 $variant($variant)
311 ),*
312 }
313
314 impl AnyRegisteredClaim {
315 pub fn kind(&self) -> RegisteredClaimKind {
316 match self {
317 $(
318 Self::$variant(_) => RegisteredClaimKind::$variant
319 ),*
320 }
321 }
322
323 fn serialize<S: serde::ser::SerializeMap>(&self, serializer: &mut S) -> Result<(), S::Error> {
324 match self {
325 $(
326 Self::$variant(value) => {
327 serializer.serialize_entry(
328 $name,
329 value
330 )
331 }
332 ),*
333 }
334 }
335 }
336
337 $(
338 impl From<$variant> for AnyRegisteredClaim {
339 fn from(value: $variant) -> Self {
340 Self::$variant(value)
341 }
342 }
343 )*
344 };
345}
346
347registered_claims! {
348 "iss": Issuer(StringOrURI),
353
354 "sub": Subject(StringOrURI),
363
364 "aud": Audience(OneOrMany<StringOrURI>),
372
373 #[derive(Copy)]
380 "exp": ExpirationTime(NumericDate),
381
382 #[derive(Copy)]
390 "nbf": NotBefore(NumericDate),
391
392 #[derive(Copy)]
397 "iat": IssuedAt(NumericDate),
398
399 "jti": JwtId(String),
409
410 "nonce": Nonce(String),
411
412 "vc": VerifiableCredential(json_syntax::Value),
413
414 "vp": VerifiablePresentation(json_syntax::Value)
415}
416
417pub enum JwtClaimValidationFailed {
418 Premature {
419 now: DateTime<Utc>,
420 valid_from: DateTime<Utc>,
421 },
422 Expired {
423 now: DateTime<Utc>,
424 valid_until: DateTime<Utc>,
425 },
426}
427
428impl From<JwtClaimValidationFailed> for InvalidClaims {
429 fn from(value: JwtClaimValidationFailed) -> Self {
430 match value {
431 JwtClaimValidationFailed::Premature { now, valid_from } => {
432 Self::Premature { now, valid_from }
433 }
434 JwtClaimValidationFailed::Expired { now, valid_until } => {
435 Self::Expired { now, valid_until }
436 }
437 }
438 }
439}
440
441impl ExpirationTime {
442 pub fn verify(&self, now: DateTime<Utc>) -> Result<(), JwtClaimValidationFailed> {
443 let exp: DateTime<Utc> = self.0.into();
444 if exp > now {
445 Ok(())
446 } else {
447 Err(JwtClaimValidationFailed::Expired {
448 now,
449 valid_until: exp,
450 })
451 }
452 }
453}
454
455impl NotBefore {
456 pub fn verify(&self, now: DateTime<Utc>) -> Result<(), JwtClaimValidationFailed> {
457 let nbf: DateTime<Utc> = self.0.into();
458 if nbf <= now {
459 Ok(())
460 } else {
461 Err(JwtClaimValidationFailed::Premature {
462 now,
463 valid_from: nbf,
464 })
465 }
466 }
467}
468
469impl IssuedAt {
470 pub fn now() -> Self {
471 Self(Utc::now().into())
472 }
473
474 pub fn verify(&self, now: DateTime<Utc>) -> Result<(), JwtClaimValidationFailed> {
475 let iat: DateTime<Utc> = self.0.into();
476 if iat <= now {
477 Ok(())
478 } else {
479 Err(JwtClaimValidationFailed::Premature {
480 now,
481 valid_from: iat,
482 })
483 }
484 }
485}