Skip to main content

securitydept_oauth_resource_server/models/
mod.rs

1use std::{collections::HashMap, time::Duration};
2
3use openidconnect::{IntrospectionUrl, IssuerUrl, JsonWebKeySetUrl, core::CoreJsonWebKeySet};
4use securitydept_creds::{JwtClaimsTrait, Scope, TokenData};
5use serde_json::Value;
6
7pub mod introspection;
8#[cfg(feature = "jwe")]
9pub mod jwe;
10
11pub use introspection::VerifiedOpaqueToken;
12#[cfg(feature = "jwe")]
13pub use jwe::LocalJweDecryptionKeySet;
14
15#[derive(Debug, Clone)]
16pub struct OAuthResourceServerMetadata {
17    pub issuer: IssuerUrl,
18    pub jwks_uri: JsonWebKeySetUrl,
19    pub introspection_url: Option<IntrospectionUrl>,
20}
21
22#[derive(Debug, Clone)]
23pub struct VerificationPolicy {
24    allowed_audiences: Vec<String>,
25    required_scopes: Vec<String>,
26    clock_skew: Duration,
27}
28
29#[derive(Debug, Clone, PartialEq, Eq)]
30pub struct ResourceTokenPrincipal {
31    pub subject: Option<String>,
32    pub issuer: Option<String>,
33    pub audiences: Vec<String>,
34    pub scopes: Vec<String>,
35    pub authorized_party: Option<String>,
36    pub claims: HashMap<String, Value>,
37}
38
39impl VerificationPolicy {
40    pub fn new(
41        allowed_audiences: Vec<String>,
42        required_scopes: Vec<String>,
43        clock_skew: Duration,
44    ) -> Self {
45        Self {
46            allowed_audiences,
47            required_scopes,
48            clock_skew,
49        }
50    }
51
52    pub fn allowed_audiences(&self) -> &[String] {
53        &self.allowed_audiences
54    }
55
56    pub fn required_scopes(&self) -> &[String] {
57        &self.required_scopes
58    }
59
60    pub fn clock_skew(&self) -> Duration {
61        self.clock_skew
62    }
63}
64
65#[derive(Debug, Clone)]
66pub struct JwksState {
67    pub jwks: CoreJsonWebKeySet,
68    pub fetched_at: std::time::Instant,
69}
70
71pub struct VerifiedAccessToken<CLAIMS>
72where
73    CLAIMS: JwtClaimsTrait,
74{
75    pub token_data: TokenData<CLAIMS>,
76    pub metadata: OAuthResourceServerMetadata,
77}
78
79pub enum VerifiedToken<CLAIMS>
80where
81    CLAIMS: JwtClaimsTrait,
82{
83    Structured(Box<VerifiedAccessToken<CLAIMS>>),
84    Opaque(Box<VerifiedOpaqueToken>),
85}
86
87impl<CLAIMS> From<VerifiedOpaqueToken> for VerifiedToken<CLAIMS>
88where
89    CLAIMS: JwtClaimsTrait,
90{
91    fn from(value: VerifiedOpaqueToken) -> Self {
92        Self::Opaque(Box::new(value))
93    }
94}
95
96impl<CLAIMS> From<VerifiedAccessToken<CLAIMS>> for VerifiedToken<CLAIMS>
97where
98    CLAIMS: JwtClaimsTrait,
99{
100    fn from(value: VerifiedAccessToken<CLAIMS>) -> Self {
101        Self::Structured(Box::new(value))
102    }
103}
104
105impl<CLAIMS> VerifiedToken<CLAIMS>
106where
107    CLAIMS: JwtClaimsTrait,
108{
109    pub fn to_resource_token_principal(&self) -> ResourceTokenPrincipal {
110        match self {
111            Self::Structured(token) => structured_token_principal(&token.token_data),
112            Self::Opaque(token) => ResourceTokenPrincipal {
113                subject: token.subject().map(str::to_string),
114                issuer: token.issuer().map(str::to_string),
115                audiences: token.audience().cloned().unwrap_or_default(),
116                scopes: token.scopes().unwrap_or_default(),
117                authorized_party: None,
118                claims: HashMap::new(),
119            },
120        }
121    }
122}
123
124fn structured_token_principal<CLAIMS>(token_data: &TokenData<CLAIMS>) -> ResourceTokenPrincipal
125where
126    CLAIMS: JwtClaimsTrait,
127{
128    let claims = match token_data {
129        TokenData::JWT(token) => &token.claims,
130        TokenData::Opaque => unreachable!("structured token data must not be opaque"),
131        #[allow(unreachable_patterns)]
132        _ => unreachable!("unexpected structured token variant"),
133    };
134    let additional = claims.get_additional().cloned().unwrap_or_default();
135    let projected_claims = project_additional_claims(additional.clone());
136    let audiences = claims
137        .get_audience()
138        .map(|audience| audience.iter().cloned().collect())
139        .unwrap_or_default();
140    let scopes = additional
141        .get("scope")
142        .or_else(|| additional.get("scp"))
143        .map(value_as_scope_list)
144        .unwrap_or_default();
145
146    ResourceTokenPrincipal {
147        subject: claims.get_subject().map(str::to_string),
148        issuer: claims.get_issuer().map(str::to_string),
149        audiences,
150        scopes,
151        authorized_party: additional
152            .get("azp")
153            .and_then(Value::as_str)
154            .map(str::to_string),
155        claims: projected_claims,
156    }
157}
158
159fn project_additional_claims(additional: HashMap<String, Value>) -> HashMap<String, Value> {
160    additional
161        .into_iter()
162        .filter(|(key, _)| !is_sensitive_additional_claim_key(key))
163        .collect()
164}
165
166fn is_sensitive_additional_claim_key(key: &str) -> bool {
167    let tokens = claim_key_tokens(key);
168    if tokens.is_empty() {
169        return false;
170    }
171    let token_slices = tokens.iter().map(String::as_str).collect::<Vec<_>>();
172
173    if matches!(
174        token_slices.as_slice(),
175        ["access", "token"] | ["refresh", "token"] | ["id", "token"] | ["client", "secret"]
176    ) {
177        return true;
178    }
179
180    tokens.iter().any(|token| {
181        matches!(
182            token.as_str(),
183            "authorization" | "password" | "secret" | "scope" | "scp" | "azp"
184        )
185    })
186}
187
188fn claim_key_tokens(key: &str) -> Vec<String> {
189    let mut tokens = Vec::new();
190    let mut current = String::new();
191
192    for character in key.chars() {
193        if !character.is_ascii_alphanumeric() {
194            if !current.is_empty() {
195                tokens.push(std::mem::take(&mut current));
196            }
197            continue;
198        }
199
200        if character.is_ascii_uppercase()
201            && !current.is_empty()
202            && current
203                .chars()
204                .last()
205                .is_some_and(|last| last.is_ascii_lowercase())
206        {
207            tokens.push(std::mem::take(&mut current));
208        }
209
210        current.push(character.to_ascii_lowercase());
211    }
212
213    if !current.is_empty() {
214        tokens.push(current);
215    }
216
217    tokens
218}
219
220fn value_as_scope_list(value: &Value) -> Vec<String> {
221    match value {
222        Value::String(raw) => raw
223            .split_whitespace()
224            .filter(|scope| !scope.is_empty())
225            .map(str::to_string)
226            .collect(),
227        Value::Array(items) => items
228            .iter()
229            .filter_map(Value::as_str)
230            .map(str::to_string)
231            .collect(),
232        _ => Vec::new(),
233    }
234}
235
236pub fn scope_contains_all(scope: Option<&Scope>, required_scopes: &[String]) -> bool {
237    if required_scopes.is_empty() {
238        return true;
239    }
240
241    let Some(scope) = scope else {
242        return false;
243    };
244
245    required_scopes
246        .iter()
247        .all(|required_scope| scope.iter().any(|value| value == required_scope))
248}
249
250#[cfg(test)]
251mod tests {
252    use std::collections::HashMap;
253
254    use securitydept_creds::{CoreJwtClaims, JwtHeader, Scope, TokenData};
255    use serde_json::json;
256
257    use super::{
258        claim_key_tokens, is_sensitive_additional_claim_key, scope_contains_all,
259        structured_token_principal,
260    };
261
262    #[test]
263    fn scope_policy_accepts_required_scopes() {
264        let scope: Scope = serde_json::from_str("\"read write\"").expect("scope should parse");
265
266        assert!(scope_contains_all(
267            Some(&scope),
268            &["read".to_string(), "write".to_string()]
269        ));
270    }
271
272    #[test]
273    fn scope_policy_rejects_missing_scope() {
274        let scope: Scope = serde_json::from_str("\"read\"").expect("scope should parse");
275
276        assert!(!scope_contains_all(Some(&scope), &["write".to_string()]));
277    }
278
279    #[test]
280    fn sensitive_claim_key_matching_is_case_insensitive_and_separator_agnostic() {
281        assert_eq!(claim_key_tokens("clientSecret"), vec!["client", "secret"]);
282        assert!(is_sensitive_additional_claim_key("access_token"));
283        assert!(is_sensitive_additional_claim_key("refreshToken"));
284        assert!(is_sensitive_additional_claim_key("id-token"));
285        assert!(is_sensitive_additional_claim_key("Authorization"));
286        assert!(is_sensitive_additional_claim_key("authorization_header"));
287        assert!(is_sensitive_additional_claim_key("client_secret"));
288        assert!(is_sensitive_additional_claim_key("client-secret"));
289        assert!(is_sensitive_additional_claim_key("provider_secret"));
290        assert!(!is_sensitive_additional_claim_key("scoped_feature"));
291        assert!(!is_sensitive_additional_claim_key("secretariat"));
292    }
293
294    #[test]
295    fn structured_token_principal_projects_only_safe_additional_claims() {
296        let mut additional = HashMap::new();
297        additional.insert("access_token".to_string(), json!("at-1"));
298        additional.insert("refreshToken".to_string(), json!("rt-1"));
299        additional.insert("id-token".to_string(), json!("id-1"));
300        additional.insert("Authorization".to_string(), json!("Bearer test"));
301        additional.insert("clientSecret".to_string(), json!("top-secret"));
302        additional.insert("provider_secret".to_string(), json!("nested-secret"));
303        additional.insert("password".to_string(), json!("p@ss"));
304        additional.insert("scope".to_string(), json!("read write"));
305        additional.insert("scp".to_string(), json!(["read", "write"]));
306        additional.insert("azp".to_string(), json!("webui-client"));
307        additional.insert("tenant".to_string(), json!("acme"));
308        additional.insert("feature_flags".to_string(), json!(["alpha"]));
309
310        let principal =
311            structured_token_principal(&TokenData::JWT(Box::new(jsonwebtoken::TokenData {
312                header: JwtHeader::default(),
313                claims: CoreJwtClaims {
314                    subject: Some("user-1".to_string()),
315                    issuer: Some("https://issuer.example.com".to_string()),
316                    audience: Some(
317                        serde_json::from_value(json!(["api", "web"]))
318                            .expect("audience should parse"),
319                    ),
320                    expiration_time: Some(1_234_567_890),
321                    not_before: None,
322                    additional,
323                },
324            })));
325
326        assert_eq!(principal.subject.as_deref(), Some("user-1"));
327        assert_eq!(principal.authorized_party.as_deref(), Some("webui-client"));
328        assert_eq!(
329            principal.scopes,
330            vec!["read".to_string(), "write".to_string()]
331        );
332        assert_eq!(principal.claims.get("tenant"), Some(&json!("acme")));
333        assert_eq!(
334            principal.claims.get("feature_flags"),
335            Some(&json!(["alpha"]))
336        );
337        assert!(!principal.claims.contains_key("access_token"));
338        assert!(!principal.claims.contains_key("refreshToken"));
339        assert!(!principal.claims.contains_key("id-token"));
340        assert!(!principal.claims.contains_key("Authorization"));
341        assert!(!principal.claims.contains_key("clientSecret"));
342        assert!(!principal.claims.contains_key("provider_secret"));
343        assert!(!principal.claims.contains_key("password"));
344        assert!(!principal.claims.contains_key("scope"));
345        assert!(!principal.claims.contains_key("scp"));
346        assert!(!principal.claims.contains_key("azp"));
347    }
348}