Skip to main content

securitydept_oauth_resource_server/models/
mod.rs

1use std::{collections::HashMap, time::Duration};
2
3use openidconnect::{IntrospectionUrl, IssuerUrl, JsonWebKeySetUrl, core::CoreJsonWebKeySet};
4use securitydept_creds::{JwtClaimsTrait, Scope, TokenData};
5use serde_json::Value;
6
7pub mod introspection;
8#[cfg(feature = "jwe")]
9pub mod jwe;
10
11pub use introspection::VerifiedOpaqueToken;
12#[cfg(feature = "jwe")]
13pub use jwe::LocalJweDecryptionKeySet;
14
15#[derive(Debug, Clone)]
16pub struct OAuthResourceServerMetadata {
17    pub issuer: IssuerUrl,
18    pub jwks_uri: JsonWebKeySetUrl,
19    pub introspection_url: Option<IntrospectionUrl>,
20}
21
22#[derive(Debug, Clone)]
23pub struct VerificationPolicy {
24    allowed_audiences: Vec<String>,
25    required_scopes: Vec<String>,
26    clock_skew: Duration,
27}
28
29#[derive(Debug, Clone, PartialEq, Eq)]
30pub struct ResourceTokenPrincipal {
31    pub subject: Option<String>,
32    pub issuer: Option<String>,
33    pub audiences: Vec<String>,
34    pub scopes: Vec<String>,
35    pub authorized_party: Option<String>,
36    pub claims: HashMap<String, Value>,
37}
38
39impl VerificationPolicy {
40    pub fn new(
41        allowed_audiences: Vec<String>,
42        required_scopes: Vec<String>,
43        clock_skew: Duration,
44    ) -> Self {
45        Self {
46            allowed_audiences,
47            required_scopes,
48            clock_skew,
49        }
50    }
51
52    pub fn allowed_audiences(&self) -> &[String] {
53        &self.allowed_audiences
54    }
55
56    pub fn required_scopes(&self) -> &[String] {
57        &self.required_scopes
58    }
59
60    pub fn clock_skew(&self) -> Duration {
61        self.clock_skew
62    }
63}
64
65#[derive(Debug, Clone)]
66pub struct JwksState {
67    pub jwks: CoreJsonWebKeySet,
68    pub fetched_at: std::time::Instant,
69}
70
71pub struct VerifiedAccessToken<CLAIMS>
72where
73    CLAIMS: JwtClaimsTrait,
74{
75    pub token_data: TokenData<CLAIMS>,
76    pub metadata: OAuthResourceServerMetadata,
77}
78
79pub enum VerifiedToken<CLAIMS>
80where
81    CLAIMS: JwtClaimsTrait,
82{
83    Structured(Box<VerifiedAccessToken<CLAIMS>>),
84    Opaque(Box<VerifiedOpaqueToken>),
85}
86
87impl<CLAIMS> From<VerifiedOpaqueToken> for VerifiedToken<CLAIMS>
88where
89    CLAIMS: JwtClaimsTrait,
90{
91    fn from(value: VerifiedOpaqueToken) -> Self {
92        Self::Opaque(Box::new(value))
93    }
94}
95
96impl<CLAIMS> From<VerifiedAccessToken<CLAIMS>> for VerifiedToken<CLAIMS>
97where
98    CLAIMS: JwtClaimsTrait,
99{
100    fn from(value: VerifiedAccessToken<CLAIMS>) -> Self {
101        Self::Structured(Box::new(value))
102    }
103}
104
105impl<CLAIMS> VerifiedToken<CLAIMS>
106where
107    CLAIMS: JwtClaimsTrait,
108{
109    pub fn to_resource_token_principal(&self) -> ResourceTokenPrincipal {
110        match self {
111            Self::Structured(token) => structured_token_principal(&token.token_data),
112            Self::Opaque(token) => ResourceTokenPrincipal {
113                subject: token.subject().map(str::to_string),
114                issuer: token.issuer().map(str::to_string),
115                audiences: token.audience().cloned().unwrap_or_default(),
116                scopes: token.scopes().unwrap_or_default(),
117                authorized_party: None,
118                claims: HashMap::new(),
119            },
120        }
121    }
122}
123
124fn structured_token_principal<CLAIMS>(token_data: &TokenData<CLAIMS>) -> ResourceTokenPrincipal
125where
126    CLAIMS: JwtClaimsTrait,
127{
128    let claims = match token_data {
129        TokenData::JWT(token) => &token.claims,
130        TokenData::Opaque => unreachable!("structured token data must not be opaque"),
131        #[allow(unreachable_patterns)]
132        _ => unreachable!("unexpected structured token variant"),
133    };
134    let additional = claims.get_additional().cloned().unwrap_or_default();
135    let audiences = claims
136        .get_audience()
137        .map(|audience| audience.iter().cloned().collect())
138        .unwrap_or_default();
139    let scopes = additional
140        .get("scope")
141        .or_else(|| additional.get("scp"))
142        .map(value_as_scope_list)
143        .unwrap_or_default();
144
145    ResourceTokenPrincipal {
146        subject: claims.get_subject().map(str::to_string),
147        issuer: claims.get_issuer().map(str::to_string),
148        audiences,
149        scopes,
150        authorized_party: additional
151            .get("azp")
152            .and_then(Value::as_str)
153            .map(str::to_string),
154        claims: additional,
155    }
156}
157
158fn value_as_scope_list(value: &Value) -> Vec<String> {
159    match value {
160        Value::String(raw) => raw
161            .split_whitespace()
162            .filter(|scope| !scope.is_empty())
163            .map(str::to_string)
164            .collect(),
165        Value::Array(items) => items
166            .iter()
167            .filter_map(Value::as_str)
168            .map(str::to_string)
169            .collect(),
170        _ => Vec::new(),
171    }
172}
173
174pub fn scope_contains_all(scope: Option<&Scope>, required_scopes: &[String]) -> bool {
175    if required_scopes.is_empty() {
176        return true;
177    }
178
179    let Some(scope) = scope else {
180        return false;
181    };
182
183    required_scopes
184        .iter()
185        .all(|required_scope| scope.iter().any(|value| value == required_scope))
186}
187
188#[cfg(test)]
189mod tests {
190    use securitydept_creds::Scope;
191
192    use super::scope_contains_all;
193
194    #[test]
195    fn scope_policy_accepts_required_scopes() {
196        let scope: Scope = serde_json::from_str("\"read write\"").expect("scope should parse");
197
198        assert!(scope_contains_all(
199            Some(&scope),
200            &["read".to_string(), "write".to_string()]
201        ));
202    }
203
204    #[test]
205    fn scope_policy_rejects_missing_scope() {
206        let scope: Scope = serde_json::from_str("\"read\"").expect("scope should parse");
207
208        assert!(!scope_contains_all(Some(&scope), &["write".to_string()]));
209    }
210}