stakpak_shared/models/
auth.rs

1//! Authentication credentials for LLM providers
2//!
3//! This module defines the `ProviderAuth` enum which represents different
4//! authentication methods for LLM providers (API key or OAuth tokens).
5
6use serde::{Deserialize, Serialize};
7
8/// Authentication credentials for an LLM provider
9#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
10#[serde(tag = "type", rename_all = "snake_case")]
11pub enum ProviderAuth {
12    /// API key authentication
13    Api {
14        /// The API key
15        key: String,
16    },
17
18    /// OAuth 2.0 authentication with refresh tokens
19    #[serde(rename = "oauth")]
20    OAuth {
21        /// Access token for API requests
22        access: String,
23        /// Refresh token for obtaining new access tokens
24        refresh: String,
25        /// Expiration timestamp in milliseconds since Unix epoch
26        expires: i64,
27        /// Optional name for the subscription (e.g. "Claude Pro", "Claude Max")
28        #[serde(skip_serializing_if = "Option::is_none")]
29        name: Option<String>,
30    },
31}
32
33impl ProviderAuth {
34    /// Create a new API key authentication
35    pub fn api_key(key: impl Into<String>) -> Self {
36        Self::Api { key: key.into() }
37    }
38
39    /// Create a new OAuth authentication
40    pub fn oauth(access: impl Into<String>, refresh: impl Into<String>, expires: i64) -> Self {
41        Self::OAuth {
42            access: access.into(),
43            refresh: refresh.into(),
44            expires,
45            name: None,
46        }
47    }
48
49    /// Create a new OAuth authentication with subscription name
50    pub fn oauth_with_name(
51        access: impl Into<String>,
52        refresh: impl Into<String>,
53        expires: i64,
54        name: impl Into<String>,
55    ) -> Self {
56        Self::OAuth {
57            access: access.into(),
58            refresh: refresh.into(),
59            expires,
60            name: Some(name.into()),
61        }
62    }
63
64    /// Check if OAuth token needs refresh (within 5 minutes of expiry)
65    pub fn needs_refresh(&self) -> bool {
66        match self {
67            Self::OAuth { expires, .. } => {
68                let now_ms = chrono::Utc::now().timestamp_millis();
69                let buffer_ms = 5 * 60 * 1000; // 5 minutes
70                *expires < (now_ms + buffer_ms)
71            }
72            Self::Api { .. } => false,
73        }
74    }
75
76    /// Check if OAuth token is expired
77    pub fn is_expired(&self) -> bool {
78        match self {
79            Self::OAuth { expires, .. } => *expires < chrono::Utc::now().timestamp_millis(),
80            Self::Api { .. } => false,
81        }
82    }
83
84    /// Get the API key if this is an API key auth
85    pub fn api_key_value(&self) -> Option<&str> {
86        match self {
87            Self::Api { key } => Some(key),
88            Self::OAuth { .. } => None,
89        }
90    }
91
92    /// Get the access token if this is an OAuth auth
93    pub fn access_token(&self) -> Option<&str> {
94        match self {
95            Self::OAuth { access, .. } => Some(access),
96            Self::Api { .. } => None,
97        }
98    }
99
100    /// Get the refresh token if this is an OAuth auth
101    pub fn refresh_token(&self) -> Option<&str> {
102        match self {
103            Self::OAuth { refresh, .. } => Some(refresh),
104            Self::Api { .. } => None,
105        }
106    }
107
108    /// Check if this is an OAuth authentication
109    pub fn is_oauth(&self) -> bool {
110        matches!(self, Self::OAuth { .. })
111    }
112
113    /// Check if this is an API key authentication
114    pub fn is_api_key(&self) -> bool {
115        matches!(self, Self::Api { .. })
116    }
117
118    /// Get a display-safe representation of the auth type
119    pub fn auth_type_display(&self) -> &'static str {
120        match self {
121            Self::Api { .. } => "api_key",
122            Self::OAuth { .. } => "oauth",
123        }
124    }
125
126    /// Get the subscription name if this is an OAuth auth with a name
127    pub fn subscription_name(&self) -> Option<&str> {
128        match self {
129            Self::OAuth { name, .. } => name.as_deref(),
130            Self::Api { .. } => None,
131        }
132    }
133}
134
135#[cfg(test)]
136mod tests {
137    use super::*;
138
139    #[test]
140    fn test_api_key_creation() {
141        let auth = ProviderAuth::api_key("sk-test-key");
142        assert!(auth.is_api_key());
143        assert!(!auth.is_oauth());
144        assert_eq!(auth.api_key_value(), Some("sk-test-key"));
145        assert_eq!(auth.access_token(), None);
146    }
147
148    #[test]
149    fn test_oauth_creation() {
150        let expires = chrono::Utc::now().timestamp_millis() + 3600000; // 1 hour from now
151        let auth = ProviderAuth::oauth("access-token", "refresh-token", expires);
152        assert!(auth.is_oauth());
153        assert!(!auth.is_api_key());
154        assert_eq!(auth.access_token(), Some("access-token"));
155        assert_eq!(auth.refresh_token(), Some("refresh-token"));
156        assert_eq!(auth.api_key_value(), None);
157    }
158
159    #[test]
160    fn test_oauth_needs_refresh() {
161        // Token expiring in 2 minutes - should need refresh
162        let expires = chrono::Utc::now().timestamp_millis() + 2 * 60 * 1000;
163        let auth = ProviderAuth::oauth("access", "refresh", expires);
164        assert!(auth.needs_refresh());
165
166        // Token expiring in 10 minutes - should not need refresh
167        let expires = chrono::Utc::now().timestamp_millis() + 10 * 60 * 1000;
168        let auth = ProviderAuth::oauth("access", "refresh", expires);
169        assert!(!auth.needs_refresh());
170    }
171
172    #[test]
173    fn test_oauth_is_expired() {
174        // Expired token
175        let expires = chrono::Utc::now().timestamp_millis() - 1000;
176        let auth = ProviderAuth::oauth("access", "refresh", expires);
177        assert!(auth.is_expired());
178
179        // Valid token
180        let expires = chrono::Utc::now().timestamp_millis() + 3600000;
181        let auth = ProviderAuth::oauth("access", "refresh", expires);
182        assert!(!auth.is_expired());
183    }
184
185    #[test]
186    fn test_api_key_never_needs_refresh() {
187        let auth = ProviderAuth::api_key("sk-test");
188        assert!(!auth.needs_refresh());
189        assert!(!auth.is_expired());
190    }
191
192    #[test]
193    fn test_serde_api_key() {
194        let auth = ProviderAuth::api_key("sk-test-key");
195        let json = serde_json::to_string(&auth).unwrap();
196        assert!(json.contains("\"type\":\"api\""));
197        assert!(json.contains("\"key\":\"sk-test-key\""));
198
199        let parsed: ProviderAuth = serde_json::from_str(&json).unwrap();
200        assert_eq!(auth, parsed);
201    }
202
203    #[test]
204    fn test_serde_oauth() {
205        let auth = ProviderAuth::oauth("access-token", "refresh-token", 1735600000000);
206        let json = serde_json::to_string(&auth).unwrap();
207        assert!(json.contains("\"type\":\"oauth\""), "JSON was: {}", json);
208        assert!(json.contains("\"access\":\"access-token\""));
209        assert!(json.contains("\"refresh\":\"refresh-token\""));
210        assert!(json.contains("\"expires\":1735600000000"));
211
212        let parsed: ProviderAuth = serde_json::from_str(&json).unwrap();
213        assert_eq!(auth, parsed);
214    }
215
216    #[test]
217    fn test_auth_type_display() {
218        let api = ProviderAuth::api_key("key");
219        assert_eq!(api.auth_type_display(), "api_key");
220
221        let oauth = ProviderAuth::oauth("access", "refresh", 0);
222        assert_eq!(oauth.auth_type_display(), "oauth");
223    }
224}