1use std::collections::HashSet;
5use std::fmt::Display;
6use std::ops::Deref;
7use std::ops::DerefMut;
8use std::str::FromStr;
9
10use crate::jwt::Jwt;
11use crate::Disclosure;
12use crate::Error;
13use crate::Hasher;
14use crate::JsonObject;
15use crate::KeyBindingJwt;
16use crate::RequiredKeyBinding;
17use crate::Result;
18use crate::SdObjectDecoder;
19use crate::ARRAY_DIGEST_KEY;
20use crate::DIGESTS_KEY;
21use crate::SHA_ALG_NAME;
22use indexmap::IndexMap;
23use itertools::Either;
24use itertools::Itertools;
25use serde::Deserialize;
26use serde::Serialize;
27use serde_json::Value;
28
29#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Default)]
30pub struct SdJwtClaims {
31 #[serde(skip_serializing_if = "Vec::is_empty", default)]
32 pub _sd: Vec<String>,
33 #[serde(skip_serializing_if = "Option::is_none")]
34 pub _sd_alg: Option<String>,
35 #[serde(skip_serializing_if = "Option::is_none")]
36 pub cnf: Option<RequiredKeyBinding>,
37 #[serde(flatten)]
38 properties: JsonObject,
39}
40
41impl Deref for SdJwtClaims {
42 type Target = JsonObject;
43 fn deref(&self) -> &Self::Target {
44 &self.properties
45 }
46}
47
48impl DerefMut for SdJwtClaims {
49 fn deref_mut(&mut self) -> &mut Self::Target {
50 &mut self.properties
51 }
52}
53
54#[derive(Debug, Clone, Eq, PartialEq)]
57pub struct SdJwt {
58 jwt: Jwt<SdJwtClaims>,
60 disclosures: Vec<Disclosure>,
62 key_binding_jwt: Option<KeyBindingJwt>,
64}
65
66impl SdJwt {
67 pub(crate) fn new(
69 jwt: Jwt<SdJwtClaims>,
70 disclosures: Vec<Disclosure>,
71 key_binding_jwt: Option<KeyBindingJwt>,
72 ) -> Self {
73 Self {
74 jwt,
75 disclosures,
76 key_binding_jwt,
77 }
78 }
79
80 pub fn header(&self) -> &JsonObject {
81 &self.jwt.header
82 }
83
84 pub fn claims(&self) -> &SdJwtClaims {
85 &self.jwt.claims
86 }
87
88 pub fn claims_mut(&mut self) -> &mut SdJwtClaims {
93 &mut self.jwt.claims
94 }
95
96 pub fn disclosures(&self) -> &[Disclosure] {
97 &self.disclosures
98 }
99
100 pub fn required_key_bind(&self) -> Option<&RequiredKeyBinding> {
101 self.claims().cnf.as_ref()
102 }
103
104 pub fn key_binding_jwt(&self) -> Option<&KeyBindingJwt> {
105 self.key_binding_jwt.as_ref()
106 }
107
108 pub fn presentation(&self) -> String {
113 let disclosures = self.disclosures.iter().map(ToString::to_string).join("~");
114 let key_bindings = self
115 .key_binding_jwt
116 .as_ref()
117 .map(ToString::to_string)
118 .unwrap_or_default();
119 if disclosures.is_empty() {
120 format!("{}~{}", self.jwt, key_bindings)
121 } else {
122 format!("{}~{}~{}", self.jwt, disclosures, key_bindings)
123 }
124 }
125
126 pub fn parse(sd_jwt: &str) -> Result<Self> {
128 let sd_segments: Vec<&str> = sd_jwt.split('~').collect();
129 let num_of_segments = sd_segments.len();
130 if num_of_segments < 2 {
131 return Err(Error::DeserializationError(
132 "SD-JWT format is invalid, less than 2 segments".to_string(),
133 ));
134 }
135
136 let jwt = sd_segments.first().unwrap().parse()?;
137
138 let disclosures = sd_segments[1..num_of_segments - 1]
139 .iter()
140 .map(|s| Disclosure::parse(s))
141 .try_collect()?;
142
143 let key_binding_jwt = sd_segments
144 .last()
145 .filter(|segment| !segment.is_empty())
146 .map(|segment| segment.parse())
147 .transpose()?;
148
149 Ok(Self {
150 jwt,
151 disclosures,
152 key_binding_jwt,
153 })
154 }
155
156 pub fn into_presentation(self, hasher: &dyn Hasher) -> Result<SdJwtPresentationBuilder> {
161 SdJwtPresentationBuilder::new(self, hasher)
162 }
163
164 pub fn into_disclosed_object(self, hasher: &dyn Hasher) -> Result<JsonObject> {
167 let decoder = SdObjectDecoder;
168 let object = serde_json::to_value(self.claims()).unwrap();
169
170 let disclosure_map = self
171 .disclosures
172 .into_iter()
173 .map(|disclosure| (hasher.encoded_digest(disclosure.as_str()), disclosure))
174 .collect();
175
176 decoder.decode(object.as_object().unwrap(), &disclosure_map)
177 }
178}
179
180impl Display for SdJwt {
181 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
182 f.write_str(&(self.presentation()))
183 }
184}
185
186impl FromStr for SdJwt {
187 type Err = Error;
188 fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
189 Self::parse(s)
190 }
191}
192
193#[derive(Debug, Clone)]
194pub struct SdJwtPresentationBuilder {
195 sd_jwt: SdJwt,
196 disclosures: IndexMap<String, Disclosure>,
197 disclosures_to_omit: HashSet<usize>,
198 object: Value,
199}
200
201impl Deref for SdJwtPresentationBuilder {
202 type Target = SdJwt;
203 fn deref(&self) -> &Self::Target {
204 &self.sd_jwt
205 }
206}
207
208impl SdJwtPresentationBuilder {
209 pub fn new(mut sd_jwt: SdJwt, hasher: &dyn Hasher) -> Result<Self> {
210 let required_hasher = sd_jwt.claims()._sd_alg.as_deref().unwrap_or(SHA_ALG_NAME);
211 if required_hasher != hasher.alg_name() {
212 return Err(Error::InvalidHasher(format!(
213 "hasher \"{}\" was provided, but \"{required_hasher} is required\"",
214 hasher.alg_name()
215 )));
216 }
217 let disclosures = std::mem::take(&mut sd_jwt.disclosures)
218 .into_iter()
219 .map(|disclosure| (hasher.encoded_digest(disclosure.as_str()), disclosure))
220 .collect();
221 let object = {
222 let sd = std::mem::take(&mut sd_jwt.jwt.claims._sd)
223 .into_iter()
224 .map(Value::String)
225 .collect();
226 let mut object = Value::Object(std::mem::take(&mut sd_jwt.jwt.claims.properties));
227 object
228 .as_object_mut()
229 .unwrap()
230 .insert(DIGESTS_KEY.to_string(), Value::Array(sd));
231
232 object
233 };
234 Ok(Self {
235 sd_jwt,
236 disclosures,
237 disclosures_to_omit: HashSet::default(),
238 object,
239 })
240 }
241
242 pub fn conceal(mut self, path: &str) -> Result<Self> {
248 self
249 .disclosures_to_omit
250 .extend(find_disclosure_and_sub_disclosures_for_value_at_path(
251 &self.object,
252 path,
253 &self.disclosures,
254 )?);
255 Ok(self)
256 }
257
258 pub fn conceal_all(mut self) -> Self {
262 self.disclosures_to_omit.extend(0..self.disclosures.len());
263 self
264 }
265
266 pub fn disclose(mut self, path: &str) -> Result<Self> {
274 let disclosing = find_disclosure_and_parent_disclosures_for_value_at_path(&self.object, path, &self.disclosures)?;
275 for idx in disclosing {
276 self.disclosures_to_omit.remove(&idx);
277 }
278 Ok(self)
279 }
280
281 pub fn attach_key_binding_jwt(mut self, kb_jwt: KeyBindingJwt) -> Self {
283 self.sd_jwt.key_binding_jwt = Some(kb_jwt);
284 self
285 }
286
287 pub fn finish(self) -> Result<(SdJwt, Vec<Disclosure>)> {
291 if self.sd_jwt.required_key_bind().is_some() && self.key_binding_jwt.is_none() {
292 return Err(Error::MissingKeyBindingJwt);
293 }
294
295 let SdJwtPresentationBuilder {
297 mut sd_jwt,
298 disclosures,
299 disclosures_to_omit,
300 object,
301 ..
302 } = self;
303
304 let (disclosures_to_keep, omitted_disclosures) =
305 disclosures
306 .into_values()
307 .enumerate()
308 .partition_map(|(idx, disclosure)| {
309 if disclosures_to_omit.contains(&idx) {
310 Either::Right(disclosure)
311 } else {
312 Either::Left(disclosure)
313 }
314 });
315
316 let Value::Object(mut obj) = object else {
317 unreachable!();
318 };
319 let Value::Array(sd) = obj.remove(DIGESTS_KEY).unwrap_or(Value::Array(vec![])) else {
320 unreachable!()
321 };
322 sd_jwt.jwt.claims._sd = sd
323 .into_iter()
324 .map(|value| {
325 if let Value::String(s) = value {
326 s
327 } else {
328 unreachable!()
329 }
330 })
331 .collect();
332 sd_jwt.jwt.claims.properties = obj;
333 sd_jwt.disclosures = disclosures_to_keep;
334
335 Ok((sd_jwt, omitted_disclosures))
336 }
337}
338
339fn find_disclosure_and_sub_disclosures_for_value_at_path<'a>(
340 value: &'a Value,
341 path: &str,
342 disclosures: &'a IndexMap<String, Disclosure>,
343) -> Result<Vec<usize>> {
344 let path_segments = path.trim_start_matches('/').split('/').collect_vec();
345 let (value, mut visited_disclosures) = traverse_disclosable_object(value, &path_segments, disclosures)
346 .ok_or_else(|| Error::InvalidPath("the referenced element doesn't exist or is not concealable".to_owned()))?;
347 let path_referenced_disclosure = visited_disclosures
348 .pop()
349 .ok_or_else(|| Error::InvalidPath("the referenced element doesn't exist or is not concealable".to_owned()))?;
350
351 let mut disclosures_to_omit = get_all_sub_disclosures(value, disclosures);
352 disclosures_to_omit.push(path_referenced_disclosure);
353
354 Ok(disclosures_to_omit)
355}
356
357fn find_disclosure_and_parent_disclosures_for_value_at_path<'a>(
358 value: &'a Value,
359 path: &str,
360 disclosures: &'a IndexMap<String, Disclosure>,
361) -> Result<Vec<usize>> {
362 let path_segments = path.trim_start_matches('/').split('/').collect_vec();
363 traverse_disclosable_object(value, &path_segments, disclosures)
364 .map(|(_, disclosures)| disclosures)
365 .ok_or_else(|| Error::InvalidPath("the referenced element doesn't exist or is not concealable".to_owned()))
366}
367
368fn find_disclosure(object: &JsonObject, key: &str, disclosures: &IndexMap<String, Disclosure>) -> Option<usize> {
369 object
372 .get(DIGESTS_KEY)
373 .and_then(|value| value.as_array())
374 .iter()
375 .flat_map(|values| values.iter())
376 .flat_map(|value| value.as_str())
377 .find(|digest| {
378 disclosures
379 .get(*digest)
380 .and_then(|disclosure| disclosure.claim_name.as_deref())
381 .is_some_and(|name| name == key)
382 })
383 .and_then(|digest| disclosures.get_index_of(digest))
384}
385
386fn traverse_disclosable_object<'a>(
387 mut value: &'a Value,
388 path: &[&str],
389 disclosures: &'a IndexMap<String, Disclosure>,
390) -> Option<(&'a Value, Vec<usize>)> {
391 let mut visited_disclosures = vec![];
392 for path_segment in path {
393 let step = traverse_disclosable_object_step(value, path_segment, disclosures)?;
394 value = step.value;
395 if let Some(disclosure) = step.disclosure {
396 visited_disclosures.push(disclosure)
397 }
398 }
399
400 Some((value, visited_disclosures))
401}
402
403fn traverse_disclosable_object_step<'a>(
404 value: &'a Value,
405 path_fragment: &str,
406 disclosures: &'a IndexMap<String, Disclosure>,
407) -> Option<TraversalResult<'a>> {
408 match value {
409 Value::Object(object) if object.contains_key(path_fragment) => {
411 Some(TraversalResult::new_value(object.get(path_fragment).unwrap()))
412 }
413 Value::Object(object) => {
415 let idx = find_disclosure(object, path_fragment, disclosures)?;
416 let (_, disclosure) = disclosures.get_index(idx).unwrap();
417 Some(TraversalResult::new_from_disclosure(idx, disclosure))
418 }
419 Value::Array(array) => {
420 let arr_idx = path_fragment.parse::<usize>().ok()?;
421 let value = array.get(arr_idx)?;
422
423 if let Some(digest) = value.get(ARRAY_DIGEST_KEY).and_then(|value| value.as_str()) {
425 disclosures
426 .get_full(digest)
427 .map(|(idx, _, disclosure)| TraversalResult::new_from_disclosure(idx, disclosure))
428 } else {
429 Some(TraversalResult::new_value(value))
430 }
431 }
432 _ => None,
433 }
434}
435
436#[derive(Debug)]
438struct TraversalResult<'a> {
439 value: &'a Value,
441 disclosure: Option<usize>,
443}
444
445impl<'a> TraversalResult<'a> {
446 fn new_value(value: &'a Value) -> Self {
447 Self {
448 value,
449 disclosure: None,
450 }
451 }
452
453 fn new_from_disclosure(idx: usize, disclosure: &'a Disclosure) -> Self {
454 Self {
455 value: &disclosure.claim_value,
456 disclosure: Some(idx),
457 }
458 }
459}
460
461fn get_all_sub_disclosures<'a>(value: &'a Value, disclosures: &'a IndexMap<String, Disclosure>) -> Vec<usize> {
462 let mut sub_disclosures = vec![];
463 match value {
464 Value::Object(object) => {
465 object
467 .get(DIGESTS_KEY)
468 .and_then(|sd| sd.as_array())
469 .map(|sd| sd.iter())
470 .unwrap_or_default()
471 .flat_map(|value| value.as_str())
472 .filter_map(|digest| disclosures.get_index_of(digest))
473 .for_each(|idx| sub_disclosures.push(idx));
474 object.values().for_each(|value| {
476 let found_sub_disclosures = get_all_sub_disclosures(value, disclosures);
477 sub_disclosures.extend(found_sub_disclosures);
478 });
479 }
480 Value::Array(arr) => {
481 for value in arr.iter().filter(|value| value.is_object()) {
482 if let Some(idx) = value
483 .get(ARRAY_DIGEST_KEY)
484 .and_then(|value| value.as_str())
485 .and_then(|digest| disclosures.get_index_of(digest))
486 {
487 sub_disclosures.push(idx);
488 } else {
489 sub_disclosures.extend(get_all_sub_disclosures(value, disclosures));
490 }
491 }
492 }
493 _ => (),
494 }
495
496 sub_disclosures
497}
498
499#[cfg(test)]
500mod test {
501 use crate::SdJwt;
502 const SD_JWT: &str = "eyJhbGciOiAiRVMyNTYiLCAidHlwIjogImV4YW1wbGUrc2Qtand0In0.eyJfc2QiOiBbIkM5aW5wNllvUmFFWFI0Mjd6WUpQN1FyazFXSF84YmR3T0FfWVVyVW5HUVUiLCAiS3VldDF5QWEwSElRdlluT1ZkNTloY1ZpTzlVZzZKMmtTZnFZUkJlb3d2RSIsICJNTWxkT0ZGekIyZDB1bWxtcFRJYUdlcmhXZFVfUHBZZkx2S2hoX2ZfOWFZIiwgIlg2WkFZT0lJMnZQTjQwVjd4RXhad1Z3ejd5Um1MTmNWd3Q1REw4Ukx2NGciLCAiWTM0em1JbzBRTExPdGRNcFhHd2pCZ0x2cjE3eUVoaFlUMEZHb2ZSLWFJRSIsICJmeUdwMFdUd3dQdjJKRFFsbjFsU2lhZW9iWnNNV0ExMGJRNTk4OS05RFRzIiwgIm9tbUZBaWNWVDhMR0hDQjB1eXd4N2ZZdW8zTUhZS08xNWN6LVJaRVlNNVEiLCAiczBCS1lzTFd4UVFlVTh0VmxsdE03TUtzSVJUckVJYTFQa0ptcXhCQmY1VSJdLCAiaXNzIjogImh0dHBzOi8vaXNzdWVyLmV4YW1wbGUuY29tIiwgImlhdCI6IDE2ODMwMDAwMDAsICJleHAiOiAxODgzMDAwMDAwLCAiYWRkcmVzcyI6IHsiX3NkIjogWyI2YVVoelloWjdTSjFrVm1hZ1FBTzN1MkVUTjJDQzFhSGhlWnBLbmFGMF9FIiwgIkF6TGxGb2JrSjJ4aWF1cFJFUHlvSnotOS1OU2xkQjZDZ2pyN2ZVeW9IemciLCAiUHp6Y1Z1MHFiTXVCR1NqdWxmZXd6a2VzRDl6dXRPRXhuNUVXTndrclEtayIsICJiMkRrdzBqY0lGOXJHZzhfUEY4WmN2bmNXN3p3Wmo1cnlCV3ZYZnJwemVrIiwgImNQWUpISVo4VnUtZjlDQ3lWdWIyVWZnRWs4anZ2WGV6d0sxcF9KbmVlWFEiLCAiZ2xUM2hyU1U3ZlNXZ3dGNVVEWm1Xd0JUdzMyZ25VbGRJaGk4aEdWQ2FWNCIsICJydkpkNmlxNlQ1ZWptc0JNb0d3dU5YaDlxQUFGQVRBY2k0MG9pZEVlVnNBIiwgInVOSG9XWWhYc1poVkpDTkUyRHF5LXpxdDd0NjlnSkt5NVFhRnY3R3JNWDQiXX0sICJfc2RfYWxnIjogInNoYS0yNTYifQ.gR6rSL7urX79CNEvTQnP1MH5xthG11ucIV44SqKFZ4Pvlu_u16RfvXQd4k4CAIBZNKn2aTI18TfvFwV97gJFoA~WyJHMDJOU3JRZmpGWFE3SW8wOXN5YWpBIiwgInJlZ2lvbiIsICJcdTZlMmZcdTUzM2EiXQ~WyJsa2x4RjVqTVlsR1RQVW92TU5JdkNBIiwgImNvdW50cnkiLCAiSlAiXQ~";
503
504 #[test]
505 fn parse() {
506 let sd_jwt = SdJwt::parse(SD_JWT).unwrap();
507 assert_eq!(sd_jwt.disclosures.len(), 2);
508 assert!(sd_jwt.key_binding_jwt.is_none());
509 }
510
511 #[test]
512 fn round_trip_ser_des() {
513 let sd_jwt = SdJwt::parse(SD_JWT).unwrap();
514 assert_eq!(&sd_jwt.to_string(), SD_JWT);
515 }
516}