1use crate::error::Error;
6use crate::utils::{base64_hash, base64url_decode, jwt_payload_decode};
7
8use error::Result;
9use serde::{Deserialize, Serialize};
10use serde_json::{Map, Value};
11use strum::Display;
12use std::collections::HashMap;
13pub use {holder::SDJWTHolder, issuer::SDJWTIssuer, issuer::ClaimsForSelectiveDisclosureStrategy, verifier::SDJWTVerifier};
14
15mod disclosure;
16pub mod error;
17pub mod holder;
18pub mod issuer;
19pub mod utils;
20pub mod verifier;
21
22pub const DEFAULT_SIGNING_ALG: &str = "ES256";
23const SD_DIGESTS_KEY: &str = "_sd";
24const DIGEST_ALG_KEY: &str = "_sd_alg";
25pub const DEFAULT_DIGEST_ALG: &str = "sha-256";
26const SD_LIST_PREFIX: &str = "...";
27const _SD_JWT_TYP_HEADER: &str = "sd+jwt";
28const KB_JWT_TYP_HEADER: &str = "kb+jwt";
29const KB_DIGEST_KEY: &str = "sd_hash";
30pub const COMBINED_SERIALIZATION_FORMAT_SEPARATOR: &str = "~";
31const JWT_SEPARATOR: &str = ".";
32const CNF_KEY: &str = "cnf";
33const JWK_KEY: &str = "jwk";
34
35#[derive(Debug)]
36pub(crate) struct SDJWTHasSDClaimException(String);
37
38impl SDJWTHasSDClaimException {}
39
40#[derive(Default, Clone, PartialEq, Debug, Display)]
42pub enum SDJWTSerializationFormat {
43 #[default]
45 JSON,
46 Compact,
48}
49
50#[derive(Default)]
51pub(crate) struct SDJWTCommon {
52 typ: Option<String>,
53 serialization_format: SDJWTSerializationFormat,
54 unverified_input_key_binding_jwt: Option<String>,
55 unverified_sd_jwt: Option<String>,
56 unverified_sd_jwt_json: Option<SDJWTJson>,
57 unverified_input_sd_jwt_payload: Option<Map<String, Value>>,
58 hash_to_decoded_disclosure: HashMap<String, Value>,
59 hash_to_disclosure: HashMap<String, String>,
60 input_disclosures: Vec<String>,
61 sign_alg: Option<String>,
62}
63
64#[derive(Default, Serialize, Deserialize, Clone, Eq, PartialEq, Debug)]
65pub struct SDJWTJson {
66 protected: String,
67 payload: String,
68 signature: String,
69 pub disclosures: Vec<String>,
70 pub kb_jwt: Option<String>,
71}
72
73impl SDJWTCommon {
75 fn create_hash_mappings(&mut self) -> Result<()> {
76 self.hash_to_decoded_disclosure = HashMap::new();
77 self.hash_to_disclosure = HashMap::new();
78
79 for disclosure in &self.input_disclosures {
80 let decoded_disclosure = base64url_decode(disclosure).map_err(|err| {
81 Error::InvalidDisclosure(format!(
82 "Error decoding disclosure {}: {}",
83 disclosure, err
84 ))
85 })?;
86 let decoded_disclosure: Value =
87 serde_json::from_slice(&decoded_disclosure).map_err(|err| {
88 Error::InvalidDisclosure(format!(
89 "Error parsing disclosure {}: {}",
90 disclosure, err
91 ))
92 })?;
93
94 let hash = base64_hash(disclosure.as_bytes());
95 if self.hash_to_decoded_disclosure.contains_key(&hash) {
96 return Err(Error::DuplicateDigestError(hash));
97 }
98 self.hash_to_decoded_disclosure
99 .insert(hash.clone(), decoded_disclosure);
100 self.hash_to_disclosure
101 .insert(hash.clone(), disclosure.to_owned());
102 }
103
104 Ok(())
105 }
106
107 fn check_for_sd_claim(the_object: &Value) -> Result<()> {
108 match the_object {
109 Value::Object(obj) => {
110 for (key, value) in obj.iter() {
111 if key == SD_DIGESTS_KEY {
112 return Err(Error::DataFieldMismatch(format!(
113 "Claim object cannot have `{}` field",
114 SD_DIGESTS_KEY
115 )));
116 } else {
117 Self::check_for_sd_claim(value)?;
118 }
119 }
120 }
121 Value::Array(arr) => {
122 for item in arr {
123 Self::check_for_sd_claim(item)?;
124 }
125 }
126 _ => {}
127 }
128
129 Ok(())
130 }
131
132 fn parse_compact_sd_jwt(&mut self, sd_jwt_with_disclosures: String) -> Result<()> {
133 let parts: Vec<&str> = sd_jwt_with_disclosures
134 .split(COMBINED_SERIALIZATION_FORMAT_SEPARATOR)
135 .collect();
136 if parts.len() < 2 { return Err(Error::InvalidInput(format!(
138 "Invalid SD-JWT length: {}",
139 parts.len()
140 )));
141 }
142 let idx = parts.len();
143 let mut parts = parts.into_iter();
144 let sd_jwt = parts.next().ok_or(Error::IndexOutOfBounds {
145 idx: 0,
146 length: parts.len(),
147 msg: format!("Invalid SD-JWT: {}", sd_jwt_with_disclosures),
148 })?;
149 self.sign_alg = Self::decode_header_and_get_sign_algorithm(&sd_jwt);
150 self.unverified_input_key_binding_jwt = Some(
151 parts
152 .next_back()
153 .ok_or(Error::IndexOutOfBounds {
154 idx: idx - 1,
155 length: idx,
156 msg: format!(
157 "Invalid SD-JWT. Key binding not found: {}",
158 sd_jwt_with_disclosures
159 ),
160 })?
161 .to_owned(),
162 );
163 self.input_disclosures = parts.map(str::to_owned).collect();
164 self.unverified_sd_jwt = Some(sd_jwt.to_owned());
165
166 let mut sd_jwt = sd_jwt.split(JWT_SEPARATOR);
167 sd_jwt.next();
168 let jwt_body = sd_jwt.next().ok_or(Error::IndexOutOfBounds {
169 idx: 1,
170 length: 3,
171 msg: format!(
172 "Invalid JWT: Cannot extract JWT payload: {}",
173 self.unverified_sd_jwt.to_owned().unwrap_or("".to_string())
174 ),
175 })?;
176 self.unverified_input_sd_jwt_payload = Some(jwt_payload_decode(jwt_body)?);
177 Ok(())
178 }
179
180 fn parse_json_sd_jwt(&mut self, sd_jwt_with_disclosures: String) -> Result<()> {
181 let parsed_sd_jwt_json: SDJWTJson = serde_json::from_str(&sd_jwt_with_disclosures)
182 .map_err(|e| Error::DeserializationError(e.to_string()))?;
183 self.unverified_sd_jwt_json = Some(parsed_sd_jwt_json.clone());
184 self.unverified_input_key_binding_jwt = parsed_sd_jwt_json.kb_jwt;
185 self.input_disclosures = parsed_sd_jwt_json.disclosures;
186 self.unverified_input_sd_jwt_payload =
187 Some(jwt_payload_decode(&parsed_sd_jwt_json.payload)?);
188 let sd_jwt = format!(
189 "{}.{}.{}",
190 parsed_sd_jwt_json.protected,
191 parsed_sd_jwt_json.payload,
192 parsed_sd_jwt_json.signature
193 );
194 self.unverified_sd_jwt = Some(sd_jwt.clone());
195 self.sign_alg = Self::decode_header_and_get_sign_algorithm(&sd_jwt);
196 Ok(())
197 }
198
199 fn parse_sd_jwt(&mut self, sd_jwt_with_disclosures: String) -> Result<()> {
200 match self.serialization_format {
201 SDJWTSerializationFormat::Compact => {
202 self.parse_compact_sd_jwt(sd_jwt_with_disclosures)
203 }
204 SDJWTSerializationFormat::JSON => {
205 self.parse_json_sd_jwt(sd_jwt_with_disclosures)
206 }
207 }
208 }
209 fn decode_header_and_get_sign_algorithm(sd_jwt: &str) -> Option<String> {
215 let parts: Vec<&str> = sd_jwt.split('.').collect();
216 if parts.len() < 2 {
217 return None;
218 }
219 let jwt_header = parts[0];
220 let decoded = base64url_decode(jwt_header).ok()?;
221 let decoded_str = std::str::from_utf8(&decoded).ok()?;
222 let json_sign_alg: Value = serde_json::from_str(decoded_str).ok()?;
223 let sign_alg = json_sign_alg.get("alg")
224 .and_then(Value::as_str)
225 .map(String::from);
226 sign_alg
227 }
228}
229
230#[cfg(test)]
231mod tests {
232 use crate::{utils, SDJWTCommon};
233
234
235 #[test]
236 fn test_parse_compact_sd_jwt(){
237 let mut sdjwt = SDJWTCommon::default();
238 let encoded_empty_object = utils::base64url_encode("{}".as_bytes());
239 sdjwt.parse_compact_sd_jwt(format!("jwt1.{encoded_empty_object}.jwt3~disc1~disc2~kbjwt")).unwrap();
240 assert_eq!(sdjwt.unverified_sd_jwt.unwrap(), format!("jwt1.{encoded_empty_object}.jwt3"));
241 assert_eq!(sdjwt.unverified_input_key_binding_jwt.unwrap(), "kbjwt");
242 assert_eq!(sdjwt.input_disclosures, vec!["disc1".to_string(), "disc2".to_string()]);
243 }
244
245 #[test]
246 fn test_parse_json_sd_jwt() {
247 let mut sdjwt = SDJWTCommon::default();
248 let encoded_empty_object = utils::base64url_encode("{}".as_bytes());
249 sdjwt.parse_json_sd_jwt(format!(
250 "{{\"protected\":\"jwt1\",\"payload\":\"{encoded_empty_object}\",\"signature\":\"jwt3\",\"disclosures\":[\"disc1\",\"disc2\"],\"kb_jwt\":\"kbjwt\"}}"
251 )).unwrap();
252 assert_eq!(sdjwt.unverified_sd_jwt.unwrap(), format!("jwt1.{encoded_empty_object}.jwt3"));
253 assert_eq!(sdjwt.unverified_input_key_binding_jwt.unwrap(), "kbjwt");
254 assert_eq!(sdjwt.input_disclosures, vec!["disc1".to_string(), "disc2".to_string()]);
255 }
256}