1use std::borrow::Cow;
2
3use serde::{Deserialize, Serialize};
4use ssi_claims_core::{ClaimsValidity, DateTimeProvider, InvalidClaims, ValidateClaims};
5use ssi_jwk::Algorithm;
6use ssi_jws::{JwsPayload, ValidateJwsHeader};
7use ssi_jwt::{Claim, ExpirationTime, IssuedAt, Nonce, NotBefore};
8
9use crate::{SdAlg, SdJwt};
10
11pub const KB_JWT_TYP: &str = "kb+jwt";
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct KbJwtPayload<T = serde_json::Map<String, serde_json::Value>> {
17 pub iat: IssuedAt,
19
20 pub aud: String,
22
23 pub nonce: Nonce,
25
26 pub sd_hash: SdHash,
28
29 pub exp: Option<ExpirationTime>,
31
32 pub nbf: Option<NotBefore>,
34
35 #[serde(flatten)]
37 pub claims: T,
38}
39
40impl KbJwtPayload {
41 pub fn new(aud: String, nonce: String, sd_alg: SdAlg, sd_jwt: &SdJwt) -> Self {
43 Self {
44 iat: IssuedAt::now(),
45 aud,
46 nonce: Nonce(nonce),
47 sd_hash: SdHash::new(sd_alg, sd_jwt),
48 exp: None,
49 nbf: None,
50 claims: Default::default(),
51 }
52 }
53}
54
55impl JwsPayload for KbJwtPayload {
56 fn typ(&self) -> Option<&str> {
57 Some(KB_JWT_TYP)
58 }
59
60 fn payload_bytes(&self) -> Cow<'_, [u8]> {
61 Cow::Owned(serde_json::to_vec(self).unwrap())
62 }
63}
64
65impl<E> ValidateJwsHeader<E> for KbJwtPayload {
66 fn validate_jws_header(&self, _params: &E, header: &ssi_jws::Header) -> ClaimsValidity {
67 if header.type_.as_deref() != Some(KB_JWT_TYP) {
68 return Err(InvalidClaims::other("invalid JWT type"));
69 }
70
71 if header.algorithm == Algorithm::None {
72 return Err(InvalidClaims::other("algorithm can't be `none`"));
73 }
74
75 Ok(())
76 }
77}
78
79impl<E, P> ValidateClaims<E, P> for KbJwtPayload
80where
81 E: DateTimeProvider,
82{
83 fn validate_claims(&self, params: &E, _proof: &P) -> ClaimsValidity {
84 let now = params.date_time();
85
86 self.iat.verify(now)?;
87
88 if let Some(nbf) = &self.nbf {
89 nbf.verify(now)?;
90 }
91
92 if let Some(exp) = &self.exp {
93 exp.verify(now)?;
94 }
95
96 Ok(())
97 }
98}
99
100#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
102#[serde(transparent)]
103pub struct SdHash(pub String);
104
105impl SdHash {
106 pub fn new(sd_alg: SdAlg, sd_jwt: &SdJwt) -> Self {
108 Self(sd_alg.hash(sd_jwt.trim_kb().as_bytes()))
109 }
110
111 pub fn verify(&self, alg: SdAlg, sd_jwt: &SdJwt) -> bool {
113 alg.verify(sd_jwt.trim_kb().as_bytes(), &self.0)
114 }
115}
116
117impl Claim for SdHash {
118 const JWT_CLAIM_NAME: &str = "sd_hash";
119}
120
121#[cfg(test)]
122mod tests {
123 use std::sync::LazyLock;
124
125 use serde::Deserialize;
126 use serde_json::json;
127 use ssi_claims_core::{ValidateClaims, VerificationParameters};
128 use ssi_core::JsonPointerBuf;
129 use ssi_jwk::JWK;
130 use ssi_jws::JwsPayload;
131 use ssi_jwt::{ClaimSet, JWTClaims};
132
133 use crate::{sd_jwt, ConcealJwtClaims, KbJwtPayload, SdAlg, SdJwt};
134
135 #[async_std::test]
136 async fn kb_sign() {
137 let claims = JWTClaims::builder()
138 .iss("https://example.com/issuer")
139 .iat(1683000000)
140 .exp(1883000000)
141 .sub("user_42")
142 .build()
143 .unwrap();
144
145 let jwk = JWK::generate_p256();
146 let cnf_jwk = JWK::generate_p256();
147
148 let pointers: &[JsonPointerBuf] = &[];
149 let mut sd_jwt = claims
150 .conceal_and_sign(SdAlg::Sha256, pointers, &jwk)
151 .await
152 .unwrap();
153
154 let kb_jwt = KbJwtPayload::new(
155 "issuer".to_owned(),
156 "123nonce".to_owned(),
157 SdAlg::Sha256,
158 &sd_jwt,
159 )
160 .sign(&cnf_jwk)
161 .await
162 .unwrap();
163
164 sd_jwt.set_kb(&kb_jwt);
165
166 let params = VerificationParameters::from_resolver(&jwk);
169 let (revealed, verification_result) =
170 sd_jwt.decode_reveal_verify_any(¶ms).await.unwrap();
171
172 verification_result.expect("SD-JWT verification failed");
173
174 let kb_jwt = sd_jwt
176 .decode_kb()
177 .expect("invalid KB-JWT")
178 .expect("missing KB-JWT");
179
180 let kb_jwt_claims = &kb_jwt.signing_bytes.payload;
182 assert_eq!(kb_jwt_claims.aud, "issuer");
183 assert_eq!(kb_jwt_claims.nonce.0, "123nonce");
184 assert!(kb_jwt_claims.sd_hash.verify(revealed.sd_alg, &sd_jwt));
185
186 let params = VerificationParameters::from_resolver(cnf_jwk);
188 kb_jwt
189 .verify(¶ms)
190 .await
191 .expect("KB-JWT verification failed")
192 .expect("invalid KB-JWT signature");
193 }
194
195 #[async_std::test]
196 async fn kb_verify() {
197 let params = VerificationParameters::from_resolver(&*JWK);
198
199 let (revealed, verification_result) = SD_JWT_KB
201 .decode_reveal_verify::<ExampleClaims, _>(¶ms)
202 .await
203 .unwrap();
204
205 let cnf_jwk = &revealed.jwt.signing_bytes.payload.private.cnf.jwk;
206
207 verification_result.expect("SD-JWT verification failed");
208
209 let kb_jwt = SD_JWT_KB
211 .decode_kb()
212 .expect("invalid KB-JWT")
213 .expect("missing KB-JWT");
214
215 let kb_jwt_claims = &kb_jwt.signing_bytes.payload;
217 assert_eq!(kb_jwt_claims.aud, "https://verifier.example.org");
218 assert_eq!(kb_jwt_claims.nonce.0, "1234567890");
219 assert!(kb_jwt_claims.sd_hash.verify(revealed.sd_alg, SD_JWT_KB));
220
221 let params = VerificationParameters::from_resolver(cnf_jwk);
223 kb_jwt
224 .verify(¶ms)
225 .await
226 .expect("KB-JWT verification failed")
227 .expect("invalid KB-JWT signature");
228 }
229
230 static JWK: LazyLock<JWK> = LazyLock::new(|| {
231 json!({
232 "kty": "EC",
233 "crv": "P-256",
234 "x": "b28d4MwZMjw8-00CG4xfnn9SLMVMM19SlqZpVb_uNtQ",
235 "y": "Xv5zWwuoaTgdS6hV43yI6gBwTnjukmFQQnJ_kCxzqk8"
236 })
237 .try_into()
238 .unwrap()
239 });
240
241 const SD_JWT_KB: &SdJwt = sd_jwt!("eyJhbGciOiAiRVMyNTYiLCAidHlwIjogImV4YW1wbGUrc2Qtand0In0.eyJfc2QiOiBbIkNyUWU3UzVrcUJBSHQtbk1ZWGdjNmJkdDJTSDVhVFkxc1VfTS1QZ2tqUEkiLCAiSnpZakg0c3ZsaUgwUjNQeUVNZmVadTZKdDY5dTVxZWhabzdGN0VQWWxTRSIsICJQb3JGYnBLdVZ1Nnh5bUphZ3ZrRnNGWEFiUm9jMkpHbEFVQTJCQTRvN2NJIiwgIlRHZjRvTGJnd2Q1SlFhSHlLVlFaVTlVZEdFMHc1cnREc3JaemZVYW9tTG8iLCAiWFFfM2tQS3QxWHlYN0tBTmtxVlI2eVoyVmE1TnJQSXZQWWJ5TXZSS0JNTSIsICJYekZyendzY002R242Q0pEYzZ2Vks4QmtNbmZHOHZPU0tmcFBJWmRBZmRFIiwgImdiT3NJNEVkcTJ4Mkt3LXc1d1BFemFrb2I5aFYxY1JEMEFUTjNvUUw5Sk0iLCAianN1OXlWdWx3UVFsaEZsTV8zSmx6TWFTRnpnbGhRRzBEcGZheVF3TFVLNCJdLCAiaXNzIjogImh0dHBzOi8vaXNzdWVyLmV4YW1wbGUuY29tIiwgImlhdCI6IDE2ODMwMDAwMDAsICJleHAiOiAxODgzMDAwMDAwLCAic3ViIjogInVzZXJfNDIiLCAibmF0aW9uYWxpdGllcyI6IFt7Ii4uLiI6ICJwRm5kamtaX1ZDem15VGE2VWpsWm8zZGgta284YUlLUWM5RGxHemhhVllvIn0sIHsiLi4uIjogIjdDZjZKa1B1ZHJ5M2xjYndIZ2VaOGtoQXYxVTFPU2xlclAwVmtCSnJXWjAifV0sICJfc2RfYWxnIjogInNoYS0yNTYiLCAiY25mIjogeyJqd2siOiB7Imt0eSI6ICJFQyIsICJjcnYiOiAiUC0yNTYiLCAieCI6ICJUQ0FFUjE5WnZ1M09IRjRqNFc0dmZTVm9ISVAxSUxpbERsczd2Q2VHZW1jIiwgInkiOiAiWnhqaVdXYlpNUUdIVldLVlE0aGJTSWlyc1ZmdWVjQ0U2dDRqVDlGMkhaUSJ9fX0.MczwjBFGtzf-6WMT-hIvYbkb11NrV1WMO-jTijpMPNbswNzZ87wY2uHz-CXo6R04b7jYrpj9mNRAvVssXou1iw~WyJlbHVWNU9nM2dTTklJOEVZbnN4QV9BIiwgImZhbWlseV9uYW1lIiwgIkRvZSJd~WyJBSngtMDk1VlBycFR0TjRRTU9xUk9BIiwgImFkZHJlc3MiLCB7InN0cmVldF9hZGRyZXNzIjogIjEyMyBNYWluIFN0IiwgImxvY2FsaXR5IjogIkFueXRvd24iLCAicmVnaW9uIjogIkFueXN0YXRlIiwgImNvdW50cnkiOiAiVVMifV0~WyIyR0xDNDJzS1F2ZUNmR2ZyeU5STjl3IiwgImdpdmVuX25hbWUiLCAiSm9obiJd~WyJsa2x4RjVqTVlsR1RQVW92TU5JdkNBIiwgIlVTIl0~eyJhbGciOiAiRVMyNTYiLCAidHlwIjogImtiK2p3dCJ9.eyJub25jZSI6ICIxMjM0NTY3ODkwIiwgImF1ZCI6ICJodHRwczovL3ZlcmlmaWVyLmV4YW1wbGUub3JnIiwgImlhdCI6IDE3NDg1MzcyNDQsICJzZF9oYXNoIjogIjBfQWYtMkItRWhMV1g1eWRoX3cyeHp3bU82aU02NkJfMlFDRWFuSTRmVVkifQ.T3SIus2OidNl41nmVkTZVCKKhOAX97aOldMyHFiYjHm261eLiJ1YiuONFiMN8QlCmYzDlBLAdPvrXh52KaLgUQ");
242
243 #[derive(Debug, PartialEq, Deserialize)]
244 struct ExampleAddress {
245 street_address: Option<String>,
246 locality: Option<String>,
247 region: Option<String>,
248 country: Option<String>,
249 }
250
251 #[derive(Debug, PartialEq, Deserialize)]
252 struct ExampleClaims {
253 cnf: Cnf,
254 given_name: Option<String>,
255 family_name: Option<String>,
256 email: Option<String>,
257 phone_number: Option<String>,
258 address: ExampleAddress,
259 birthdate: Option<String>,
260 }
261
262 #[derive(Debug, PartialEq, Deserialize)]
263 struct Cnf {
264 jwk: JWK,
265 }
266
267 impl ClaimSet for ExampleClaims {}
268 impl<E, P> ValidateClaims<E, P> for ExampleClaims {}
269}