stakpak_shared/models/
auth.rs1use serde::{Deserialize, Serialize};
7
8#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
10#[serde(tag = "type", rename_all = "snake_case")]
11pub enum ProviderAuth {
12 Api {
14 key: String,
16 },
17
18 #[serde(rename = "oauth")]
20 OAuth {
21 access: String,
23 refresh: String,
25 expires: i64,
27 #[serde(skip_serializing_if = "Option::is_none")]
29 name: Option<String>,
30 },
31}
32
33impl ProviderAuth {
34 pub fn api_key(key: impl Into<String>) -> Self {
36 Self::Api { key: key.into() }
37 }
38
39 pub fn oauth(access: impl Into<String>, refresh: impl Into<String>, expires: i64) -> Self {
41 Self::OAuth {
42 access: access.into(),
43 refresh: refresh.into(),
44 expires,
45 name: None,
46 }
47 }
48
49 pub fn oauth_with_name(
51 access: impl Into<String>,
52 refresh: impl Into<String>,
53 expires: i64,
54 name: impl Into<String>,
55 ) -> Self {
56 Self::OAuth {
57 access: access.into(),
58 refresh: refresh.into(),
59 expires,
60 name: Some(name.into()),
61 }
62 }
63
64 pub fn needs_refresh(&self) -> bool {
66 match self {
67 Self::OAuth { expires, .. } => {
68 let now_ms = chrono::Utc::now().timestamp_millis();
69 let buffer_ms = 5 * 60 * 1000; *expires < (now_ms + buffer_ms)
71 }
72 Self::Api { .. } => false,
73 }
74 }
75
76 pub fn is_expired(&self) -> bool {
78 match self {
79 Self::OAuth { expires, .. } => *expires < chrono::Utc::now().timestamp_millis(),
80 Self::Api { .. } => false,
81 }
82 }
83
84 pub fn api_key_value(&self) -> Option<&str> {
86 match self {
87 Self::Api { key } => Some(key),
88 Self::OAuth { .. } => None,
89 }
90 }
91
92 pub fn access_token(&self) -> Option<&str> {
94 match self {
95 Self::OAuth { access, .. } => Some(access),
96 Self::Api { .. } => None,
97 }
98 }
99
100 pub fn refresh_token(&self) -> Option<&str> {
102 match self {
103 Self::OAuth { refresh, .. } => Some(refresh),
104 Self::Api { .. } => None,
105 }
106 }
107
108 pub fn is_oauth(&self) -> bool {
110 matches!(self, Self::OAuth { .. })
111 }
112
113 pub fn is_api_key(&self) -> bool {
115 matches!(self, Self::Api { .. })
116 }
117
118 pub fn auth_type_display(&self) -> &'static str {
120 match self {
121 Self::Api { .. } => "api_key",
122 Self::OAuth { .. } => "oauth",
123 }
124 }
125
126 pub fn subscription_name(&self) -> Option<&str> {
128 match self {
129 Self::OAuth { name, .. } => name.as_deref(),
130 Self::Api { .. } => None,
131 }
132 }
133}
134
135#[cfg(test)]
136mod tests {
137 use super::*;
138
139 #[test]
140 fn test_api_key_creation() {
141 let auth = ProviderAuth::api_key("sk-test-key");
142 assert!(auth.is_api_key());
143 assert!(!auth.is_oauth());
144 assert_eq!(auth.api_key_value(), Some("sk-test-key"));
145 assert_eq!(auth.access_token(), None);
146 }
147
148 #[test]
149 fn test_oauth_creation() {
150 let expires = chrono::Utc::now().timestamp_millis() + 3600000; let auth = ProviderAuth::oauth("access-token", "refresh-token", expires);
152 assert!(auth.is_oauth());
153 assert!(!auth.is_api_key());
154 assert_eq!(auth.access_token(), Some("access-token"));
155 assert_eq!(auth.refresh_token(), Some("refresh-token"));
156 assert_eq!(auth.api_key_value(), None);
157 }
158
159 #[test]
160 fn test_oauth_needs_refresh() {
161 let expires = chrono::Utc::now().timestamp_millis() + 2 * 60 * 1000;
163 let auth = ProviderAuth::oauth("access", "refresh", expires);
164 assert!(auth.needs_refresh());
165
166 let expires = chrono::Utc::now().timestamp_millis() + 10 * 60 * 1000;
168 let auth = ProviderAuth::oauth("access", "refresh", expires);
169 assert!(!auth.needs_refresh());
170 }
171
172 #[test]
173 fn test_oauth_is_expired() {
174 let expires = chrono::Utc::now().timestamp_millis() - 1000;
176 let auth = ProviderAuth::oauth("access", "refresh", expires);
177 assert!(auth.is_expired());
178
179 let expires = chrono::Utc::now().timestamp_millis() + 3600000;
181 let auth = ProviderAuth::oauth("access", "refresh", expires);
182 assert!(!auth.is_expired());
183 }
184
185 #[test]
186 fn test_api_key_never_needs_refresh() {
187 let auth = ProviderAuth::api_key("sk-test");
188 assert!(!auth.needs_refresh());
189 assert!(!auth.is_expired());
190 }
191
192 #[test]
193 fn test_serde_api_key() {
194 let auth = ProviderAuth::api_key("sk-test-key");
195 let json = serde_json::to_string(&auth).unwrap();
196 assert!(json.contains("\"type\":\"api\""));
197 assert!(json.contains("\"key\":\"sk-test-key\""));
198
199 let parsed: ProviderAuth = serde_json::from_str(&json).unwrap();
200 assert_eq!(auth, parsed);
201 }
202
203 #[test]
204 fn test_serde_oauth() {
205 let auth = ProviderAuth::oauth("access-token", "refresh-token", 1735600000000);
206 let json = serde_json::to_string(&auth).unwrap();
207 assert!(json.contains("\"type\":\"oauth\""), "JSON was: {}", json);
208 assert!(json.contains("\"access\":\"access-token\""));
209 assert!(json.contains("\"refresh\":\"refresh-token\""));
210 assert!(json.contains("\"expires\":1735600000000"));
211
212 let parsed: ProviderAuth = serde_json::from_str(&json).unwrap();
213 assert_eq!(auth, parsed);
214 }
215
216 #[test]
217 fn test_auth_type_display() {
218 let api = ProviderAuth::api_key("key");
219 assert_eq!(api.auth_type_display(), "api_key");
220
221 let oauth = ProviderAuth::oauth("access", "refresh", 0);
222 assert_eq!(oauth.auth_type_display(), "oauth");
223 }
224}