1use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
4use jsonwebtoken::{decode, Algorithm, DecodingKey, Validation};
5use serde::{Deserialize, Serialize};
6
7use crate::error::{Result, ValidationError};
8
9#[derive(Debug, Clone, Deserialize)]
10pub struct Jwk {
11 pub kty: String,
12 pub crv: String,
13 pub x: String,
14 pub y: String,
15 pub kid: String,
16}
17
18#[derive(Debug, Clone, Deserialize)]
19pub struct Jwks {
20 pub keys: Vec<Jwk>,
21}
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct JwtHeader {
25 pub alg: String,
26 pub typ: Option<String>,
27 pub kid: Option<String>,
28}
29
30#[derive(Debug, Clone, Serialize, Deserialize)]
32pub struct Attestation {
33 pub method: String,
35 #[serde(skip_serializing_if = "Option::is_none")]
37 pub app_id: Option<String>,
38}
39
40#[derive(Debug, Clone, Serialize, Deserialize)]
41pub struct CaptureTrustClaims {
42 pub iss: String,
43 pub aud: String,
44 pub sub: String,
45 pub iat: i64,
46 pub capture_id: String,
47 pub publisher_id: String,
48 pub device_id: String,
49 pub attestation: Attestation,
50}
51
52#[derive(Debug, Clone)]
53pub struct ParsedJwt {
54 pub header: JwtHeader,
55 pub claims: CaptureTrustClaims,
56 pub signature: String,
57}
58
59pub fn parse_jwt(token: &str) -> Result<ParsedJwt> {
60 let parts: Vec<&str> = token.split('.').collect();
61 if parts.len() != 3 {
62 return Err(ValidationError::InvalidJwt(
63 "JWT must have 3 parts separated by dots".to_string(),
64 ));
65 }
66
67 let header = decode_part::<JwtHeader>(parts[0], "header")?;
68 let claims = decode_part::<CaptureTrustClaims>(parts[1], "claims")?;
69 let signature = parts[2].to_string();
70
71 validate_header(&header)?;
72 validate_claims(&claims)?;
73
74 Ok(ParsedJwt {
75 header,
76 claims,
77 signature,
78 })
79}
80
81fn decode_part<T: for<'de> Deserialize<'de>>(encoded: &str, part_name: &str) -> Result<T> {
82 let bytes = URL_SAFE_NO_PAD.decode(encoded).map_err(|e| {
83 ValidationError::JwtDecodeError(format!("Failed to decode {}: {}", part_name, e))
84 })?;
85
86 serde_json::from_slice(&bytes).map_err(|e| {
87 ValidationError::JwtDecodeError(format!("Failed to parse {}: {}", part_name, e))
88 })
89}
90
91fn validate_header(header: &JwtHeader) -> Result<()> {
92 if header.alg != "ES256" {
93 return Err(ValidationError::InvalidJwt(format!(
94 "Expected algorithm ES256, got {}",
95 header.alg
96 )));
97 }
98 Ok(())
99}
100
101fn validate_claims(claims: &CaptureTrustClaims) -> Result<()> {
102 if claims.aud != "signedshot" {
103 return Err(ValidationError::InvalidJwt(format!(
104 "Expected audience 'signedshot', got '{}'",
105 claims.aud
106 )));
107 }
108
109 let valid_methods = ["sandbox", "app_check", "app_attest"];
110 if !valid_methods.contains(&claims.attestation.method.as_str()) {
111 return Err(ValidationError::InvalidJwt(format!(
112 "Invalid attestation method '{}', expected one of: {:?}",
113 claims.attestation.method, valid_methods
114 )));
115 }
116
117 Ok(())
118}
119
120pub fn fetch_jwks(issuer: &str) -> Result<Jwks> {
121 let url = format!("{}/.well-known/jwks.json", issuer.trim_end_matches('/'));
122
123 let response = reqwest::blocking::get(&url)
124 .map_err(|e| ValidationError::JwksFetchError(format!("HTTP request failed: {}", e)))?;
125
126 if !response.status().is_success() {
127 return Err(ValidationError::JwksFetchError(format!(
128 "HTTP {} from {}",
129 response.status(),
130 url
131 )));
132 }
133
134 response
135 .json::<Jwks>()
136 .map_err(|e| ValidationError::JwksFetchError(format!("Failed to parse JWKS: {}", e)))
137}
138
139pub fn parse_jwks_json(jwks_json: &str) -> Result<Jwks> {
143 serde_json::from_str(jwks_json)
144 .map_err(|e| ValidationError::JwksFetchError(format!("Failed to parse JWKS JSON: {}", e)))
145}
146
147pub fn verify_signature(token: &str, jwks: &Jwks, kid: &str) -> Result<()> {
148 let jwk = jwks
149 .keys
150 .iter()
151 .find(|k| k.kid == kid)
152 .ok_or_else(|| ValidationError::KeyNotFound(kid.to_string()))?;
153
154 let x_bytes = URL_SAFE_NO_PAD
155 .decode(&jwk.x)
156 .map_err(|e| ValidationError::SignatureError(format!("Invalid x coordinate: {}", e)))?;
157 let y_bytes = URL_SAFE_NO_PAD
158 .decode(&jwk.y)
159 .map_err(|e| ValidationError::SignatureError(format!("Invalid y coordinate: {}", e)))?;
160
161 let mut public_key = Vec::with_capacity(1 + x_bytes.len() + y_bytes.len());
162 public_key.push(0x04);
163 public_key.extend_from_slice(&x_bytes);
164 public_key.extend_from_slice(&y_bytes);
165
166 let decoding_key = DecodingKey::from_ec_der(&public_key);
167
168 let mut validation = Validation::new(Algorithm::ES256);
169 validation.set_audience(&["signedshot"]);
170 validation.validate_exp = false;
171 validation.set_required_spec_claims::<&str>(&[]);
172
173 decode::<CaptureTrustClaims>(token, &decoding_key, &validation)
174 .map_err(|e| ValidationError::SignatureError(format!("{}", e)))?;
175
176 Ok(())
177}
178
179#[cfg(test)]
180mod tests {
181 use super::*;
182
183 fn make_jwt(header: &str, payload: &str) -> String {
184 let h = URL_SAFE_NO_PAD.encode(header);
185 let p = URL_SAFE_NO_PAD.encode(payload);
186 format!("{}.{}.fake-signature", h, p)
187 }
188
189 #[test]
190 fn parse_valid_jwt() {
191 let header = r#"{"alg":"ES256","typ":"JWT","kid":"test-key"}"#;
192 let payload = r#"{"iss":"https://dev-api.signedshot.io","aud":"signedshot","sub":"capture-service","iat":1705312200,"capture_id":"123","publisher_id":"456","device_id":"789","attestation":{"method":"sandbox"}}"#;
193 let token = make_jwt(header, payload);
194
195 let parsed = parse_jwt(&token).unwrap();
196 assert_eq!(parsed.header.alg, "ES256");
197 assert_eq!(parsed.claims.capture_id, "123");
198 assert_eq!(parsed.claims.attestation.method, "sandbox");
199 assert_eq!(parsed.claims.attestation.app_id, None);
200 }
201
202 #[test]
203 fn parse_jwt_with_app_id() {
204 let header = r#"{"alg":"ES256","typ":"JWT","kid":"test-key"}"#;
205 let payload = r#"{"iss":"https://dev-api.signedshot.io","aud":"signedshot","sub":"capture-service","iat":1705312200,"capture_id":"123","publisher_id":"456","device_id":"789","attestation":{"method":"app_check","app_id":"io.foo.bar"}}"#;
206 let token = make_jwt(header, payload);
207
208 let parsed = parse_jwt(&token).unwrap();
209 assert_eq!(parsed.claims.attestation.method, "app_check");
210 assert_eq!(
211 parsed.claims.attestation.app_id,
212 Some("io.foo.bar".to_string())
213 );
214 }
215
216 #[test]
217 fn reject_invalid_algorithm() {
218 let header = r#"{"alg":"HS256","typ":"JWT"}"#;
219 let payload = r#"{"iss":"https://dev-api.signedshot.io","aud":"signedshot","sub":"capture-service","iat":1705312200,"capture_id":"123","publisher_id":"456","device_id":"789","attestation":{"method":"sandbox"}}"#;
220 let token = make_jwt(header, payload);
221
222 let result = parse_jwt(&token);
223 assert!(matches!(result, Err(ValidationError::InvalidJwt(_))));
224 }
225
226 #[test]
227 fn reject_invalid_audience() {
228 let header = r#"{"alg":"ES256","typ":"JWT"}"#;
229 let payload = r#"{"iss":"https://example.com","aud":"wrong","sub":"capture-service","iat":1705312200,"capture_id":"123","publisher_id":"456","device_id":"789","attestation":{"method":"sandbox"}}"#;
230 let token = make_jwt(header, payload);
231
232 let result = parse_jwt(&token);
233 assert!(matches!(result, Err(ValidationError::InvalidJwt(_))));
234 }
235
236 #[test]
237 fn reject_invalid_method() {
238 let header = r#"{"alg":"ES256","typ":"JWT"}"#;
239 let payload = r#"{"iss":"https://dev-api.signedshot.io","aud":"signedshot","sub":"capture-service","iat":1705312200,"capture_id":"123","publisher_id":"456","device_id":"789","attestation":{"method":"invalid"}}"#;
240 let token = make_jwt(header, payload);
241
242 let result = parse_jwt(&token);
243 assert!(matches!(result, Err(ValidationError::InvalidJwt(_))));
244 }
245
246 #[test]
247 fn reject_malformed_jwt() {
248 let result = parse_jwt("not.a.valid.jwt");
249 assert!(matches!(result, Err(ValidationError::InvalidJwt(_))));
250 }
251
252 #[test]
253 fn reject_invalid_base64() {
254 let result = parse_jwt("!!!.@@@.###");
255 assert!(matches!(result, Err(ValidationError::JwtDecodeError(_))));
256 }
257}