Skip to main content

stakpak_shared/oauth/
config.rs

1//! OAuth configuration types
2
3/// Provider-specific authorization request shape.
4#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
5pub enum AuthorizationRequestMode {
6    /// Standard OAuth 2.0 Authorization Code + PKCE request.
7    #[default]
8    StandardPkce,
9    /// Legacy request shape that includes `code=true`.
10    LegacyCode,
11}
12
13/// Provider-specific token request encoding.
14#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
15pub enum TokenRequestMode {
16    /// JSON request bodies used by legacy providers.
17    #[default]
18    Json,
19    /// `application/x-www-form-urlencoded` request bodies.
20    FormUrlEncoded,
21}
22
23/// Configuration for an OAuth 2.0 provider
24#[derive(Debug, Clone)]
25pub struct OAuthConfig {
26    /// OAuth client ID
27    pub client_id: String,
28    /// Authorization endpoint URL
29    pub auth_url: String,
30    /// Token exchange endpoint URL
31    pub token_url: String,
32    /// Redirect URI for authorization callback
33    pub redirect_url: String,
34    /// Scopes to request
35    pub scopes: Vec<String>,
36    /// Provider-specific authorization request mode.
37    pub authorization_request_mode: AuthorizationRequestMode,
38    /// Additional provider-specific authorization query parameters.
39    pub authorization_params: Vec<(String, String)>,
40    /// Provider-specific token request encoding.
41    pub token_request_mode: TokenRequestMode,
42}
43
44impl OAuthConfig {
45    /// Create a new OAuth configuration
46    pub fn new(
47        client_id: impl Into<String>,
48        auth_url: impl Into<String>,
49        token_url: impl Into<String>,
50        redirect_url: impl Into<String>,
51        scopes: Vec<String>,
52    ) -> Self {
53        Self {
54            client_id: client_id.into(),
55            auth_url: auth_url.into(),
56            token_url: token_url.into(),
57            redirect_url: redirect_url.into(),
58            scopes,
59            authorization_request_mode: AuthorizationRequestMode::StandardPkce,
60            authorization_params: Vec::new(),
61            token_request_mode: TokenRequestMode::Json,
62        }
63    }
64
65    /// Override the authorization request mode for providers with non-standard requirements.
66    pub fn with_authorization_request_mode(mut self, mode: AuthorizationRequestMode) -> Self {
67        self.authorization_request_mode = mode;
68        self
69    }
70
71    /// Add provider-specific authorization query parameters.
72    pub fn with_authorization_params<K, V, I>(mut self, params: I) -> Self
73    where
74        K: Into<String>,
75        V: Into<String>,
76        I: IntoIterator<Item = (K, V)>,
77    {
78        self.authorization_params = params
79            .into_iter()
80            .map(|(key, value)| (key.into(), value.into()))
81            .collect();
82        self
83    }
84
85    /// Override the token request encoding for providers with non-standard requirements.
86    pub fn with_token_request_mode(mut self, mode: TokenRequestMode) -> Self {
87        self.token_request_mode = mode;
88        self
89    }
90
91    /// Get the scopes as a space-separated string
92    pub fn scopes_string(&self) -> String {
93        self.scopes.join(" ")
94    }
95}
96
97#[cfg(test)]
98mod tests {
99    use super::*;
100
101    #[test]
102    fn test_oauth_config_creation() {
103        let config = OAuthConfig::new(
104            "client-id",
105            "https://example.com/auth",
106            "https://example.com/token",
107            "https://example.com/callback",
108            vec!["scope1".to_string(), "scope2".to_string()],
109        );
110
111        assert_eq!(config.client_id, "client-id");
112        assert_eq!(config.auth_url, "https://example.com/auth");
113        assert_eq!(config.token_url, "https://example.com/token");
114        assert_eq!(config.redirect_url, "https://example.com/callback");
115        assert_eq!(config.scopes, vec!["scope1", "scope2"]);
116        assert_eq!(
117            config.authorization_request_mode,
118            AuthorizationRequestMode::StandardPkce
119        );
120        assert!(config.authorization_params.is_empty());
121        assert_eq!(config.token_request_mode, TokenRequestMode::Json);
122    }
123
124    #[test]
125    fn test_authorization_request_mode_builder() {
126        let config = OAuthConfig::new(
127            "client-id",
128            "https://example.com/auth",
129            "https://example.com/token",
130            "https://example.com/callback",
131            vec!["scope".to_string()],
132        )
133        .with_authorization_request_mode(AuthorizationRequestMode::LegacyCode);
134
135        assert_eq!(
136            config.authorization_request_mode,
137            AuthorizationRequestMode::LegacyCode
138        );
139    }
140
141    #[test]
142    fn test_authorization_params_builder() {
143        let config = OAuthConfig::new(
144            "client-id",
145            "https://example.com/auth",
146            "https://example.com/token",
147            "https://example.com/callback",
148            vec!["scope".to_string()],
149        )
150        .with_authorization_params(vec![("originator", "stakpak"), ("mode", "codex")]);
151
152        assert_eq!(
153            config.authorization_params,
154            vec![
155                ("originator".to_string(), "stakpak".to_string()),
156                ("mode".to_string(), "codex".to_string()),
157            ]
158        );
159    }
160
161    #[test]
162    fn test_token_request_mode_builder() {
163        let config = OAuthConfig::new(
164            "client-id",
165            "https://example.com/auth",
166            "https://example.com/token",
167            "https://example.com/callback",
168            vec!["scope".to_string()],
169        )
170        .with_token_request_mode(TokenRequestMode::FormUrlEncoded);
171
172        assert_eq!(config.token_request_mode, TokenRequestMode::FormUrlEncoded);
173    }
174
175    #[test]
176    fn test_scopes_string() {
177        let config = OAuthConfig::new(
178            "client-id",
179            "https://example.com/auth",
180            "https://example.com/token",
181            "https://example.com/callback",
182            vec!["read".to_string(), "write".to_string(), "admin".to_string()],
183        );
184
185        assert_eq!(config.scopes_string(), "read write admin");
186    }
187
188    #[test]
189    fn test_empty_scopes() {
190        let config = OAuthConfig::new(
191            "client-id",
192            "https://example.com/auth",
193            "https://example.com/token",
194            "https://example.com/callback",
195            vec![],
196        );
197
198        assert_eq!(config.scopes_string(), "");
199    }
200}