sd_jwt_rs/
issuer.rs

1// Copyright (c) 2024 DSR Corporation, Denver, Colorado.
2// https://www.dsr-corporation.com
3// SPDX-License-Identifier: Apache-2.0
4
5use crate::{error, SDJWTJson};
6use error::Result;
7use std::collections::{HashMap, VecDeque};
8use std::str::FromStr;
9use std::vec::Vec;
10
11use jsonwebtoken::jwk::Jwk;
12use jsonwebtoken::{Algorithm, EncodingKey, Header};
13use rand::Rng;
14use serde_json::Value;
15use serde_json::{json, Map as SJMap, Map};
16
17use crate::disclosure::SDJWTDisclosure;
18use crate::error::Error;
19use crate::utils::{base64_hash, generate_salt};
20use crate::{
21    SDJWTCommon, CNF_KEY, COMBINED_SERIALIZATION_FORMAT_SEPARATOR, DEFAULT_DIGEST_ALG,
22    DEFAULT_SIGNING_ALG, DIGEST_ALG_KEY, JWK_KEY, SD_DIGESTS_KEY, SD_LIST_PREFIX,
23    SDJWTSerializationFormat,
24};
25
26pub struct SDJWTIssuer {
27    // parameters
28    sign_alg: String,
29    add_decoy_claims: bool,
30    extra_header_parameters: Option<HashMap<String, String>>,
31
32    // input data
33    issuer_key: EncodingKey,
34    holder_key: Option<Jwk>,
35
36    // internal fields
37    inner: SDJWTCommon,
38    all_disclosures: Vec<SDJWTDisclosure>,
39    sd_jwt_payload: SJMap<String, Value>,
40    signed_sd_jwt: String,
41    serialized_sd_jwt: String,
42}
43
44/// ClaimsForSelectiveDisclosureStrategy is used to determine which claims can be selectively disclosed later by the holder.
45#[derive(PartialEq, Debug)]
46pub enum ClaimsForSelectiveDisclosureStrategy<'a> {
47    /// No claims can be selectively disclosed, so all claims are always disclosed in presentations generated by the holder.
48    NoSDClaims,
49    /// Top-level claims can be selectively disclosed, nested objects are fully disclosed, if a parent claim is disclosed.
50    TopLevel,
51    /// All claims can be selectively disclosed (recursively including nested objects).
52    AllLevels,
53    /// Claims can be selectively disclosed based on the provided JSONPaths.
54    /// Other claims are always disclosed in presentation generated by the holder.
55    /// # Examples
56    /// ```
57    /// use sd_jwt_rs::issuer::ClaimsForSelectiveDisclosureStrategy;
58    ///
59    /// let strategy = ClaimsForSelectiveDisclosureStrategy::Custom(vec!["$.address", "$.address.street_address"]);
60    /// ```
61    Custom(Vec<&'a str>),
62}
63
64impl<'a> ClaimsForSelectiveDisclosureStrategy<'a> {
65    fn finalize_input(&mut self) -> Result<()> {
66        match self {
67            ClaimsForSelectiveDisclosureStrategy::Custom(keys) => {
68                for key in keys.iter_mut() {
69                    if let Some(new_key) = key.strip_prefix("$.") {
70                        *key = new_key;
71                    } else {
72                        return Err(Error::InvalidPath("Invalid JSONPath".to_owned()));
73                    }
74                }
75                Ok(())
76            }
77            _ => Ok(()),
78        }
79    }
80
81    fn next_level(&self, key: &str) -> Self {
82        match self {
83            Self::NoSDClaims => Self::NoSDClaims,
84            Self::TopLevel => Self::NoSDClaims,
85            Self::AllLevels => Self::AllLevels,
86            Self::Custom(sd_keys) => {
87                let next_sd_keys = sd_keys
88                    .iter()
89                    .filter_map(|str| {
90                        str.strip_prefix(key).and_then(|str|
91                            match str.chars().next() {
92                                Some('.') => Some(&str[1..]), // next token
93                                Some('[') => Some(str),       // array index
94                                _ => None,
95                            }
96                        )
97                    })
98                    .collect();
99                Self::Custom(next_sd_keys)
100            }
101        }
102    }
103
104    fn sd_for_key(&self, key: &str) -> bool {
105        match self {
106            Self::NoSDClaims => false,
107            Self::TopLevel => true,
108            Self::AllLevels => true,
109            Self::Custom(sd_keys) => sd_keys.contains(&key),
110        }
111    }
112}
113
114impl SDJWTIssuer {
115    const DECOY_MIN_ELEMENTS: u32 = 2;
116    const DECOY_MAX_ELEMENTS: u32 = 5;
117
118    /// Creates a new SDJWTIssuer instance.
119    ///
120    /// The instance can be used mutliple times to issue SD-JWTs.
121    ///
122    /// # Arguments
123    /// * `issuer_key` - The key used to sign the SD-JWT.
124    /// * `sign_alg` - The signing algorithm used to sign the SD-JWT. If not provided, the default algorithm is used.
125    ///
126    /// # Returns
127    /// A new SDJWTIssuer instance.
128    pub fn new(issuer_key: EncodingKey, sign_alg: Option<String>) -> Self {
129        SDJWTIssuer {
130            sign_alg: sign_alg.unwrap_or(DEFAULT_SIGNING_ALG.to_owned()),
131            add_decoy_claims: false,
132            extra_header_parameters: None,
133            issuer_key,
134            holder_key: None,
135            inner: Default::default(),
136            all_disclosures: vec![],
137            sd_jwt_payload: Default::default(),
138            signed_sd_jwt: "".to_string(),
139            serialized_sd_jwt: "".to_string(),
140        }
141    }
142
143    fn reset(&mut self) {
144        self.extra_header_parameters = Default::default();
145        self.all_disclosures = Default::default();
146        self.sd_jwt_payload = Default::default();
147        self.signed_sd_jwt = Default::default();
148        self.serialized_sd_jwt = Default::default();
149    }
150
151    /// Issues a SD-JWT.
152    ///
153    /// # Arguments
154    /// * `user_claims` - The claims to be included in the SD-JWT.
155    /// * `sd_strategy` - The strategy to be used to determine which claims to be selectively disclosed. See [ClaimsForSelectiveDisclosureStrategy] for more details.
156    /// * `holder_key` - The key used to sign the SD-JWT. If not provided, no key binding is added to the SD-JWT.
157    /// * `add_decoy_claims` - If true, decoy claims are added to the SD-JWT.
158    /// * `serialization_format` - The serialization format to be used for the SD-JWT, see [SDJWTSerializationFormat].
159    ///
160    /// # Returns
161    /// The issued SD-JWT as a string in the requested serialization format.
162    pub fn issue_sd_jwt(
163        &mut self,
164        user_claims: Value,
165        mut sd_strategy: ClaimsForSelectiveDisclosureStrategy,
166        holder_key: Option<Jwk>,
167        add_decoy_claims: bool,
168        serialization_format: SDJWTSerializationFormat,
169        // extra_header_parameters: Option<HashMap<String, String>>,
170    ) -> Result<String> {
171        let inner = SDJWTCommon {
172            serialization_format,
173            ..Default::default()
174        };
175
176        sd_strategy.finalize_input()?;
177
178        SDJWTCommon::check_for_sd_claim(&user_claims)?;
179
180        self.reset();
181        self.inner = inner;
182        self.holder_key = holder_key;
183        self.add_decoy_claims = add_decoy_claims;
184
185        self.assemble_sd_jwt_payload(user_claims, sd_strategy)?;
186        self.create_signed_jws()?;
187        self.create_combined()?;
188
189        Ok(self.serialized_sd_jwt.clone())
190    }
191
192    fn assemble_sd_jwt_payload(
193        &mut self,
194        mut user_claims: Value,
195        sd_strategy: ClaimsForSelectiveDisclosureStrategy,
196    ) -> Result<()> {
197        let claims_obj_ref = user_claims
198            .as_object_mut()
199            .ok_or(Error::ConversionError("json object".to_string()))?;
200        let always_revealed_root_keys = vec!["iss", "iat", "exp"];
201        let mut always_revealed_claims: Map<String, Value> = always_revealed_root_keys
202            .into_iter()
203            .filter_map(|key| claims_obj_ref.shift_remove_entry(key))
204            .collect();
205
206        self.sd_jwt_payload = self
207            .create_sd_claims(&user_claims, sd_strategy)
208            .as_object()
209            .ok_or(Error::ConversionError("json object".to_string()))?
210            .clone();
211
212        self.sd_jwt_payload.insert(
213            DIGEST_ALG_KEY.to_owned(),
214            Value::String(DEFAULT_DIGEST_ALG.to_owned()),
215        ); //TODO
216        self.sd_jwt_payload.append(&mut always_revealed_claims);
217
218        if let Some(holder_key) = &self.holder_key {
219            self.sd_jwt_payload
220                .entry(CNF_KEY)
221                .or_insert_with(|| json!({JWK_KEY: holder_key}));
222        }
223
224        Ok(())
225    }
226
227    fn create_sd_claims(&mut self, user_claims: &Value, sd_strategy: ClaimsForSelectiveDisclosureStrategy) -> Value {
228        match user_claims {
229            Value::Array(list) => self.create_sd_claims_list(list, sd_strategy),
230            Value::Object(object) => self.create_sd_claims_object(object, sd_strategy),
231            _ => user_claims.to_owned(),
232        }
233    }
234
235    fn create_sd_claims_list(&mut self, list: &[Value], sd_strategy: ClaimsForSelectiveDisclosureStrategy) -> Value {
236        let mut claims = Vec::new();
237        for (idx, object) in list.iter().enumerate() {
238            let key = format!("[{idx}]");
239            let strategy_for_child = sd_strategy.next_level(&key);
240            let subtree = self.create_sd_claims(object, strategy_for_child);
241
242            if sd_strategy.sd_for_key(&key) {
243                let disclosure = SDJWTDisclosure::new(None, subtree);
244                claims.push(json!({ SD_LIST_PREFIX: disclosure.hash}));
245                self.all_disclosures.push(disclosure);
246            } else {
247                claims.push(subtree);
248            }
249        }
250        Value::Array(claims)
251    }
252
253    fn create_sd_claims_object(
254        &mut self,
255        user_claims: &SJMap<String, Value>,
256        sd_strategy: ClaimsForSelectiveDisclosureStrategy,
257    ) -> Value {
258        let mut claims = SJMap::new();
259
260        // to have the first key "_sd" in the ordered map
261        claims.insert(SD_DIGESTS_KEY.to_owned(), Value::Null);
262
263        let mut sd_claims = Vec::new();
264
265        for (key, value) in user_claims.iter() {
266            let strategy_for_child = sd_strategy.next_level(key);
267            let subtree_from_here = self.create_sd_claims(value, strategy_for_child);
268
269            if sd_strategy.sd_for_key(key) {
270                let disclosure = SDJWTDisclosure::new(Some(key.to_owned()), subtree_from_here);
271                sd_claims.push(disclosure.hash.clone());
272                self.all_disclosures.push(disclosure);
273            } else {
274                claims.insert(key.to_owned(), subtree_from_here);
275            }
276        }
277
278        if self.add_decoy_claims {
279            let num_decoy_elements =
280                rand::thread_rng().gen_range(Self::DECOY_MIN_ELEMENTS..Self::DECOY_MAX_ELEMENTS);
281            for _ in 0..num_decoy_elements {
282                sd_claims.push(self.create_decoy_claim_entry());
283            }
284        }
285
286        if !sd_claims.is_empty() {
287            sd_claims.sort();
288            claims.insert(
289                SD_DIGESTS_KEY.to_owned(),
290                Value::Array(sd_claims.into_iter().map(Value::String).collect()),
291            );
292        } else {
293            claims.shift_remove(SD_DIGESTS_KEY);
294        }
295
296        Value::Object(claims)
297    }
298
299    fn create_signed_jws(&mut self) -> Result<()> {
300        if let Some(extra_headers) = &self.extra_header_parameters {
301            let mut _protected_headers = extra_headers.clone();
302            for (key, value) in extra_headers.iter() {
303                _protected_headers.insert(key.to_string(), value.to_string());
304            }
305            unimplemented!("extra_headers are not supported for issuance");
306        }
307
308        let mut header = Header::new(
309            Algorithm::from_str(&self.sign_alg)
310                .map_err(|e| Error::DeserializationError(e.to_string()))?,
311        );
312        header.typ = self.inner.typ.clone();
313        self.signed_sd_jwt = jsonwebtoken::encode(&header, &self.sd_jwt_payload, &self.issuer_key)
314            .map_err(|e| Error::DeserializationError(e.to_string()))?;
315
316        Ok(())
317    }
318
319    fn create_combined(&mut self) -> Result<()> {
320        if self.inner.serialization_format == SDJWTSerializationFormat::Compact {
321            let mut disclosures: VecDeque<String> = self
322                .all_disclosures
323                .iter()
324                .map(|d| d.raw_b64.to_string())
325                .collect();
326            disclosures.push_front(self.signed_sd_jwt.clone());
327
328            let disclosures: Vec<&str> = disclosures.iter().map(|s| s.as_str()).collect();
329
330            self.serialized_sd_jwt = format!(
331                "{}{}",
332                disclosures.join(COMBINED_SERIALIZATION_FORMAT_SEPARATOR),
333                COMBINED_SERIALIZATION_FORMAT_SEPARATOR,
334            );
335        } else if self.inner.serialization_format == SDJWTSerializationFormat::JSON {
336            let jwt: Vec<&str> = self.signed_sd_jwt.split('.').collect();
337            if jwt.len() != 3 {
338                return Err(Error::InvalidInput(format!(
339                    "Invalid JWT, JWT must contain three parts after splitting with \".\": jwt {}",
340                    self.signed_sd_jwt
341                )));
342            }
343            let sd_jwt_json = SDJWTJson {
344                protected: jwt[0].to_owned(),
345                payload: jwt[1].to_owned(),
346                signature: jwt[2].to_owned(),
347                kb_jwt: None,
348                disclosures: self
349                    .all_disclosures
350                    .iter()
351                    .map(|d| d.raw_b64.to_string())
352                    .collect(),
353            };
354            self.serialized_sd_jwt = serde_json::to_string(&sd_jwt_json)
355                .map_err(|e| Error::DeserializationError(e.to_string()))?;
356        } else {
357            return Err(Error::InvalidInput(
358                format!("Unknown serialization format {}, only \"Compact\" or \"JSON\" formats are supported", self.inner.serialization_format)
359            ));
360        }
361
362        Ok(())
363    }
364
365    fn create_decoy_claim_entry(&mut self) -> String {
366        let digest = base64_hash(generate_salt().as_bytes()).to_string();
367        digest
368    }
369}
370
371#[cfg(test)]
372mod tests {
373    use jsonwebtoken::EncodingKey;
374    use log::trace;
375    use serde_json::json;
376
377    use crate::issuer::ClaimsForSelectiveDisclosureStrategy;
378    use crate::{SDJWTIssuer, SDJWTSerializationFormat};
379
380    const PRIVATE_ISSUER_PEM: &str = "-----BEGIN PRIVATE KEY-----\nMIGHAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBG0wawIBAQQgUr2bNKuBPOrAaxsR\nnbSH6hIhmNTxSGXshDSUD1a1y7ihRANCAARvbx3gzBkyPDz7TQIbjF+ef1IsxUwz\nX1KWpmlVv+421F7+c1sLqGk4HUuoVeN8iOoAcE547pJhUEJyf5Asc6pP\n-----END PRIVATE KEY-----\n";
381
382    #[test]
383    fn test_assembly_sd_full_recursive() {
384        let user_claims = json!({
385            "sub": "6c5c0a49-b589-431d-bae7-219122a9ec2c",
386            "iss": "https://example.com/issuer",
387            "iat": 1683000000,
388            "exp": 1883000000,
389            "address": {
390                "street_address": "Schulstr. 12",
391                "locality": "Schulpforta",
392                "region": "Sachsen-Anhalt",
393                "country": "DE"
394            }
395        });
396        let private_issuer_bytes = PRIVATE_ISSUER_PEM.as_bytes();
397        let issuer_key = EncodingKey::from_ec_pem(private_issuer_bytes).unwrap();
398        let sd_jwt = SDJWTIssuer::new(issuer_key, None).issue_sd_jwt(
399            user_claims,
400            ClaimsForSelectiveDisclosureStrategy::AllLevels,
401            None,
402            false,
403            SDJWTSerializationFormat::Compact,
404        )
405            .unwrap();
406        trace!("{:?}", sd_jwt)
407    }
408
409    #[test]
410    fn test_next_level_array() {
411        let strategy = ClaimsForSelectiveDisclosureStrategy::Custom(vec![
412            "name",
413            "addresses[1]",
414            "addresses[1].country",
415            "nationalities[0]",
416        ]);
417
418        let next_strategy = strategy.next_level("addresses");
419        assert_eq!(&next_strategy, &ClaimsForSelectiveDisclosureStrategy::Custom(vec!["[1]", "[1].country"]));
420        let next_strategy = next_strategy.next_level("[1]");
421        assert_eq!(&next_strategy, &ClaimsForSelectiveDisclosureStrategy::Custom(vec!["country"]));
422    }
423
424    #[test]
425    fn test_next_level_object() {
426        let strategy = ClaimsForSelectiveDisclosureStrategy::Custom(vec![
427            "address.street_address",
428            "address.locality",
429            "address.region",
430            "address.country",
431        ]);
432
433        let next_strategy = strategy.next_level("address");
434        assert_eq!(&next_strategy, &ClaimsForSelectiveDisclosureStrategy::Custom(vec![
435            "street_address",
436            "locality",
437            "region",
438            "country"
439        ]));
440    }
441}