stakpak_shared/oauth/
provider.rs1use super::config::OAuthConfig;
4use super::device_flow::{DeviceCodeResponse, DeviceFlow, DeviceTokenResponse};
5use super::error::{OAuthError, OAuthResult};
6use super::flow::TokenResponse;
7use crate::models::auth::ProviderAuth;
8use async_trait::async_trait;
9use reqwest::header::HeaderMap;
10
11#[derive(Debug, Clone)]
13pub struct AuthMethod {
14 pub id: String,
16 pub label: String,
18 pub description: Option<String>,
20 pub method_type: AuthMethodType,
22}
23
24impl AuthMethod {
25 pub fn oauth(
27 id: impl Into<String>,
28 label: impl Into<String>,
29 description: Option<String>,
30 ) -> Self {
31 Self {
32 id: id.into(),
33 label: label.into(),
34 description,
35 method_type: AuthMethodType::OAuth,
36 }
37 }
38
39 pub fn api_key(
41 id: impl Into<String>,
42 label: impl Into<String>,
43 description: Option<String>,
44 ) -> Self {
45 Self {
46 id: id.into(),
47 label: label.into(),
48 description,
49 method_type: AuthMethodType::ApiKey,
50 }
51 }
52
53 pub fn display(&self) -> String {
55 match &self.description {
56 Some(desc) => format!("{} - {}", self.label, desc),
57 None => self.label.clone(),
58 }
59 }
60}
61
62#[derive(Debug, Clone, PartialEq, Eq)]
64pub enum AuthMethodType {
65 OAuth,
67 ApiKey,
69 DeviceFlow,
71}
72
73#[async_trait]
75pub trait OAuthProvider: Send + Sync {
76 fn id(&self) -> &'static str;
78
79 fn name(&self) -> &'static str;
81
82 fn auth_methods(&self) -> Vec<AuthMethod>;
84
85 fn oauth_config(&self, method_id: &str) -> Option<OAuthConfig>;
87
88 async fn post_authorize(
93 &self,
94 method_id: &str,
95 tokens: &TokenResponse,
96 ) -> OAuthResult<ProviderAuth>;
97
98 fn apply_auth_headers(&self, auth: &ProviderAuth, headers: &mut HeaderMap) -> OAuthResult<()>;
103
104 fn api_key_env_var(&self) -> Option<&'static str> {
106 None
107 }
108
109 fn device_flow(&self, method_id: &str) -> OAuthResult<DeviceFlow> {
116 Err(OAuthError::unknown_method(format!(
117 "Provider '{}' does not support the Device Authorization Grant for method '{}'",
118 self.id(),
119 method_id,
120 )))
121 }
122
123 async fn request_device_code(
125 &self,
126 method_id: &str,
127 ) -> OAuthResult<(DeviceFlow, DeviceCodeResponse)> {
128 let flow = self.device_flow(method_id)?;
129 let code = flow.request_device_code().await?;
130 Ok((flow, code))
131 }
132
133 async fn wait_for_token(
135 &self,
136 flow: &DeviceFlow,
137 device_code: &DeviceCodeResponse,
138 ) -> OAuthResult<DeviceTokenResponse> {
139 flow.poll_for_token(device_code).await
140 }
141
142 async fn post_device_authorize(
144 &self,
145 method_id: &str,
146 token: &DeviceTokenResponse,
147 ) -> OAuthResult<ProviderAuth> {
148 let _ = (method_id, token);
149 Err(OAuthError::unknown_method(format!(
150 "Provider '{}' does not support post_device_authorize for method '{}'",
151 self.id(),
152 method_id,
153 )))
154 }
155}
156
157#[cfg(test)]
158mod tests {
159 use super::*;
160
161 #[test]
162 fn test_auth_method_oauth() {
163 let method = AuthMethod::oauth(
164 "claude-max",
165 "Claude Pro/Max",
166 Some("Use your subscription".to_string()),
167 );
168
169 assert_eq!(method.id, "claude-max");
170 assert_eq!(method.label, "Claude Pro/Max");
171 assert_eq!(
172 method.description,
173 Some("Use your subscription".to_string())
174 );
175 assert_eq!(method.method_type, AuthMethodType::OAuth);
176 }
177
178 #[test]
179 fn test_auth_method_api_key() {
180 let method = AuthMethod::api_key("api-key", "Manual API Key", None);
181
182 assert_eq!(method.id, "api-key");
183 assert_eq!(method.label, "Manual API Key");
184 assert_eq!(method.description, None);
185 assert_eq!(method.method_type, AuthMethodType::ApiKey);
186 }
187
188 #[test]
189 fn test_auth_method_display() {
190 let with_desc =
191 AuthMethod::oauth("test", "Test Method", Some("Description here".to_string()));
192 assert_eq!(with_desc.display(), "Test Method - Description here");
193
194 let without_desc = AuthMethod::oauth("test", "Test Method", None);
195 assert_eq!(without_desc.display(), "Test Method");
196 }
197}