stakpak_shared/oauth/
device_flow.rs1use super::error::{OAuthError, OAuthResult};
14use reqwest::Client;
15use serde::{Deserialize, Serialize};
16use std::time::Duration;
17
18#[derive(Debug, Clone, Deserialize)]
24pub struct DeviceCodeResponse {
25 pub device_code: String,
27 pub user_code: String,
29 pub verification_uri: String,
31 pub expires_in: u64,
33 pub interval: u64,
35}
36
37#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct DeviceTokenResponse {
40 pub access_token: String,
42 pub token_type: String,
44 #[serde(default)]
46 pub scope: String,
47}
48
49#[derive(Debug, Clone)]
51pub enum DeviceFlowState {
52 Pending {
54 user_code: String,
55 verification_uri: String,
56 },
57 Completed(DeviceTokenResponse),
59}
60
61#[derive(Debug, Deserialize)]
62struct PollRaw {
63 access_token: Option<String>,
65 token_type: Option<String>,
66 #[serde(default)]
67 scope: String,
68 error: Option<String>,
70 error_description: Option<String>,
71}
72
73pub struct DeviceFlow {
78 client_id: String,
79 scopes: Vec<String>,
80 device_code_url: String,
81 token_url: String,
82 client: Client,
84}
85
86impl DeviceFlow {
87 pub fn new(
94 client_id: impl Into<String>,
95 scopes: Vec<String>,
96 device_code_url: impl Into<String>,
97 token_url: impl Into<String>,
98 ) -> OAuthResult<Self> {
99 let client =
100 crate::tls_client::create_tls_client(crate::tls_client::TlsClientConfig::default())
101 .map_err(OAuthError::token_exchange_failed)?;
102 Ok(Self {
103 client_id: client_id.into(),
104 scopes,
105 device_code_url: device_code_url.into(),
106 token_url: token_url.into(),
107 client,
108 })
109 }
110
111 pub async fn request_device_code(&self) -> OAuthResult<DeviceCodeResponse> {
116 let scope = self.scopes.join(" ");
117
118 let response = self
119 .client
120 .post(&self.device_code_url)
121 .header("Accept", "application/json")
122 .form(&[("client_id", self.client_id.as_str()), ("scope", &scope)])
123 .send()
124 .await?;
125
126 if !response.status().is_success() {
127 let status = response.status();
128 let body = response.text().await.unwrap_or_default();
129 return Err(OAuthError::token_exchange_failed(format!(
130 "Device code request failed: HTTP {} — {}",
131 status, body
132 )));
133 }
134
135 response.json::<DeviceCodeResponse>().await.map_err(|e| {
136 OAuthError::token_exchange_failed(format!(
137 "Failed to parse device code response: {}",
138 e
139 ))
140 })
141 }
142
143 pub async fn poll_for_token(
150 &self,
151 device_code: &DeviceCodeResponse,
152 ) -> OAuthResult<DeviceTokenResponse> {
153 let mut interval_secs = device_code.interval;
154 let expires_at = std::time::Instant::now() + Duration::from_secs(device_code.expires_in);
155
156 loop {
157 if std::time::Instant::now() >= expires_at {
158 return Err(OAuthError::token_exchange_failed(
159 "Device code expired before the user completed authorisation",
160 ));
161 }
162
163 let response = self
164 .client
165 .post(&self.token_url)
166 .header("Accept", "application/json")
167 .form(&[
168 ("client_id", self.client_id.as_str()),
169 ("device_code", device_code.device_code.as_str()),
170 ("grant_type", "urn:ietf:params:oauth:grant-type:device_code"),
171 ])
172 .send()
173 .await?;
174
175 if !response.status().is_success() {
176 let status = response.status();
177 let body = response.text().await.unwrap_or_default();
178 return Err(OAuthError::token_exchange_failed(format!(
179 "Token polling failed: HTTP {} — {}",
180 status, body
181 )));
182 }
183
184 let poll_raw: PollRaw = response.json().await.map_err(|e| {
185 OAuthError::token_exchange_failed(format!(
186 "Failed to parse token poll response: {}",
187 e
188 ))
189 })?;
190
191 if let Some(ref err) = poll_raw.error {
193 match err.as_str() {
194 "authorization_pending" => {}
196 "slow_down" => {
198 interval_secs += 5;
199 }
200 "access_denied" => {
202 return Err(OAuthError::token_exchange_failed(
203 "User denied the authorisation request",
204 ));
205 }
206 "expired_token" | "token_expired" => {
207 return Err(OAuthError::token_exchange_failed("Device code expired"));
208 }
209 "unsupported_grant_type" => {
210 return Err(OAuthError::token_exchange_failed(
211 "Unsupported grant type — grant_type must be \
212 urn:ietf:params:oauth:grant-type:device_code",
213 ));
214 }
215 "incorrect_client_credentials" => {
216 return Err(OAuthError::token_exchange_failed(
217 "Incorrect client credentials — check the client_id",
218 ));
219 }
220 "incorrect_device_code" => {
221 return Err(OAuthError::token_exchange_failed(
222 "The device_code provided is not valid",
223 ));
224 }
225 "device_flow_disabled" => {
226 return Err(OAuthError::token_exchange_failed(
227 "Device flow is not enabled for this OAuth app",
228 ));
229 }
230 other => {
231 return Err(OAuthError::token_exchange_failed(format!(
232 "Unexpected error from provider: {} — {}",
233 other,
234 poll_raw.error_description.as_deref().unwrap_or("")
235 )));
236 }
237 }
238 } else if let Some(access_token) = poll_raw.access_token {
239 let token_type = poll_raw.token_type.unwrap_or_default();
240 return Ok(DeviceTokenResponse {
241 access_token,
242 token_type,
243 scope: poll_raw.scope,
244 });
245 } else {
246 return Err(OAuthError::token_exchange_failed(
247 "Token poll response contained neither an error nor an access token",
248 ));
249 }
250
251 tokio::time::sleep(Duration::from_secs(interval_secs)).await;
253 }
254 }
255}