sd_jwt_rs/
lib.rs

1// Copyright (c) 2024 DSR Corporation, Denver, Colorado.
2// https://www.dsr-corporation.com
3// SPDX-License-Identifier: Apache-2.0
4
5use crate::error::Error;
6use crate::utils::{base64_hash, base64url_decode, jwt_payload_decode};
7
8use error::Result;
9use serde::{Deserialize, Serialize};
10use serde_json::{Map, Value};
11use strum::Display;
12use std::collections::HashMap;
13pub use {holder::SDJWTHolder, issuer::SDJWTIssuer, issuer::ClaimsForSelectiveDisclosureStrategy, verifier::SDJWTVerifier};
14
15mod disclosure;
16pub mod error;
17pub mod holder;
18pub mod issuer;
19pub mod utils;
20pub mod verifier;
21
22pub const DEFAULT_SIGNING_ALG: &str = "ES256";
23const SD_DIGESTS_KEY: &str = "_sd";
24const DIGEST_ALG_KEY: &str = "_sd_alg";
25pub const DEFAULT_DIGEST_ALG: &str = "sha-256";
26const SD_LIST_PREFIX: &str = "...";
27const _SD_JWT_TYP_HEADER: &str = "sd+jwt";
28const KB_JWT_TYP_HEADER: &str = "kb+jwt";
29const KB_DIGEST_KEY: &str = "sd_hash";
30pub const COMBINED_SERIALIZATION_FORMAT_SEPARATOR: &str = "~";
31const JWT_SEPARATOR: &str = ".";
32const CNF_KEY: &str = "cnf";
33const JWK_KEY: &str = "jwk";
34
35#[derive(Debug)]
36pub(crate) struct SDJWTHasSDClaimException(String);
37
38impl SDJWTHasSDClaimException {}
39
40/// SDJWTSerializationFormat is used to determine how an SD-JWT is serialized to String
41#[derive(Default, Clone, PartialEq, Debug, Display)]
42pub enum SDJWTSerializationFormat {
43    /// JSON-encoded representation
44    #[default]
45    JSON,
46    /// Base64-encoded representation
47    Compact,
48}
49
50#[derive(Default)]
51pub(crate) struct SDJWTCommon {
52    typ: Option<String>,
53    serialization_format: SDJWTSerializationFormat,
54    unverified_input_key_binding_jwt: Option<String>,
55    unverified_sd_jwt: Option<String>,
56    unverified_sd_jwt_json: Option<SDJWTJson>,
57    unverified_input_sd_jwt_payload: Option<Map<String, Value>>,
58    hash_to_decoded_disclosure: HashMap<String, Value>,
59    hash_to_disclosure: HashMap<String, String>,
60    input_disclosures: Vec<String>,
61    sign_alg: Option<String>,
62}
63
64#[derive(Default, Serialize, Deserialize, Clone, Eq, PartialEq, Debug)]
65pub struct SDJWTJson {
66    protected: String,
67    payload: String,
68    signature: String,
69    pub disclosures: Vec<String>,
70    pub kb_jwt: Option<String>,
71}
72
73// Define the SDJWTCommon struct to hold common properties.
74impl SDJWTCommon {
75    fn create_hash_mappings(&mut self) -> Result<()> {
76        self.hash_to_decoded_disclosure = HashMap::new();
77        self.hash_to_disclosure = HashMap::new();
78
79        for disclosure in &self.input_disclosures {
80            let decoded_disclosure = base64url_decode(disclosure).map_err(|err| {
81                Error::InvalidDisclosure(format!(
82                    "Error decoding disclosure {}: {}",
83                    disclosure, err
84                ))
85            })?;
86            let decoded_disclosure: Value =
87                serde_json::from_slice(&decoded_disclosure).map_err(|err| {
88                    Error::InvalidDisclosure(format!(
89                        "Error parsing disclosure {}: {}",
90                        disclosure, err
91                    ))
92                })?;
93
94            let hash = base64_hash(disclosure.as_bytes());
95            if self.hash_to_decoded_disclosure.contains_key(&hash) {
96                return Err(Error::DuplicateDigestError(hash));
97            }
98            self.hash_to_decoded_disclosure
99                .insert(hash.clone(), decoded_disclosure);
100            self.hash_to_disclosure
101                .insert(hash.clone(), disclosure.to_owned());
102        }
103
104        Ok(())
105    }
106
107    fn check_for_sd_claim(the_object: &Value) -> Result<()> {
108        match the_object {
109            Value::Object(obj) => {
110                for (key, value) in obj.iter() {
111                    if key == SD_DIGESTS_KEY {
112                        return Err(Error::DataFieldMismatch(format!(
113                            "Claim object cannot have `{}` field",
114                            SD_DIGESTS_KEY
115                        )));
116                    } else {
117                        Self::check_for_sd_claim(value)?;
118                    }
119                }
120            }
121            Value::Array(arr) => {
122                for item in arr {
123                    Self::check_for_sd_claim(item)?;
124                }
125            }
126            _ => {}
127        }
128
129        Ok(())
130    }
131
132    fn parse_compact_sd_jwt(&mut self, sd_jwt_with_disclosures: String) -> Result<()> {
133        let parts: Vec<&str> = sd_jwt_with_disclosures
134            .split(COMBINED_SERIALIZATION_FORMAT_SEPARATOR)
135            .collect();
136        if parts.len() < 2 { // minimal number of SD-JWT parts according to the standard
137            return Err(Error::InvalidInput(format!(
138                "Invalid SD-JWT length: {}",
139                parts.len()
140            )));
141        }
142        let idx = parts.len();
143        let mut parts = parts.into_iter();
144        let sd_jwt = parts.next().ok_or(Error::IndexOutOfBounds {
145            idx: 0,
146            length: parts.len(),
147            msg: format!("Invalid SD-JWT: {}", sd_jwt_with_disclosures),
148        })?;
149        self.sign_alg = Self::decode_header_and_get_sign_algorithm(&sd_jwt);
150        self.unverified_input_key_binding_jwt = Some(
151            parts
152                .next_back()
153                .ok_or(Error::IndexOutOfBounds {
154                    idx: idx - 1,
155                    length: idx,
156                    msg: format!(
157                        "Invalid SD-JWT. Key binding not found: {}",
158                        sd_jwt_with_disclosures
159                    ),
160                })?
161                .to_owned(),
162        );
163        self.input_disclosures = parts.map(str::to_owned).collect();
164        self.unverified_sd_jwt = Some(sd_jwt.to_owned());
165
166        let mut sd_jwt = sd_jwt.split(JWT_SEPARATOR);
167        sd_jwt.next();
168        let jwt_body = sd_jwt.next().ok_or(Error::IndexOutOfBounds {
169            idx: 1,
170            length: 3,
171            msg: format!(
172                "Invalid JWT: Cannot extract JWT payload: {}",
173                self.unverified_sd_jwt.to_owned().unwrap_or("".to_string())
174            ),
175        })?;
176        self.unverified_input_sd_jwt_payload = Some(jwt_payload_decode(jwt_body)?);
177        Ok(())
178    }
179
180    fn parse_json_sd_jwt(&mut self, sd_jwt_with_disclosures: String) -> Result<()> {
181        let parsed_sd_jwt_json: SDJWTJson = serde_json::from_str(&sd_jwt_with_disclosures)
182            .map_err(|e| Error::DeserializationError(e.to_string()))?;
183        self.unverified_sd_jwt_json = Some(parsed_sd_jwt_json.clone());
184        self.unverified_input_key_binding_jwt = parsed_sd_jwt_json.kb_jwt;
185        self.input_disclosures = parsed_sd_jwt_json.disclosures;
186        self.unverified_input_sd_jwt_payload =
187            Some(jwt_payload_decode(&parsed_sd_jwt_json.payload)?);
188        let sd_jwt = format!(
189            "{}.{}.{}",
190            parsed_sd_jwt_json.protected,
191            parsed_sd_jwt_json.payload,
192            parsed_sd_jwt_json.signature
193        );    
194        self.unverified_sd_jwt = Some(sd_jwt.clone());
195        self.sign_alg = Self::decode_header_and_get_sign_algorithm(&sd_jwt);    
196        Ok(())
197    }
198
199    fn parse_sd_jwt(&mut self, sd_jwt_with_disclosures: String) -> Result<()> {
200        match self.serialization_format {
201            SDJWTSerializationFormat::Compact => {
202                self.parse_compact_sd_jwt(sd_jwt_with_disclosures)
203            }
204            SDJWTSerializationFormat::JSON => {
205                self.parse_json_sd_jwt(sd_jwt_with_disclosures)
206            }
207        }
208    }
209    /// Decodes a header jwt string and extracts the "alg" field from the JSON object.
210    /// # Arguments
211    /// * `sd_jwt` - jwt format string.
212    /// # Returns
213    /// * `Option<String>` - The result containing the algorithm String e.g ES256 or on failure None.
214    fn decode_header_and_get_sign_algorithm(sd_jwt: &str) -> Option<String> {
215        let parts: Vec<&str> = sd_jwt.split('.').collect();
216        if parts.len() < 2 {
217            return None;
218        }
219        let jwt_header = parts[0];
220        let decoded = base64url_decode(jwt_header).ok()?;
221        let decoded_str = std::str::from_utf8(&decoded).ok()?;
222        let json_sign_alg: Value = serde_json::from_str(decoded_str).ok()?;
223        let sign_alg = json_sign_alg.get("alg")
224            .and_then(Value::as_str)
225            .map(String::from);
226        sign_alg
227    }
228}
229
230#[cfg(test)]
231mod tests {
232    use crate::{utils, SDJWTCommon};
233
234
235    #[test]
236    fn test_parse_compact_sd_jwt(){
237        let mut sdjwt = SDJWTCommon::default();
238        let encoded_empty_object = utils::base64url_encode("{}".as_bytes());
239        sdjwt.parse_compact_sd_jwt(format!("jwt1.{encoded_empty_object}.jwt3~disc1~disc2~kbjwt")).unwrap();
240        assert_eq!(sdjwt.unverified_sd_jwt.unwrap(), format!("jwt1.{encoded_empty_object}.jwt3"));
241        assert_eq!(sdjwt.unverified_input_key_binding_jwt.unwrap(), "kbjwt");
242        assert_eq!(sdjwt.input_disclosures, vec!["disc1".to_string(), "disc2".to_string()]);
243    }
244
245    #[test]
246    fn test_parse_json_sd_jwt() {
247        let mut sdjwt = SDJWTCommon::default();
248        let encoded_empty_object = utils::base64url_encode("{}".as_bytes());
249        sdjwt.parse_json_sd_jwt(format!(
250            "{{\"protected\":\"jwt1\",\"payload\":\"{encoded_empty_object}\",\"signature\":\"jwt3\",\"disclosures\":[\"disc1\",\"disc2\"],\"kb_jwt\":\"kbjwt\"}}"
251        )).unwrap();
252        assert_eq!(sdjwt.unverified_sd_jwt.unwrap(), format!("jwt1.{encoded_empty_object}.jwt3"));
253        assert_eq!(sdjwt.unverified_input_key_binding_jwt.unwrap(), "kbjwt");
254        assert_eq!(sdjwt.input_disclosures, vec!["disc1".to_string(), "disc2".to_string()]);
255    }
256}