Skip to main content

stack_auth/device_code/
mod.rs

1mod protocol;
2
3use cts_common::{CtsServiceDiscovery, Region, ServiceDiscovery};
4use url::Url;
5
6use std::time::{SystemTime, UNIX_EPOCH};
7
8use std::path::PathBuf;
9
10use stack_profile::ProfileStore;
11
12use crate::{ensure_trailing_slash, http_client, AuthError, DeviceIdentity, Token};
13use protocol::{
14    DeviceCode, DeviceCodeRequest, DeviceCodeResponse, ErrorResponse, TokenRequest, TokenResponse,
15};
16
17#[cfg(test)]
18mod tests;
19
20/// Authenticates with CipherStash using the
21/// [device code flow (RFC 8628)](https://datatracker.ietf.org/doc/html/rfc8628).
22///
23/// This is the primary entry point for CLI and browserless authentication.
24/// Create a strategy with [`DeviceCodeStrategy::new`], then call
25/// [`begin`](DeviceCodeStrategy::begin) to start the flow.
26///
27/// # Example
28///
29/// ```
30/// use stack_auth::DeviceCodeStrategy;
31/// use cts_common::Region;
32///
33/// let region = Region::aws("ap-southeast-2").unwrap();
34/// let strategy = DeviceCodeStrategy::new(region, "my-client-id").unwrap();
35/// ```
36pub struct DeviceCodeStrategy {
37    region: Region,
38    base_url: Url,
39    client_id: String,
40    profile_dir: Option<PathBuf>,
41    device_identity: Option<DeviceIdentity>,
42}
43
44impl DeviceCodeStrategy {
45    /// Create a new strategy for the given CipherStash region and OAuth client ID.
46    ///
47    /// The auth endpoint is resolved automatically via service discovery.
48    ///
49    /// # Example
50    ///
51    /// ```
52    /// use stack_auth::DeviceCodeStrategy;
53    /// use cts_common::Region;
54    ///
55    /// let strategy = DeviceCodeStrategy::new(
56    ///     Region::aws("ap-southeast-2").unwrap(),
57    ///     "my-client-id",
58    /// ).unwrap();
59    /// ```
60    pub fn new(region: Region, client_id: impl Into<String>) -> Result<Self, AuthError> {
61        Self::builder(region, client_id).build()
62    }
63
64    /// Return a builder for configuring a `DeviceCodeStrategy` before construction.
65    pub fn builder(region: Region, client_id: impl Into<String>) -> DeviceCodeStrategyBuilder {
66        DeviceCodeStrategyBuilder {
67            region,
68            client_id: client_id.into(),
69            base_url_override: None,
70            profile_dir: None,
71            device_identity: None,
72        }
73    }
74
75    /// Start the device code flow.
76    ///
77    /// Requests a device code from the CipherStash auth server and returns a
78    /// [`PendingDeviceCode`] with the user-facing codes and URIs. Show these
79    /// to the user, then call [`PendingDeviceCode::poll_for_token`] to wait
80    /// for authorization.
81    ///
82    /// # Errors
83    ///
84    /// Returns [`AuthError::InvalidClient`] if the client ID is not recognized,
85    /// or [`AuthError::Request`] if the server is unreachable.
86    pub async fn begin(&self) -> Result<PendingDeviceCode, AuthError> {
87        let client = http_client();
88
89        let code_url = self.base_url.join("oauth/device/code")?;
90
91        tracing::debug!(url = %code_url, client_id = %self.client_id, "requesting device code");
92
93        let device_instance_id = self
94            .device_identity
95            .as_ref()
96            .map(|d| d.device_instance_id.to_string());
97
98        let code_resp = client
99            .post(code_url)
100            .form(&DeviceCodeRequest {
101                client_id: &self.client_id,
102                device_instance_id: device_instance_id.as_deref(),
103                device_name: self
104                    .device_identity
105                    .as_ref()
106                    .map(|d| d.device_name.as_str()),
107            })
108            .send()
109            .await?;
110
111        if !code_resp.status().is_success() {
112            let err: ErrorResponse = code_resp.json().await?;
113            tracing::debug!(error = %err.error, "device code request failed");
114            return Err(match err.error.as_str() {
115                "invalid_client" => AuthError::InvalidClient,
116                _ => AuthError::Server(err.error_description),
117            });
118        }
119
120        let code: DeviceCodeResponse = code_resp.json().await?;
121
122        let token_url = self.base_url.join("oauth/device/token")?;
123
124        tracing::debug!(
125            user_code = %code.user_code,
126            expires_in = code.expires_in,
127            "device code received"
128        );
129
130        Ok(PendingDeviceCode {
131            token_url,
132            region: self.region,
133            client_id: self.client_id.clone(),
134            device_code: code.device_code,
135            user_code: code.user_code,
136            verification_uri: code.verification_uri,
137            verification_uri_complete: code.verification_uri_complete,
138            expires_in: code.expires_in,
139            profile_dir: self.profile_dir.clone(),
140            device_identity: self.device_identity.clone(),
141        })
142    }
143}
144
145/// Builder for [`DeviceCodeStrategy`].
146///
147/// Created via [`DeviceCodeStrategy::builder`].
148pub struct DeviceCodeStrategyBuilder {
149    region: Region,
150    client_id: String,
151    base_url_override: Option<Url>,
152    profile_dir: Option<PathBuf>,
153    device_identity: Option<DeviceIdentity>,
154}
155
156impl DeviceCodeStrategyBuilder {
157    /// Override the base URL resolved by service discovery.
158    ///
159    /// Useful for pointing at a local or mock CTS instance during testing.
160    #[cfg(any(test, feature = "test-utils"))]
161    pub fn base_url(mut self, url: Url) -> Self {
162        self.base_url_override = Some(url);
163        self
164    }
165
166    /// Override the profile directory used to persist the token.
167    ///
168    /// By default tokens are saved to `~/.cipherstash/auth.json`. Use this in
169    /// tests to redirect writes to a temporary directory.
170    #[cfg(any(test, feature = "test-utils"))]
171    pub fn profile_dir(mut self, dir: impl Into<PathBuf>) -> Self {
172        self.profile_dir = Some(dir.into());
173        self
174    }
175
176    /// Set the device identity for this strategy.
177    ///
178    /// When set, the device instance ID and name are sent to the auth server
179    /// during the device code flow and persisted in the token.
180    pub fn device_identity(mut self, identity: DeviceIdentity) -> Self {
181        self.device_identity = Some(identity);
182        self
183    }
184
185    /// Build the [`DeviceCodeStrategy`].
186    ///
187    /// Resolves the base URL via service discovery unless overridden with
188    /// `base_url` (available when the `test-utils` feature is enabled).
189    pub fn build(self) -> Result<DeviceCodeStrategy, AuthError> {
190        let base_url = match self.base_url_override {
191            Some(url) => url,
192            None => crate::cts_base_url_from_env()?
193                .unwrap_or(CtsServiceDiscovery::endpoint(self.region)?),
194        };
195        Ok(DeviceCodeStrategy {
196            region: self.region,
197            base_url: ensure_trailing_slash(base_url),
198            client_id: self.client_id,
199            profile_dir: self.profile_dir,
200            device_identity: self.device_identity,
201        })
202    }
203}
204
205/// A device code flow that is waiting for the user to authorize.
206///
207/// Returned by [`DeviceCodeStrategy::begin`]. Display the
208/// [`user_code`](Self::user_code) and
209/// [`verification_uri_complete`](Self::verification_uri_complete) to the user
210/// (or call [`open_in_browser`](Self::open_in_browser)), then call
211/// [`poll_for_token`](Self::poll_for_token) to wait for authorization.
212///
213/// # Example
214///
215/// ```no_run
216/// # use stack_auth::DeviceCodeStrategy;
217/// # use cts_common::Region;
218/// # async fn run() -> Result<(), Box<dyn std::error::Error>> {
219/// # let strategy = DeviceCodeStrategy::new(Region::aws("ap-southeast-2")?, "cli")?;
220/// let pending = strategy.begin().await?;
221///
222/// println!("Go to: {}", pending.verification_uri_complete());
223/// println!("Enter code: {}", pending.user_code());
224///
225/// let token = pending.poll_for_token().await?;
226/// # Ok(())
227/// # }
228/// ```
229#[derive(Debug)]
230pub struct PendingDeviceCode {
231    token_url: Url,
232    region: Region,
233    client_id: String,
234    device_code: DeviceCode,
235    /// The short code the user must enter to authorize this device.
236    user_code: String,
237    /// The base verification URI (without the user code embedded).
238    verification_uri: String,
239    /// The full verification URI with the user code pre-filled.
240    verification_uri_complete: String,
241    /// How many seconds the device code remains valid.
242    expires_in: u64,
243    /// Profile directory override. Falls back to `~/.cipherstash`.
244    profile_dir: Option<PathBuf>,
245    /// Device identity to associate with the token.
246    device_identity: Option<DeviceIdentity>,
247}
248
249impl PendingDeviceCode {
250    /// The short code the user must enter to authorize this device.
251    pub fn user_code(&self) -> &str {
252        &self.user_code
253    }
254
255    /// The base verification URI (without the user code embedded).
256    pub fn verification_uri(&self) -> &str {
257        &self.verification_uri
258    }
259
260    /// The full verification URI with the user code pre-filled.
261    pub fn verification_uri_complete(&self) -> &str {
262        &self.verification_uri_complete
263    }
264
265    /// How many seconds the device code remains valid.
266    pub fn expires_in(&self) -> u64 {
267        self.expires_in
268    }
269
270    /// Open the verification URI in the user's default browser.
271    ///
272    /// Returns `true` if the browser was opened successfully.
273    pub fn open_in_browser(&self) -> bool {
274        open::that(&self.verification_uri_complete).is_ok()
275    }
276
277    /// Poll the auth server until the user authorizes (or the code expires).
278    ///
279    /// This method consumes `self` and blocks asynchronously, polling at a
280    /// server-controlled interval (starting at 5 seconds). It returns a
281    /// [`Token`] on success.
282    ///
283    /// # Errors
284    ///
285    /// - [`AuthError::AccessDenied`] — the user rejected the request.
286    /// - [`AuthError::TokenExpired`] — the device code expired before the user
287    ///   authorized.
288    /// - [`AuthError::Request`] — a network error occurred while polling.
289    pub async fn poll_for_token(self) -> Result<Token, AuthError> {
290        let client = http_client();
291        let mut interval = tokio::time::Duration::from_secs(5);
292        let deadline =
293            tokio::time::Instant::now() + tokio::time::Duration::from_secs(self.expires_in);
294
295        tracing::debug!(
296            url = %self.token_url,
297            expires_in = self.expires_in,
298            "polling for token"
299        );
300
301        loop {
302            if tokio::time::Instant::now() >= deadline {
303                tracing::debug!("device code expired while polling");
304                return Err(AuthError::TokenExpired);
305            }
306
307            let resp = client
308                .post(self.token_url.clone())
309                .form(&TokenRequest {
310                    client_id: &self.client_id,
311                    device_code: &self.device_code,
312                    grant_type: "urn:ietf:params:oauth:grant-type:device_code",
313                })
314                .send()
315                .await?;
316
317            if resp.status().is_success() {
318                tracing::debug!("token received");
319                let token_resp: TokenResponse = resp.json().await?;
320                let now = SystemTime::now()
321                    .duration_since(UNIX_EPOCH)
322                    .unwrap_or_default()
323                    .as_secs();
324                let mut token = Token {
325                    access_token: token_resp.access_token,
326                    token_type: token_resp.token_type,
327                    expires_at: now + token_resp.expires_in,
328                    refresh_token: token_resp.refresh_token,
329                    region: None,
330                    client_id: None,
331                    device_instance_id: None,
332                };
333                token.set_region(self.region.identifier());
334                token.set_client_id(&self.client_id);
335                if let Some(ref identity) = self.device_identity {
336                    token.set_device_instance_id(identity.device_instance_id.to_string());
337                }
338
339                let store = match &self.profile_dir {
340                    Some(dir) => ProfileStore::new(dir),
341                    None => ProfileStore::resolve(None)?,
342                };
343                match store.save_profile(&token) {
344                    Ok(()) => tracing::debug!("token saved to disk"),
345                    Err(err) => tracing::warn!(%err, "failed to save token to disk"),
346                }
347
348                return Ok(token);
349            }
350
351            let err: ErrorResponse = resp.json().await?;
352            match err.error.as_str() {
353                "authorization_pending" => {
354                    tracing::debug!("authorization pending, retrying");
355                }
356                "slow_down" => {
357                    interval += tokio::time::Duration::from_secs(5);
358                    tracing::debug!(interval_secs = interval.as_secs(), "slowing down");
359                }
360                "expired_token" => return Err(AuthError::TokenExpired),
361                "access_denied" => return Err(AuthError::AccessDenied),
362                "invalid_grant" => return Err(AuthError::InvalidGrant),
363                "invalid_client" => return Err(AuthError::InvalidClient),
364                _ => return Err(AuthError::Server(err.error_description)),
365            }
366
367            tokio::time::sleep(interval).await;
368        }
369    }
370}