Skip to main content

stakpak_shared/oauth/
device_flow.rs

1//! Device Authorization Grant (RFC 8628)
2//!
3//! Generic implementation of the OAuth 2.0 Device Authorization Grant.  The
4//! flow has two stages:
5//!
6//! 1. **Request** — POST to the provider's device-code endpoint.  Returns a
7//!    `device_code`, `user_code`, `verification_uri` and a polling `interval`.
8//! 2. **Poll** — Repeatedly POST to the token endpoint until the user has
9//!    authorised the device, then store the returned access token.
10//!
11//! Reference: <https://www.rfc-editor.org/rfc/rfc8628>
12
13use super::error::{OAuthError, OAuthResult};
14use reqwest::Client;
15use serde::{Deserialize, Serialize};
16use std::time::Duration;
17
18// ---------------------------------------------------------------------------
19// Public types
20// ---------------------------------------------------------------------------
21
22/// The initial response from the device-code endpoint
23#[derive(Debug, Clone, Deserialize)]
24pub struct DeviceCodeResponse {
25    /// Opaque code used in the polling phase
26    pub device_code: String,
27    /// Short code the user types on the verification page
28    pub user_code: String,
29    /// URL the user should visit
30    pub verification_uri: String,
31    /// Total time (seconds) the device code is valid
32    pub expires_in: u64,
33    /// Minimum seconds between poll attempts
34    pub interval: u64,
35}
36
37/// Successful token response from the polling endpoint
38#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct DeviceTokenResponse {
40    /// GitHub OAuth access token
41    pub access_token: String,
42    /// Token type (usually "bearer")
43    pub token_type: String,
44    /// Space-separated list of granted scopes
45    #[serde(default)]
46    pub scope: String,
47}
48
49/// Current state of an in-progress device flow
50#[derive(Debug, Clone)]
51pub enum DeviceFlowState {
52    /// Waiting for the user to complete authorisation
53    Pending {
54        user_code: String,
55        verification_uri: String,
56    },
57    /// Authorisation complete; contains the access token
58    Completed(DeviceTokenResponse),
59}
60
61#[derive(Debug, Deserialize)]
62struct PollRaw {
63    // success fields
64    access_token: Option<String>,
65    token_type: Option<String>,
66    #[serde(default)]
67    scope: String,
68    // error fields
69    error: Option<String>,
70    error_description: Option<String>,
71}
72
73/// Manages a complete RFC 8628 device-flow OAuth session.
74///
75/// Provider-agnostic: supply the device-code URL and token URL for any
76/// provider that supports the Device Authorization Grant.
77pub struct DeviceFlow {
78    client_id: String,
79    scopes: Vec<String>,
80    device_code_url: String,
81    token_url: String,
82    /// Reused across all HTTP calls in the flow (avoids re-creating TLS on every poll)
83    client: Client,
84}
85
86impl DeviceFlow {
87    /// Create a new device flow for the given OAuth app and provider endpoints.
88    ///
89    /// - `device_code_url`: the provider's device-authorization endpoint
90    /// - `token_url`: the provider's token endpoint used during polling
91    ///
92    /// Returns an error if the underlying TLS client cannot be constructed.
93    pub fn new(
94        client_id: impl Into<String>,
95        scopes: Vec<String>,
96        device_code_url: impl Into<String>,
97        token_url: impl Into<String>,
98    ) -> OAuthResult<Self> {
99        let client =
100            crate::tls_client::create_tls_client(crate::tls_client::TlsClientConfig::default())
101                .map_err(OAuthError::token_exchange_failed)?;
102        Ok(Self {
103            client_id: client_id.into(),
104            scopes,
105            device_code_url: device_code_url.into(),
106            token_url: token_url.into(),
107            client,
108        })
109    }
110
111    /// Step 1 — request a device code from the provider.
112    ///
113    /// Returns the `DeviceCodeResponse` that you should present to the user
114    /// (display `user_code` and `verification_uri`).
115    pub async fn request_device_code(&self) -> OAuthResult<DeviceCodeResponse> {
116        let scope = self.scopes.join(" ");
117
118        let response = self
119            .client
120            .post(&self.device_code_url)
121            .header("Accept", "application/json")
122            .form(&[("client_id", self.client_id.as_str()), ("scope", &scope)])
123            .send()
124            .await?;
125
126        if !response.status().is_success() {
127            let status = response.status();
128            let body = response.text().await.unwrap_or_default();
129            return Err(OAuthError::token_exchange_failed(format!(
130                "Device code request failed: HTTP {} — {}",
131                status, body
132            )));
133        }
134
135        response.json::<DeviceCodeResponse>().await.map_err(|e| {
136            OAuthError::token_exchange_failed(format!(
137                "Failed to parse device code response: {}",
138                e
139            ))
140        })
141    }
142
143    /// Step 2 — poll the provider until the user has authorised the device.
144    ///
145    /// Automatically respects the `interval` returned by step 1 and handles
146    /// `slow_down` responses (which add 5 s to the current interval per spec).
147    ///
148    /// Returns `Ok(DeviceTokenResponse)` once the user approves the request.
149    pub async fn poll_for_token(
150        &self,
151        device_code: &DeviceCodeResponse,
152    ) -> OAuthResult<DeviceTokenResponse> {
153        let mut interval_secs = device_code.interval;
154        let expires_at = std::time::Instant::now() + Duration::from_secs(device_code.expires_in);
155
156        loop {
157            if std::time::Instant::now() >= expires_at {
158                return Err(OAuthError::token_exchange_failed(
159                    "Device code expired before the user completed authorisation",
160                ));
161            }
162
163            let response = self
164                .client
165                .post(&self.token_url)
166                .header("Accept", "application/json")
167                .form(&[
168                    ("client_id", self.client_id.as_str()),
169                    ("device_code", device_code.device_code.as_str()),
170                    ("grant_type", "urn:ietf:params:oauth:grant-type:device_code"),
171                ])
172                .send()
173                .await?;
174
175            if !response.status().is_success() {
176                let status = response.status();
177                let body = response.text().await.unwrap_or_default();
178                return Err(OAuthError::token_exchange_failed(format!(
179                    "Token polling failed: HTTP {} — {}",
180                    status, body
181                )));
182            }
183
184            let poll_raw: PollRaw = response.json().await.map_err(|e| {
185                OAuthError::token_exchange_failed(format!(
186                    "Failed to parse token poll response: {}",
187                    e
188                ))
189            })?;
190
191            // Error field takes priority — check it before looking at token fields.
192            if let Some(ref err) = poll_raw.error {
193                match err.as_str() {
194                    // Normal — user hasn't approved yet; wait then retry
195                    "authorization_pending" => {}
196                    // Provider asked us to back off — add 5 s per RFC 8628 §3.5
197                    "slow_down" => {
198                        interval_secs += 5;
199                    }
200                    // Terminal errors — stop immediately
201                    "access_denied" => {
202                        return Err(OAuthError::token_exchange_failed(
203                            "User denied the authorisation request",
204                        ));
205                    }
206                    "expired_token" | "token_expired" => {
207                        return Err(OAuthError::token_exchange_failed("Device code expired"));
208                    }
209                    "unsupported_grant_type" => {
210                        return Err(OAuthError::token_exchange_failed(
211                            "Unsupported grant type — grant_type must be \
212                             urn:ietf:params:oauth:grant-type:device_code",
213                        ));
214                    }
215                    "incorrect_client_credentials" => {
216                        return Err(OAuthError::token_exchange_failed(
217                            "Incorrect client credentials — check the client_id",
218                        ));
219                    }
220                    "incorrect_device_code" => {
221                        return Err(OAuthError::token_exchange_failed(
222                            "The device_code provided is not valid",
223                        ));
224                    }
225                    "device_flow_disabled" => {
226                        return Err(OAuthError::token_exchange_failed(
227                            "Device flow is not enabled for this OAuth app",
228                        ));
229                    }
230                    other => {
231                        return Err(OAuthError::token_exchange_failed(format!(
232                            "Unexpected error from provider: {} — {}",
233                            other,
234                            poll_raw.error_description.as_deref().unwrap_or("")
235                        )));
236                    }
237                }
238            } else if let Some(access_token) = poll_raw.access_token {
239                let token_type = poll_raw.token_type.unwrap_or_default();
240                return Ok(DeviceTokenResponse {
241                    access_token,
242                    token_type,
243                    scope: poll_raw.scope,
244                });
245            } else {
246                return Err(OAuthError::token_exchange_failed(
247                    "Token poll response contained neither an error nor an access token",
248                ));
249            }
250
251            // RFC 8628 only requires the minimum gap between requests.
252            tokio::time::sleep(Duration::from_secs(interval_secs)).await;
253        }
254    }
255}