Skip to main content

stakpak_shared/oauth/
provider.rs

1//! OAuth provider trait and authentication method types
2
3use super::config::OAuthConfig;
4use super::device_flow::{DeviceCodeResponse, DeviceFlow, DeviceTokenResponse};
5use super::error::{OAuthError, OAuthResult};
6use super::flow::TokenResponse;
7use crate::models::auth::ProviderAuth;
8use async_trait::async_trait;
9use reqwest::header::HeaderMap;
10
11/// Authentication method offered by a provider
12#[derive(Debug, Clone)]
13pub struct AuthMethod {
14    /// Unique identifier for this method
15    pub id: String,
16    /// Human-readable label
17    pub label: String,
18    /// Description/hint for the user
19    pub description: Option<String>,
20    /// Type of authentication
21    pub method_type: AuthMethodType,
22}
23
24impl AuthMethod {
25    /// Create a new OAuth authentication method
26    pub fn oauth(
27        id: impl Into<String>,
28        label: impl Into<String>,
29        description: Option<String>,
30    ) -> Self {
31        Self {
32            id: id.into(),
33            label: label.into(),
34            description,
35            method_type: AuthMethodType::OAuth,
36        }
37    }
38
39    /// Create a new API key authentication method
40    pub fn api_key(
41        id: impl Into<String>,
42        label: impl Into<String>,
43        description: Option<String>,
44    ) -> Self {
45        Self {
46            id: id.into(),
47            label: label.into(),
48            description,
49            method_type: AuthMethodType::ApiKey,
50        }
51    }
52
53    /// Get a display string combining label and description
54    pub fn display(&self) -> String {
55        match &self.description {
56            Some(desc) => format!("{} - {}", self.label, desc),
57            None => self.label.clone(),
58        }
59    }
60}
61
62/// Type of authentication method
63#[derive(Debug, Clone, PartialEq, Eq)]
64pub enum AuthMethodType {
65    /// OAuth 2.0 browser-based flow (PKCE)
66    OAuth,
67    /// Manual API key entry
68    ApiKey,
69    /// Device Authorization Grant (RFC 8628) — polling-based, no browser redirect
70    DeviceFlow,
71}
72
73/// Trait for providers that support authentication
74#[async_trait]
75pub trait OAuthProvider: Send + Sync {
76    /// Provider identifier (e.g., "anthropic")
77    fn id(&self) -> &'static str;
78
79    /// Human-readable provider name
80    fn name(&self) -> &'static str;
81
82    /// List available authentication methods
83    fn auth_methods(&self) -> Vec<AuthMethod>;
84
85    /// Get OAuth configuration for a specific method
86    fn oauth_config(&self, method_id: &str) -> Option<OAuthConfig>;
87
88    /// Post-authorization processing (e.g., exchange OAuth tokens for API key)
89    ///
90    /// This is called after the OAuth flow completes to convert the tokens
91    /// into the appropriate `ProviderAuth` type.
92    async fn post_authorize(
93        &self,
94        method_id: &str,
95        tokens: &TokenResponse,
96    ) -> OAuthResult<ProviderAuth>;
97
98    /// Apply authentication to HTTP request headers
99    ///
100    /// This method modifies the provided headers to include the appropriate
101    /// authentication headers for API requests.
102    fn apply_auth_headers(&self, auth: &ProviderAuth, headers: &mut HeaderMap) -> OAuthResult<()>;
103
104    /// Get the environment variable name for API key (if supported)
105    fn api_key_env_var(&self) -> Option<&'static str> {
106        None
107    }
108
109    /// Build a [`DeviceFlow`] for the given method.
110    ///
111    /// Override this for any method whose `method_type` is
112    /// [`AuthMethodType::DeviceFlow`].  The default implementation returns an
113    /// error so providers that don't support device flow fail with a clear
114    /// message rather than a panic.
115    fn device_flow(&self, method_id: &str) -> OAuthResult<DeviceFlow> {
116        Err(OAuthError::unknown_method(format!(
117            "Provider '{}' does not support the Device Authorization Grant for method '{}'",
118            self.id(),
119            method_id,
120        )))
121    }
122
123    /// Step 1 of the Device Authorization Grant: request a device code.
124    async fn request_device_code(
125        &self,
126        method_id: &str,
127    ) -> OAuthResult<(DeviceFlow, DeviceCodeResponse)> {
128        let flow = self.device_flow(method_id)?;
129        let code = flow.request_device_code().await?;
130        Ok((flow, code))
131    }
132
133    /// Step 2 of the Device Authorization Grant: poll until the user approves.
134    async fn wait_for_token(
135        &self,
136        flow: &DeviceFlow,
137        device_code: &DeviceCodeResponse,
138    ) -> OAuthResult<DeviceTokenResponse> {
139        flow.poll_for_token(device_code).await
140    }
141
142    /// Post-authorization processing for the Device Authorization Grant.
143    async fn post_device_authorize(
144        &self,
145        method_id: &str,
146        token: &DeviceTokenResponse,
147    ) -> OAuthResult<ProviderAuth> {
148        let _ = (method_id, token);
149        Err(OAuthError::unknown_method(format!(
150            "Provider '{}' does not support post_device_authorize for method '{}'",
151            self.id(),
152            method_id,
153        )))
154    }
155}
156
157#[cfg(test)]
158mod tests {
159    use super::*;
160
161    #[test]
162    fn test_auth_method_oauth() {
163        let method = AuthMethod::oauth(
164            "claude-max",
165            "Claude Pro/Max",
166            Some("Use your subscription".to_string()),
167        );
168
169        assert_eq!(method.id, "claude-max");
170        assert_eq!(method.label, "Claude Pro/Max");
171        assert_eq!(
172            method.description,
173            Some("Use your subscription".to_string())
174        );
175        assert_eq!(method.method_type, AuthMethodType::OAuth);
176    }
177
178    #[test]
179    fn test_auth_method_api_key() {
180        let method = AuthMethod::api_key("api-key", "Manual API Key", None);
181
182        assert_eq!(method.id, "api-key");
183        assert_eq!(method.label, "Manual API Key");
184        assert_eq!(method.description, None);
185        assert_eq!(method.method_type, AuthMethodType::ApiKey);
186    }
187
188    #[test]
189    fn test_auth_method_display() {
190        let with_desc =
191            AuthMethod::oauth("test", "Test Method", Some("Description here".to_string()));
192        assert_eq!(with_desc.display(), "Test Method - Description here");
193
194        let without_desc = AuthMethod::oauth("test", "Test Method", None);
195        assert_eq!(without_desc.display(), "Test Method");
196    }
197}