sd_jwt_payload/
sd_jwt.rs

1// Copyright 2020-2023 IOTA Stiftung
2// SPDX-License-Identifier: Apache-2.0
3
4use std::collections::HashSet;
5use std::fmt::Display;
6use std::ops::Deref;
7use std::ops::DerefMut;
8use std::str::FromStr;
9
10use crate::jwt::Jwt;
11use crate::Disclosure;
12use crate::Error;
13use crate::Hasher;
14use crate::JsonObject;
15use crate::KeyBindingJwt;
16use crate::RequiredKeyBinding;
17use crate::Result;
18use crate::SdObjectDecoder;
19use crate::ARRAY_DIGEST_KEY;
20use crate::DIGESTS_KEY;
21use crate::SHA_ALG_NAME;
22use indexmap::IndexMap;
23use itertools::Either;
24use itertools::Itertools;
25use serde::Deserialize;
26use serde::Serialize;
27use serde_json::Value;
28
29#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Default)]
30pub struct SdJwtClaims {
31  #[serde(skip_serializing_if = "Vec::is_empty", default)]
32  pub _sd: Vec<String>,
33  #[serde(skip_serializing_if = "Option::is_none")]
34  pub _sd_alg: Option<String>,
35  #[serde(skip_serializing_if = "Option::is_none")]
36  pub cnf: Option<RequiredKeyBinding>,
37  #[serde(flatten)]
38  properties: JsonObject,
39}
40
41impl Deref for SdJwtClaims {
42  type Target = JsonObject;
43  fn deref(&self) -> &Self::Target {
44    &self.properties
45  }
46}
47
48impl DerefMut for SdJwtClaims {
49  fn deref_mut(&mut self) -> &mut Self::Target {
50    &mut self.properties
51  }
52}
53
54/// Representation of an SD-JWT of the format
55/// `<Issuer-signed JWT>~<Disclosure 1>~<Disclosure 2>~...~<Disclosure N>~<optional KB-JWT>`.
56#[derive(Debug, Clone, Eq, PartialEq)]
57pub struct SdJwt {
58  /// The JWT part.
59  jwt: Jwt<SdJwtClaims>,
60  /// The disclosures part.
61  disclosures: Vec<Disclosure>,
62  /// The optional key binding JWT.
63  key_binding_jwt: Option<KeyBindingJwt>,
64}
65
66impl SdJwt {
67  /// Creates a new [`SdJwt`] from its components.
68  pub(crate) fn new(
69    jwt: Jwt<SdJwtClaims>,
70    disclosures: Vec<Disclosure>,
71    key_binding_jwt: Option<KeyBindingJwt>,
72  ) -> Self {
73    Self {
74      jwt,
75      disclosures,
76      key_binding_jwt,
77    }
78  }
79
80  pub fn header(&self) -> &JsonObject {
81    &self.jwt.header
82  }
83
84  pub fn claims(&self) -> &SdJwtClaims {
85    &self.jwt.claims
86  }
87
88  /// Returns a mutable reference to this SD-JWT's claims.
89  /// ## Warning
90  /// Modifying the claims might invalidate the signature.
91  /// Use this method carefully.
92  pub fn claims_mut(&mut self) -> &mut SdJwtClaims {
93    &mut self.jwt.claims
94  }
95
96  pub fn disclosures(&self) -> &[Disclosure] {
97    &self.disclosures
98  }
99
100  pub fn required_key_bind(&self) -> Option<&RequiredKeyBinding> {
101    self.claims().cnf.as_ref()
102  }
103
104  pub fn key_binding_jwt(&self) -> Option<&KeyBindingJwt> {
105    self.key_binding_jwt.as_ref()
106  }
107
108  /// Serializes the components into the final SD-JWT.
109  ///
110  /// ## Error
111  /// Returns [`Error::DeserializationError`] if parsing fails.
112  pub fn presentation(&self) -> String {
113    let disclosures = self.disclosures.iter().map(ToString::to_string).join("~");
114    let key_bindings = self
115      .key_binding_jwt
116      .as_ref()
117      .map(ToString::to_string)
118      .unwrap_or_default();
119    if disclosures.is_empty() {
120      format!("{}~{}", self.jwt, key_bindings)
121    } else {
122      format!("{}~{}~{}", self.jwt, disclosures, key_bindings)
123    }
124  }
125
126  /// Parses an SD-JWT into its components as [`SdJwt`].
127  pub fn parse(sd_jwt: &str) -> Result<Self> {
128    let sd_segments: Vec<&str> = sd_jwt.split('~').collect();
129    let num_of_segments = sd_segments.len();
130    if num_of_segments < 2 {
131      return Err(Error::DeserializationError(
132        "SD-JWT format is invalid, less than 2 segments".to_string(),
133      ));
134    }
135
136    let jwt = sd_segments.first().unwrap().parse()?;
137
138    let disclosures = sd_segments[1..num_of_segments - 1]
139      .iter()
140      .map(|s| Disclosure::parse(s))
141      .try_collect()?;
142
143    let key_binding_jwt = sd_segments
144      .last()
145      .filter(|segment| !segment.is_empty())
146      .map(|segment| segment.parse())
147      .transpose()?;
148
149    Ok(Self {
150      jwt,
151      disclosures,
152      key_binding_jwt,
153    })
154  }
155
156  /// Prepares this [`SdJwt`] for a presentation, returning an [`SdJwtPresentationBuilder`].
157  /// ## Errors
158  /// - [`Error::InvalidHasher`] is returned if the provided `hasher`'s algorithm doesn't match the algorithm specified
159  ///   by SD-JWT's `_sd_alg` claim. "sha-256" is used if the claim is missing.
160  pub fn into_presentation(self, hasher: &dyn Hasher) -> Result<SdJwtPresentationBuilder> {
161    SdJwtPresentationBuilder::new(self, hasher)
162  }
163
164  /// Returns the JSON object obtained by replacing all disclosures into their
165  /// corresponding JWT concealable claims.
166  pub fn into_disclosed_object(self, hasher: &dyn Hasher) -> Result<JsonObject> {
167    let decoder = SdObjectDecoder;
168    let object = serde_json::to_value(self.claims()).unwrap();
169
170    let disclosure_map = self
171      .disclosures
172      .into_iter()
173      .map(|disclosure| (hasher.encoded_digest(disclosure.as_str()), disclosure))
174      .collect();
175
176    decoder.decode(object.as_object().unwrap(), &disclosure_map)
177  }
178}
179
180impl Display for SdJwt {
181  fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
182    f.write_str(&(self.presentation()))
183  }
184}
185
186impl FromStr for SdJwt {
187  type Err = Error;
188  fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
189    Self::parse(s)
190  }
191}
192
193#[derive(Debug, Clone)]
194pub struct SdJwtPresentationBuilder {
195  sd_jwt: SdJwt,
196  disclosures: IndexMap<String, Disclosure>,
197  disclosures_to_omit: HashSet<usize>,
198  object: Value,
199}
200
201impl Deref for SdJwtPresentationBuilder {
202  type Target = SdJwt;
203  fn deref(&self) -> &Self::Target {
204    &self.sd_jwt
205  }
206}
207
208impl SdJwtPresentationBuilder {
209  pub fn new(mut sd_jwt: SdJwt, hasher: &dyn Hasher) -> Result<Self> {
210    let required_hasher = sd_jwt.claims()._sd_alg.as_deref().unwrap_or(SHA_ALG_NAME);
211    if required_hasher != hasher.alg_name() {
212      return Err(Error::InvalidHasher(format!(
213        "hasher \"{}\" was provided, but \"{required_hasher} is required\"",
214        hasher.alg_name()
215      )));
216    }
217    let disclosures = std::mem::take(&mut sd_jwt.disclosures)
218      .into_iter()
219      .map(|disclosure| (hasher.encoded_digest(disclosure.as_str()), disclosure))
220      .collect();
221    let object = {
222      let sd = std::mem::take(&mut sd_jwt.jwt.claims._sd)
223        .into_iter()
224        .map(Value::String)
225        .collect();
226      let mut object = Value::Object(std::mem::take(&mut sd_jwt.jwt.claims.properties));
227      object
228        .as_object_mut()
229        .unwrap()
230        .insert(DIGESTS_KEY.to_string(), Value::Array(sd));
231
232      object
233    };
234    Ok(Self {
235      sd_jwt,
236      disclosures,
237      disclosures_to_omit: HashSet::default(),
238      object,
239    })
240  }
241
242  /// Removes the disclosure for the property at `path`, concealing it.
243  ///
244  /// ## Notes
245  /// - When concealing a claim more than one disclosure may be removed: the disclosure for the claim itself and the
246  ///   disclosures for any concealable sub-claim.
247  pub fn conceal(mut self, path: &str) -> Result<Self> {
248    self
249      .disclosures_to_omit
250      .extend(find_disclosure_and_sub_disclosures_for_value_at_path(
251        &self.object,
252        path,
253        &self.disclosures,
254      )?);
255    Ok(self)
256  }
257
258  /// Removes all disclosures from this SD-JWT, resulting in a token that,
259  /// when presented, will have *all* selectively-disclosable properties
260  /// omitted.
261  pub fn conceal_all(mut self) -> Self {
262    self.disclosures_to_omit.extend(0..self.disclosures.len());
263    self
264  }
265
266  /// Discloses a value that was previously concealed.
267  /// # Notes
268  /// - This method may disclose multiple values, if the given path references a disclosable value stored within another
269  ///   disclosable value. That is, [disclose](Self::disclose) will unconceal the selectively disclosable value at
270  ///   `path` together with *all* its parents that are disclosable values themselves.
271  /// - By default *all* disclosable claims are disclosed, therefore this method can only be used to *undo* any
272  ///   concealment operations previously performed by either [Self::conceal] or [Self::conceal_all].
273  pub fn disclose(mut self, path: &str) -> Result<Self> {
274    let disclosing = find_disclosure_and_parent_disclosures_for_value_at_path(&self.object, path, &self.disclosures)?;
275    for idx in disclosing {
276      self.disclosures_to_omit.remove(&idx);
277    }
278    Ok(self)
279  }
280
281  /// Adds a [`KeyBindingJwt`] to this [`SdJwt`]'s presentation.
282  pub fn attach_key_binding_jwt(mut self, kb_jwt: KeyBindingJwt) -> Self {
283    self.sd_jwt.key_binding_jwt = Some(kb_jwt);
284    self
285  }
286
287  /// Returns the resulting [`SdJwt`] together with all removed disclosures.
288  /// ## Errors
289  /// - Fails with [`Error::MissingKeyBindingJwt`] if this [`SdJwt`] requires a key binding but none was provided.
290  pub fn finish(self) -> Result<(SdJwt, Vec<Disclosure>)> {
291    if self.sd_jwt.required_key_bind().is_some() && self.key_binding_jwt.is_none() {
292      return Err(Error::MissingKeyBindingJwt);
293    }
294
295    // Put everything back in its place.
296    let SdJwtPresentationBuilder {
297      mut sd_jwt,
298      disclosures,
299      disclosures_to_omit,
300      object,
301      ..
302    } = self;
303
304    let (disclosures_to_keep, omitted_disclosures) =
305      disclosures
306        .into_values()
307        .enumerate()
308        .partition_map(|(idx, disclosure)| {
309          if disclosures_to_omit.contains(&idx) {
310            Either::Right(disclosure)
311          } else {
312            Either::Left(disclosure)
313          }
314        });
315
316    let Value::Object(mut obj) = object else {
317      unreachable!();
318    };
319    let Value::Array(sd) = obj.remove(DIGESTS_KEY).unwrap_or(Value::Array(vec![])) else {
320      unreachable!()
321    };
322    sd_jwt.jwt.claims._sd = sd
323      .into_iter()
324      .map(|value| {
325        if let Value::String(s) = value {
326          s
327        } else {
328          unreachable!()
329        }
330      })
331      .collect();
332    sd_jwt.jwt.claims.properties = obj;
333    sd_jwt.disclosures = disclosures_to_keep;
334
335    Ok((sd_jwt, omitted_disclosures))
336  }
337}
338
339fn find_disclosure_and_sub_disclosures_for_value_at_path<'a>(
340  value: &'a Value,
341  path: &str,
342  disclosures: &'a IndexMap<String, Disclosure>,
343) -> Result<Vec<usize>> {
344  let path_segments = path.trim_start_matches('/').split('/').collect_vec();
345  let (value, mut visited_disclosures) = traverse_disclosable_object(value, &path_segments, disclosures)
346    .ok_or_else(|| Error::InvalidPath("the referenced element doesn't exist or is not concealable".to_owned()))?;
347  let path_referenced_disclosure = visited_disclosures
348    .pop()
349    .ok_or_else(|| Error::InvalidPath("the referenced element doesn't exist or is not concealable".to_owned()))?;
350
351  let mut disclosures_to_omit = get_all_sub_disclosures(value, disclosures);
352  disclosures_to_omit.push(path_referenced_disclosure);
353
354  Ok(disclosures_to_omit)
355}
356
357fn find_disclosure_and_parent_disclosures_for_value_at_path<'a>(
358  value: &'a Value,
359  path: &str,
360  disclosures: &'a IndexMap<String, Disclosure>,
361) -> Result<Vec<usize>> {
362  let path_segments = path.trim_start_matches('/').split('/').collect_vec();
363  traverse_disclosable_object(value, &path_segments, disclosures)
364    .map(|(_, disclosures)| disclosures)
365    .ok_or_else(|| Error::InvalidPath("the referenced element doesn't exist or is not concealable".to_owned()))
366}
367
368fn find_disclosure(object: &JsonObject, key: &str, disclosures: &IndexMap<String, Disclosure>) -> Option<usize> {
369  // Try to find the digest for disclosable property `key` in
370  // the `_sd` field of `object`.
371  object
372    .get(DIGESTS_KEY)
373    .and_then(|value| value.as_array())
374    .iter()
375    .flat_map(|values| values.iter())
376    .flat_map(|value| value.as_str())
377    .find(|digest| {
378      disclosures
379        .get(*digest)
380        .and_then(|disclosure| disclosure.claim_name.as_deref())
381        .is_some_and(|name| name == key)
382    })
383    .and_then(|digest| disclosures.get_index_of(digest))
384}
385
386fn traverse_disclosable_object<'a>(
387  mut value: &'a Value,
388  path: &[&str],
389  disclosures: &'a IndexMap<String, Disclosure>,
390) -> Option<(&'a Value, Vec<usize>)> {
391  let mut visited_disclosures = vec![];
392  for path_segment in path {
393    let step = traverse_disclosable_object_step(value, path_segment, disclosures)?;
394    value = step.value;
395    if let Some(disclosure) = step.disclosure {
396      visited_disclosures.push(disclosure)
397    }
398  }
399
400  Some((value, visited_disclosures))
401}
402
403fn traverse_disclosable_object_step<'a>(
404  value: &'a Value,
405  path_fragment: &str,
406  disclosures: &'a IndexMap<String, Disclosure>,
407) -> Option<TraversalResult<'a>> {
408  match value {
409    // Object has an entry for the element we are searching.
410    Value::Object(object) if object.contains_key(path_fragment) => {
411      Some(TraversalResult::new_value(object.get(path_fragment).unwrap()))
412    }
413    // No entry for path fragment, searching object's disclosures.
414    Value::Object(object) => {
415      let idx = find_disclosure(object, path_fragment, disclosures)?;
416      let (_, disclosure) = disclosures.get_index(idx).unwrap();
417      Some(TraversalResult::new_from_disclosure(idx, disclosure))
418    }
419    Value::Array(array) => {
420      let arr_idx = path_fragment.parse::<usize>().ok()?;
421      let value = array.get(arr_idx)?;
422
423      // Check if the value is a disclosable value.
424      if let Some(digest) = value.get(ARRAY_DIGEST_KEY).and_then(|value| value.as_str()) {
425        disclosures
426          .get_full(digest)
427          .map(|(idx, _, disclosure)| TraversalResult::new_from_disclosure(idx, disclosure))
428      } else {
429        Some(TraversalResult::new_value(value))
430      }
431    }
432    _ => None,
433  }
434}
435
436/// The result of a step in the traversal of a disclosable value.
437#[derive(Debug)]
438struct TraversalResult<'a> {
439  /// The reached value.
440  value: &'a Value,
441  /// The index of the disclosure we had to walk through to reach `value`.
442  disclosure: Option<usize>,
443}
444
445impl<'a> TraversalResult<'a> {
446  fn new_value(value: &'a Value) -> Self {
447    Self {
448      value,
449      disclosure: None,
450    }
451  }
452
453  fn new_from_disclosure(idx: usize, disclosure: &'a Disclosure) -> Self {
454    Self {
455      value: &disclosure.claim_value,
456      disclosure: Some(idx),
457    }
458  }
459}
460
461fn get_all_sub_disclosures<'a>(value: &'a Value, disclosures: &'a IndexMap<String, Disclosure>) -> Vec<usize> {
462  let mut sub_disclosures = vec![];
463  match value {
464    Value::Object(object) => {
465      // Check object's "_sd" entry.
466      object
467        .get(DIGESTS_KEY)
468        .and_then(|sd| sd.as_array())
469        .map(|sd| sd.iter())
470        .unwrap_or_default()
471        .flat_map(|value| value.as_str())
472        .filter_map(|digest| disclosures.get_index_of(digest))
473        .for_each(|idx| sub_disclosures.push(idx));
474      // Recursively check all object's property.
475      object.values().for_each(|value| {
476        let found_sub_disclosures = get_all_sub_disclosures(value, disclosures);
477        sub_disclosures.extend(found_sub_disclosures);
478      });
479    }
480    Value::Array(arr) => {
481      for value in arr.iter().filter(|value| value.is_object()) {
482        if let Some(idx) = value
483          .get(ARRAY_DIGEST_KEY)
484          .and_then(|value| value.as_str())
485          .and_then(|digest| disclosures.get_index_of(digest))
486        {
487          sub_disclosures.push(idx);
488        } else {
489          sub_disclosures.extend(get_all_sub_disclosures(value, disclosures));
490        }
491      }
492    }
493    _ => (),
494  }
495
496  sub_disclosures
497}
498
499#[cfg(test)]
500mod test {
501  use crate::SdJwt;
502  const SD_JWT: &str = "eyJhbGciOiAiRVMyNTYiLCAidHlwIjogImV4YW1wbGUrc2Qtand0In0.eyJfc2QiOiBbIkM5aW5wNllvUmFFWFI0Mjd6WUpQN1FyazFXSF84YmR3T0FfWVVyVW5HUVUiLCAiS3VldDF5QWEwSElRdlluT1ZkNTloY1ZpTzlVZzZKMmtTZnFZUkJlb3d2RSIsICJNTWxkT0ZGekIyZDB1bWxtcFRJYUdlcmhXZFVfUHBZZkx2S2hoX2ZfOWFZIiwgIlg2WkFZT0lJMnZQTjQwVjd4RXhad1Z3ejd5Um1MTmNWd3Q1REw4Ukx2NGciLCAiWTM0em1JbzBRTExPdGRNcFhHd2pCZ0x2cjE3eUVoaFlUMEZHb2ZSLWFJRSIsICJmeUdwMFdUd3dQdjJKRFFsbjFsU2lhZW9iWnNNV0ExMGJRNTk4OS05RFRzIiwgIm9tbUZBaWNWVDhMR0hDQjB1eXd4N2ZZdW8zTUhZS08xNWN6LVJaRVlNNVEiLCAiczBCS1lzTFd4UVFlVTh0VmxsdE03TUtzSVJUckVJYTFQa0ptcXhCQmY1VSJdLCAiaXNzIjogImh0dHBzOi8vaXNzdWVyLmV4YW1wbGUuY29tIiwgImlhdCI6IDE2ODMwMDAwMDAsICJleHAiOiAxODgzMDAwMDAwLCAiYWRkcmVzcyI6IHsiX3NkIjogWyI2YVVoelloWjdTSjFrVm1hZ1FBTzN1MkVUTjJDQzFhSGhlWnBLbmFGMF9FIiwgIkF6TGxGb2JrSjJ4aWF1cFJFUHlvSnotOS1OU2xkQjZDZ2pyN2ZVeW9IemciLCAiUHp6Y1Z1MHFiTXVCR1NqdWxmZXd6a2VzRDl6dXRPRXhuNUVXTndrclEtayIsICJiMkRrdzBqY0lGOXJHZzhfUEY4WmN2bmNXN3p3Wmo1cnlCV3ZYZnJwemVrIiwgImNQWUpISVo4VnUtZjlDQ3lWdWIyVWZnRWs4anZ2WGV6d0sxcF9KbmVlWFEiLCAiZ2xUM2hyU1U3ZlNXZ3dGNVVEWm1Xd0JUdzMyZ25VbGRJaGk4aEdWQ2FWNCIsICJydkpkNmlxNlQ1ZWptc0JNb0d3dU5YaDlxQUFGQVRBY2k0MG9pZEVlVnNBIiwgInVOSG9XWWhYc1poVkpDTkUyRHF5LXpxdDd0NjlnSkt5NVFhRnY3R3JNWDQiXX0sICJfc2RfYWxnIjogInNoYS0yNTYifQ.gR6rSL7urX79CNEvTQnP1MH5xthG11ucIV44SqKFZ4Pvlu_u16RfvXQd4k4CAIBZNKn2aTI18TfvFwV97gJFoA~WyJHMDJOU3JRZmpGWFE3SW8wOXN5YWpBIiwgInJlZ2lvbiIsICJcdTZlMmZcdTUzM2EiXQ~WyJsa2x4RjVqTVlsR1RQVW92TU5JdkNBIiwgImNvdW50cnkiLCAiSlAiXQ~";
503
504  #[test]
505  fn parse() {
506    let sd_jwt = SdJwt::parse(SD_JWT).unwrap();
507    assert_eq!(sd_jwt.disclosures.len(), 2);
508    assert!(sd_jwt.key_binding_jwt.is_none());
509  }
510
511  #[test]
512  fn round_trip_ser_des() {
513    let sd_jwt = SdJwt::parse(SD_JWT).unwrap();
514    assert_eq!(&sd_jwt.to_string(), SD_JWT);
515  }
516}