1use crate::{error, SDJWTJson};
6use error::Result;
7use std::collections::{HashMap, VecDeque};
8use std::str::FromStr;
9use std::vec::Vec;
10
11use jsonwebtoken::jwk::Jwk;
12use jsonwebtoken::{Algorithm, EncodingKey, Header};
13use rand::Rng;
14use serde_json::Value;
15use serde_json::{json, Map as SJMap, Map};
16
17use crate::disclosure::SDJWTDisclosure;
18use crate::error::Error;
19use crate::utils::{base64_hash, generate_salt};
20use crate::{
21 SDJWTCommon, CNF_KEY, COMBINED_SERIALIZATION_FORMAT_SEPARATOR, DEFAULT_DIGEST_ALG,
22 DEFAULT_SIGNING_ALG, DIGEST_ALG_KEY, JWK_KEY, SD_DIGESTS_KEY, SD_LIST_PREFIX,
23 SDJWTSerializationFormat,
24};
25
26pub struct SDJWTIssuer {
27 sign_alg: String,
29 add_decoy_claims: bool,
30 extra_header_parameters: Option<HashMap<String, String>>,
31
32 issuer_key: EncodingKey,
34 holder_key: Option<Jwk>,
35
36 inner: SDJWTCommon,
38 all_disclosures: Vec<SDJWTDisclosure>,
39 sd_jwt_payload: SJMap<String, Value>,
40 signed_sd_jwt: String,
41 serialized_sd_jwt: String,
42}
43
44#[derive(PartialEq, Debug)]
46pub enum ClaimsForSelectiveDisclosureStrategy<'a> {
47 NoSDClaims,
49 TopLevel,
51 AllLevels,
53 Custom(Vec<&'a str>),
62}
63
64impl<'a> ClaimsForSelectiveDisclosureStrategy<'a> {
65 fn finalize_input(&mut self) -> Result<()> {
66 match self {
67 ClaimsForSelectiveDisclosureStrategy::Custom(keys) => {
68 for key in keys.iter_mut() {
69 if let Some(new_key) = key.strip_prefix("$.") {
70 *key = new_key;
71 } else {
72 return Err(Error::InvalidPath("Invalid JSONPath".to_owned()));
73 }
74 }
75 Ok(())
76 }
77 _ => Ok(()),
78 }
79 }
80
81 fn next_level(&self, key: &str) -> Self {
82 match self {
83 Self::NoSDClaims => Self::NoSDClaims,
84 Self::TopLevel => Self::NoSDClaims,
85 Self::AllLevels => Self::AllLevels,
86 Self::Custom(sd_keys) => {
87 let next_sd_keys = sd_keys
88 .iter()
89 .filter_map(|str| {
90 str.strip_prefix(key).and_then(|str|
91 match str.chars().next() {
92 Some('.') => Some(&str[1..]), Some('[') => Some(str), _ => None,
95 }
96 )
97 })
98 .collect();
99 Self::Custom(next_sd_keys)
100 }
101 }
102 }
103
104 fn sd_for_key(&self, key: &str) -> bool {
105 match self {
106 Self::NoSDClaims => false,
107 Self::TopLevel => true,
108 Self::AllLevels => true,
109 Self::Custom(sd_keys) => sd_keys.contains(&key),
110 }
111 }
112}
113
114impl SDJWTIssuer {
115 const DECOY_MIN_ELEMENTS: u32 = 2;
116 const DECOY_MAX_ELEMENTS: u32 = 5;
117
118 pub fn new(issuer_key: EncodingKey, sign_alg: Option<String>) -> Self {
129 SDJWTIssuer {
130 sign_alg: sign_alg.unwrap_or(DEFAULT_SIGNING_ALG.to_owned()),
131 add_decoy_claims: false,
132 extra_header_parameters: None,
133 issuer_key,
134 holder_key: None,
135 inner: Default::default(),
136 all_disclosures: vec![],
137 sd_jwt_payload: Default::default(),
138 signed_sd_jwt: "".to_string(),
139 serialized_sd_jwt: "".to_string(),
140 }
141 }
142
143 fn reset(&mut self) {
144 self.extra_header_parameters = Default::default();
145 self.all_disclosures = Default::default();
146 self.sd_jwt_payload = Default::default();
147 self.signed_sd_jwt = Default::default();
148 self.serialized_sd_jwt = Default::default();
149 }
150
151 pub fn issue_sd_jwt(
163 &mut self,
164 user_claims: Value,
165 mut sd_strategy: ClaimsForSelectiveDisclosureStrategy,
166 holder_key: Option<Jwk>,
167 add_decoy_claims: bool,
168 serialization_format: SDJWTSerializationFormat,
169 ) -> Result<String> {
171 let inner = SDJWTCommon {
172 serialization_format,
173 ..Default::default()
174 };
175
176 sd_strategy.finalize_input()?;
177
178 SDJWTCommon::check_for_sd_claim(&user_claims)?;
179
180 self.reset();
181 self.inner = inner;
182 self.holder_key = holder_key;
183 self.add_decoy_claims = add_decoy_claims;
184
185 self.assemble_sd_jwt_payload(user_claims, sd_strategy)?;
186 self.create_signed_jws()?;
187 self.create_combined()?;
188
189 Ok(self.serialized_sd_jwt.clone())
190 }
191
192 fn assemble_sd_jwt_payload(
193 &mut self,
194 mut user_claims: Value,
195 sd_strategy: ClaimsForSelectiveDisclosureStrategy,
196 ) -> Result<()> {
197 let claims_obj_ref = user_claims
198 .as_object_mut()
199 .ok_or(Error::ConversionError("json object".to_string()))?;
200 let always_revealed_root_keys = vec!["iss", "iat", "exp"];
201 let mut always_revealed_claims: Map<String, Value> = always_revealed_root_keys
202 .into_iter()
203 .filter_map(|key| claims_obj_ref.shift_remove_entry(key))
204 .collect();
205
206 self.sd_jwt_payload = self
207 .create_sd_claims(&user_claims, sd_strategy)
208 .as_object()
209 .ok_or(Error::ConversionError("json object".to_string()))?
210 .clone();
211
212 self.sd_jwt_payload.insert(
213 DIGEST_ALG_KEY.to_owned(),
214 Value::String(DEFAULT_DIGEST_ALG.to_owned()),
215 ); self.sd_jwt_payload.append(&mut always_revealed_claims);
217
218 if let Some(holder_key) = &self.holder_key {
219 self.sd_jwt_payload
220 .entry(CNF_KEY)
221 .or_insert_with(|| json!({JWK_KEY: holder_key}));
222 }
223
224 Ok(())
225 }
226
227 fn create_sd_claims(&mut self, user_claims: &Value, sd_strategy: ClaimsForSelectiveDisclosureStrategy) -> Value {
228 match user_claims {
229 Value::Array(list) => self.create_sd_claims_list(list, sd_strategy),
230 Value::Object(object) => self.create_sd_claims_object(object, sd_strategy),
231 _ => user_claims.to_owned(),
232 }
233 }
234
235 fn create_sd_claims_list(&mut self, list: &[Value], sd_strategy: ClaimsForSelectiveDisclosureStrategy) -> Value {
236 let mut claims = Vec::new();
237 for (idx, object) in list.iter().enumerate() {
238 let key = format!("[{idx}]");
239 let strategy_for_child = sd_strategy.next_level(&key);
240 let subtree = self.create_sd_claims(object, strategy_for_child);
241
242 if sd_strategy.sd_for_key(&key) {
243 let disclosure = SDJWTDisclosure::new(None, subtree);
244 claims.push(json!({ SD_LIST_PREFIX: disclosure.hash}));
245 self.all_disclosures.push(disclosure);
246 } else {
247 claims.push(subtree);
248 }
249 }
250 Value::Array(claims)
251 }
252
253 fn create_sd_claims_object(
254 &mut self,
255 user_claims: &SJMap<String, Value>,
256 sd_strategy: ClaimsForSelectiveDisclosureStrategy,
257 ) -> Value {
258 let mut claims = SJMap::new();
259
260 claims.insert(SD_DIGESTS_KEY.to_owned(), Value::Null);
262
263 let mut sd_claims = Vec::new();
264
265 for (key, value) in user_claims.iter() {
266 let strategy_for_child = sd_strategy.next_level(key);
267 let subtree_from_here = self.create_sd_claims(value, strategy_for_child);
268
269 if sd_strategy.sd_for_key(key) {
270 let disclosure = SDJWTDisclosure::new(Some(key.to_owned()), subtree_from_here);
271 sd_claims.push(disclosure.hash.clone());
272 self.all_disclosures.push(disclosure);
273 } else {
274 claims.insert(key.to_owned(), subtree_from_here);
275 }
276 }
277
278 if self.add_decoy_claims {
279 let num_decoy_elements =
280 rand::thread_rng().gen_range(Self::DECOY_MIN_ELEMENTS..Self::DECOY_MAX_ELEMENTS);
281 for _ in 0..num_decoy_elements {
282 sd_claims.push(self.create_decoy_claim_entry());
283 }
284 }
285
286 if !sd_claims.is_empty() {
287 sd_claims.sort();
288 claims.insert(
289 SD_DIGESTS_KEY.to_owned(),
290 Value::Array(sd_claims.into_iter().map(Value::String).collect()),
291 );
292 } else {
293 claims.shift_remove(SD_DIGESTS_KEY);
294 }
295
296 Value::Object(claims)
297 }
298
299 fn create_signed_jws(&mut self) -> Result<()> {
300 if let Some(extra_headers) = &self.extra_header_parameters {
301 let mut _protected_headers = extra_headers.clone();
302 for (key, value) in extra_headers.iter() {
303 _protected_headers.insert(key.to_string(), value.to_string());
304 }
305 unimplemented!("extra_headers are not supported for issuance");
306 }
307
308 let mut header = Header::new(
309 Algorithm::from_str(&self.sign_alg)
310 .map_err(|e| Error::DeserializationError(e.to_string()))?,
311 );
312 header.typ = self.inner.typ.clone();
313 self.signed_sd_jwt = jsonwebtoken::encode(&header, &self.sd_jwt_payload, &self.issuer_key)
314 .map_err(|e| Error::DeserializationError(e.to_string()))?;
315
316 Ok(())
317 }
318
319 fn create_combined(&mut self) -> Result<()> {
320 if self.inner.serialization_format == SDJWTSerializationFormat::Compact {
321 let mut disclosures: VecDeque<String> = self
322 .all_disclosures
323 .iter()
324 .map(|d| d.raw_b64.to_string())
325 .collect();
326 disclosures.push_front(self.signed_sd_jwt.clone());
327
328 let disclosures: Vec<&str> = disclosures.iter().map(|s| s.as_str()).collect();
329
330 self.serialized_sd_jwt = format!(
331 "{}{}",
332 disclosures.join(COMBINED_SERIALIZATION_FORMAT_SEPARATOR),
333 COMBINED_SERIALIZATION_FORMAT_SEPARATOR,
334 );
335 } else if self.inner.serialization_format == SDJWTSerializationFormat::JSON {
336 let jwt: Vec<&str> = self.signed_sd_jwt.split('.').collect();
337 if jwt.len() != 3 {
338 return Err(Error::InvalidInput(format!(
339 "Invalid JWT, JWT must contain three parts after splitting with \".\": jwt {}",
340 self.signed_sd_jwt
341 )));
342 }
343 let sd_jwt_json = SDJWTJson {
344 protected: jwt[0].to_owned(),
345 payload: jwt[1].to_owned(),
346 signature: jwt[2].to_owned(),
347 kb_jwt: None,
348 disclosures: self
349 .all_disclosures
350 .iter()
351 .map(|d| d.raw_b64.to_string())
352 .collect(),
353 };
354 self.serialized_sd_jwt = serde_json::to_string(&sd_jwt_json)
355 .map_err(|e| Error::DeserializationError(e.to_string()))?;
356 } else {
357 return Err(Error::InvalidInput(
358 format!("Unknown serialization format {}, only \"Compact\" or \"JSON\" formats are supported", self.inner.serialization_format)
359 ));
360 }
361
362 Ok(())
363 }
364
365 fn create_decoy_claim_entry(&mut self) -> String {
366 let digest = base64_hash(generate_salt().as_bytes()).to_string();
367 digest
368 }
369}
370
371#[cfg(test)]
372mod tests {
373 use jsonwebtoken::EncodingKey;
374 use log::trace;
375 use serde_json::json;
376
377 use crate::issuer::ClaimsForSelectiveDisclosureStrategy;
378 use crate::{SDJWTIssuer, SDJWTSerializationFormat};
379
380 const PRIVATE_ISSUER_PEM: &str = "-----BEGIN PRIVATE KEY-----\nMIGHAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBG0wawIBAQQgUr2bNKuBPOrAaxsR\nnbSH6hIhmNTxSGXshDSUD1a1y7ihRANCAARvbx3gzBkyPDz7TQIbjF+ef1IsxUwz\nX1KWpmlVv+421F7+c1sLqGk4HUuoVeN8iOoAcE547pJhUEJyf5Asc6pP\n-----END PRIVATE KEY-----\n";
381
382 #[test]
383 fn test_assembly_sd_full_recursive() {
384 let user_claims = json!({
385 "sub": "6c5c0a49-b589-431d-bae7-219122a9ec2c",
386 "iss": "https://example.com/issuer",
387 "iat": 1683000000,
388 "exp": 1883000000,
389 "address": {
390 "street_address": "Schulstr. 12",
391 "locality": "Schulpforta",
392 "region": "Sachsen-Anhalt",
393 "country": "DE"
394 }
395 });
396 let private_issuer_bytes = PRIVATE_ISSUER_PEM.as_bytes();
397 let issuer_key = EncodingKey::from_ec_pem(private_issuer_bytes).unwrap();
398 let sd_jwt = SDJWTIssuer::new(issuer_key, None).issue_sd_jwt(
399 user_claims,
400 ClaimsForSelectiveDisclosureStrategy::AllLevels,
401 None,
402 false,
403 SDJWTSerializationFormat::Compact,
404 )
405 .unwrap();
406 trace!("{:?}", sd_jwt)
407 }
408
409 #[test]
410 fn test_next_level_array() {
411 let strategy = ClaimsForSelectiveDisclosureStrategy::Custom(vec![
412 "name",
413 "addresses[1]",
414 "addresses[1].country",
415 "nationalities[0]",
416 ]);
417
418 let next_strategy = strategy.next_level("addresses");
419 assert_eq!(&next_strategy, &ClaimsForSelectiveDisclosureStrategy::Custom(vec!["[1]", "[1].country"]));
420 let next_strategy = next_strategy.next_level("[1]");
421 assert_eq!(&next_strategy, &ClaimsForSelectiveDisclosureStrategy::Custom(vec!["country"]));
422 }
423
424 #[test]
425 fn test_next_level_object() {
426 let strategy = ClaimsForSelectiveDisclosureStrategy::Custom(vec![
427 "address.street_address",
428 "address.locality",
429 "address.region",
430 "address.country",
431 ]);
432
433 let next_strategy = strategy.next_level("address");
434 assert_eq!(&next_strategy, &ClaimsForSelectiveDisclosureStrategy::Custom(vec![
435 "street_address",
436 "locality",
437 "region",
438 "country"
439 ]));
440 }
441}