sd_jwt_payload/
key_binding_jwt_claims.rs1use crate::jwt::Jwt;
5use crate::Error;
6use crate::Hasher;
7use crate::JsonObject;
8use crate::JwsSigner;
9use crate::SdJwt;
10use crate::SHA_ALG_NAME;
11use anyhow::Context as _;
12use serde::Deserialize;
13use serde::Serialize;
14use serde_json::Value;
15use std::borrow::Cow;
16use std::fmt::Display;
17use std::ops::Deref;
18use std::str::FromStr;
19
20pub const KB_JWT_HEADER_TYP: &str = "kb+jwt";
21
22#[derive(Debug, Clone, PartialEq, Eq)]
24pub struct KeyBindingJwt(Jwt<KeyBindingJwtClaims>);
25
26impl Display for KeyBindingJwt {
27 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
28 write!(f, "{}", &self.0)
29 }
30}
31
32impl FromStr for KeyBindingJwt {
33 type Err = Error;
34 fn from_str(s: &str) -> Result<Self, Self::Err> {
35 let jwt = Jwt::<KeyBindingJwtClaims>::from_str(s)?;
36 let valid_jwt_type = jwt.header.get("typ").is_some_and(|typ| typ == KB_JWT_HEADER_TYP);
37 if !valid_jwt_type {
38 return Err(Error::DeserializationError(format!(
39 "invalid KB-JWT: typ must be \"{KB_JWT_HEADER_TYP}\""
40 )));
41 }
42 let valid_alg = jwt.header.get("alg").is_some_and(|alg| alg != "none");
43 if !valid_alg {
44 return Err(Error::DeserializationError(
45 "invalid KB-JWT: alg must be set and cannot be \"none\"".to_string(),
46 ));
47 }
48
49 Ok(Self(jwt))
50 }
51}
52
53impl KeyBindingJwt {
54 pub fn builder() -> KeyBindingJwtBuilder {
56 KeyBindingJwtBuilder::default()
57 }
58 pub fn claims(&self) -> &KeyBindingJwtClaims {
60 &self.0.claims
61 }
62}
63
64#[derive(Debug, Default, Clone)]
66pub struct KeyBindingJwtBuilder {
67 header: JsonObject,
68 payload: JsonObject,
69}
70
71impl KeyBindingJwtBuilder {
72 pub fn new() -> Self {
74 Self::default()
75 }
76
77 pub fn from_object(object: JsonObject) -> Self {
79 Self {
80 header: JsonObject::default(),
81 payload: object,
82 }
83 }
84
85 pub fn header(mut self, header: JsonObject) -> Self {
87 self.header = header;
88 self
89 }
90
91 pub fn iat(mut self, iat: i64) -> Self {
93 self.payload.insert("iat".to_string(), iat.into());
94 self
95 }
96
97 pub fn aud<'a, S>(mut self, aud: S) -> Self
99 where
100 S: Into<Cow<'a, str>>,
101 {
102 self.payload.insert("aud".to_string(), aud.into().into_owned().into());
103 self
104 }
105
106 pub fn nonce<'a, S>(mut self, nonce: S) -> Self
108 where
109 S: Into<Cow<'a, str>>,
110 {
111 self
112 .payload
113 .insert("nonce".to_string(), nonce.into().into_owned().into());
114 self
115 }
116
117 pub fn insert_property(mut self, name: &str, value: Value) -> Self {
119 self.payload.insert(name.to_string(), value);
120 self
121 }
122
123 pub async fn finish<S>(
125 self,
126 sd_jwt: &SdJwt,
127 hasher: &dyn Hasher,
128 alg: &str,
129 signer: &S,
130 ) -> Result<KeyBindingJwt, Error>
131 where
132 S: JwsSigner,
133 {
134 let mut claims = self.payload;
135 if alg == "none" {
136 return Err(Error::DataTypeMismatch(
137 "A KeyBindingJwt cannot use algorithm \"none\"".to_string(),
138 ));
139 }
140 if sd_jwt.key_binding_jwt().is_some() {
141 return Err(Error::DataTypeMismatch(
142 "the provided SD-JWT already has a KB-JWT attached".to_string(),
143 ));
144 }
145 if sd_jwt.claims()._sd_alg.as_deref().unwrap_or(SHA_ALG_NAME) != hasher.alg_name() {
146 return Err(Error::InvalidHasher(format!(
147 "invalid hashing algorithm \"{}\"",
148 hasher.alg_name()
149 )));
150 }
151 let sd_hash = hasher.encoded_digest(&sd_jwt.to_string());
152 claims.insert("sd_hash".to_string(), sd_hash.into());
153
154 let mut header = self.header;
155 header.insert("alg".to_string(), alg.to_owned().into());
156 header
157 .entry("typ")
158 .or_insert_with(|| KB_JWT_HEADER_TYP.to_owned().into());
159
160 let parsed_claims = serde_json::from_value::<KeyBindingJwtClaims>(claims.clone().into())
162 .map_err(|e| Error::DeserializationError(format!("invalid KB-JWT claims: {e}")))?;
163 let jws = signer
164 .sign(&header, &claims)
165 .await
166 .map_err(|e| anyhow::anyhow!("{e}"))
167 .and_then(|jws_bytes| String::from_utf8(jws_bytes).context("invalid JWS"))
168 .map_err(|e| Error::JwsSignerFailure(e.to_string()))?;
169
170 Ok(KeyBindingJwt(Jwt {
171 header,
172 claims: parsed_claims,
173 jws,
174 }))
175 }
176}
177
178#[derive(Clone, Debug, Default, PartialEq, Eq, Deserialize, Serialize)]
180pub struct KeyBindingJwtClaims {
181 pub iat: i64,
182 pub aud: String,
183 pub nonce: String,
184 pub sd_hash: String,
185 #[serde(flatten)]
186 properties: JsonObject,
187}
188
189impl Deref for KeyBindingJwtClaims {
190 type Target = JsonObject;
191 fn deref(&self) -> &Self::Target {
192 &self.properties
193 }
194}
195
196#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)]
198#[serde(rename_all = "camelCase")]
199pub enum RequiredKeyBinding {
200 Jwk(JsonObject),
202 Jwe(String),
204 Kid(String),
206 Jwu {
208 jwu: String,
210 kid: String,
212 },
213 #[serde(untagged)]
215 Custom(Value),
216}