stakpak_shared/oauth/
provider.rs

1//! OAuth provider trait and authentication method types
2
3use super::config::OAuthConfig;
4use super::error::OAuthResult;
5use super::flow::TokenResponse;
6use crate::models::auth::ProviderAuth;
7use async_trait::async_trait;
8use reqwest::header::HeaderMap;
9
10/// Authentication method offered by a provider
11#[derive(Debug, Clone)]
12pub struct AuthMethod {
13    /// Unique identifier for this method
14    pub id: String,
15    /// Human-readable label
16    pub label: String,
17    /// Description/hint for the user
18    pub description: Option<String>,
19    /// Type of authentication
20    pub method_type: AuthMethodType,
21}
22
23impl AuthMethod {
24    /// Create a new OAuth authentication method
25    pub fn oauth(
26        id: impl Into<String>,
27        label: impl Into<String>,
28        description: Option<String>,
29    ) -> Self {
30        Self {
31            id: id.into(),
32            label: label.into(),
33            description,
34            method_type: AuthMethodType::OAuth,
35        }
36    }
37
38    /// Create a new API key authentication method
39    pub fn api_key(
40        id: impl Into<String>,
41        label: impl Into<String>,
42        description: Option<String>,
43    ) -> Self {
44        Self {
45            id: id.into(),
46            label: label.into(),
47            description,
48            method_type: AuthMethodType::ApiKey,
49        }
50    }
51
52    /// Get a display string combining label and description
53    pub fn display(&self) -> String {
54        match &self.description {
55            Some(desc) => format!("{} - {}", self.label, desc),
56            None => self.label.clone(),
57        }
58    }
59}
60
61/// Type of authentication method
62#[derive(Debug, Clone, PartialEq, Eq)]
63pub enum AuthMethodType {
64    /// OAuth 2.0 browser-based flow
65    OAuth,
66    /// Manual API key entry
67    ApiKey,
68}
69
70/// Trait for providers that support authentication
71#[async_trait]
72pub trait OAuthProvider: Send + Sync {
73    /// Provider identifier (e.g., "anthropic")
74    fn id(&self) -> &'static str;
75
76    /// Human-readable provider name
77    fn name(&self) -> &'static str;
78
79    /// List available authentication methods
80    fn auth_methods(&self) -> Vec<AuthMethod>;
81
82    /// Get OAuth configuration for a specific method
83    fn oauth_config(&self, method_id: &str) -> Option<OAuthConfig>;
84
85    /// Post-authorization processing (e.g., exchange OAuth tokens for API key)
86    ///
87    /// This is called after the OAuth flow completes to convert the tokens
88    /// into the appropriate `ProviderAuth` type.
89    async fn post_authorize(
90        &self,
91        method_id: &str,
92        tokens: &TokenResponse,
93    ) -> OAuthResult<ProviderAuth>;
94
95    /// Apply authentication to HTTP request headers
96    ///
97    /// This method modifies the provided headers to include the appropriate
98    /// authentication headers for API requests.
99    fn apply_auth_headers(&self, auth: &ProviderAuth, headers: &mut HeaderMap) -> OAuthResult<()>;
100
101    /// Get the environment variable name for API key (if supported)
102    fn api_key_env_var(&self) -> Option<&'static str> {
103        None
104    }
105}
106
107#[cfg(test)]
108mod tests {
109    use super::*;
110
111    #[test]
112    fn test_auth_method_oauth() {
113        let method = AuthMethod::oauth(
114            "claude-max",
115            "Claude Pro/Max",
116            Some("Use your subscription".to_string()),
117        );
118
119        assert_eq!(method.id, "claude-max");
120        assert_eq!(method.label, "Claude Pro/Max");
121        assert_eq!(
122            method.description,
123            Some("Use your subscription".to_string())
124        );
125        assert_eq!(method.method_type, AuthMethodType::OAuth);
126    }
127
128    #[test]
129    fn test_auth_method_api_key() {
130        let method = AuthMethod::api_key("api-key", "Manual API Key", None);
131
132        assert_eq!(method.id, "api-key");
133        assert_eq!(method.label, "Manual API Key");
134        assert_eq!(method.description, None);
135        assert_eq!(method.method_type, AuthMethodType::ApiKey);
136    }
137
138    #[test]
139    fn test_auth_method_display() {
140        let with_desc =
141            AuthMethod::oauth("test", "Test Method", Some("Description here".to_string()));
142        assert_eq!(with_desc.display(), "Test Method - Description here");
143
144        let without_desc = AuthMethod::oauth("test", "Test Method", None);
145        assert_eq!(without_desc.display(), "Test Method");
146    }
147}