1mod protocol;
2
3use cts_common::{CtsServiceDiscovery, Region, ServiceDiscovery};
4use url::Url;
5
6use std::time::{SystemTime, UNIX_EPOCH};
7
8use std::path::PathBuf;
9
10use stack_profile::ProfileStore;
11
12use crate::{ensure_trailing_slash, http_client, AuthError, DeviceIdentity, Token};
13use protocol::{
14 DeviceCode, DeviceCodeRequest, DeviceCodeResponse, ErrorResponse, TokenRequest, TokenResponse,
15};
16
17#[cfg(test)]
18mod tests;
19
20pub struct DeviceCodeStrategy {
37 region: Region,
38 base_url: Url,
39 client_id: String,
40 profile_dir: Option<PathBuf>,
41 device_identity: Option<DeviceIdentity>,
42}
43
44impl DeviceCodeStrategy {
45 pub fn new(region: Region, client_id: impl Into<String>) -> Result<Self, AuthError> {
61 Self::builder(region, client_id).build()
62 }
63
64 pub fn builder(region: Region, client_id: impl Into<String>) -> DeviceCodeStrategyBuilder {
66 DeviceCodeStrategyBuilder {
67 region,
68 client_id: client_id.into(),
69 base_url_override: None,
70 profile_dir: None,
71 device_identity: None,
72 }
73 }
74
75 pub async fn begin(&self) -> Result<PendingDeviceCode, AuthError> {
87 let client = http_client();
88
89 let code_url = self.base_url.join("oauth/device/code")?;
90
91 tracing::debug!(url = %code_url, client_id = %self.client_id, "requesting device code");
92
93 let device_instance_id = self
94 .device_identity
95 .as_ref()
96 .map(|d| d.device_instance_id.to_string());
97
98 let code_resp = client
99 .post(code_url)
100 .form(&DeviceCodeRequest {
101 client_id: &self.client_id,
102 device_instance_id: device_instance_id.as_deref(),
103 device_name: self
104 .device_identity
105 .as_ref()
106 .map(|d| d.device_name.as_str()),
107 })
108 .send()
109 .await?;
110
111 if !code_resp.status().is_success() {
112 let err: ErrorResponse = code_resp.json().await?;
113 tracing::debug!(error = %err.error, "device code request failed");
114 return Err(match err.error.as_str() {
115 "invalid_client" => AuthError::InvalidClient,
116 _ => AuthError::Server(err.error_description),
117 });
118 }
119
120 let code: DeviceCodeResponse = code_resp.json().await?;
121
122 let token_url = self.base_url.join("oauth/device/token")?;
123
124 tracing::debug!(
125 user_code = %code.user_code,
126 expires_in = code.expires_in,
127 "device code received"
128 );
129
130 Ok(PendingDeviceCode {
131 token_url,
132 region: self.region,
133 client_id: self.client_id.clone(),
134 device_code: code.device_code,
135 user_code: code.user_code,
136 verification_uri: code.verification_uri,
137 verification_uri_complete: code.verification_uri_complete,
138 expires_in: code.expires_in,
139 profile_dir: self.profile_dir.clone(),
140 device_identity: self.device_identity.clone(),
141 })
142 }
143}
144
145pub struct DeviceCodeStrategyBuilder {
149 region: Region,
150 client_id: String,
151 base_url_override: Option<Url>,
152 profile_dir: Option<PathBuf>,
153 device_identity: Option<DeviceIdentity>,
154}
155
156impl DeviceCodeStrategyBuilder {
157 #[cfg(any(test, feature = "test-utils"))]
161 pub fn base_url(mut self, url: Url) -> Self {
162 self.base_url_override = Some(url);
163 self
164 }
165
166 #[cfg(any(test, feature = "test-utils"))]
171 pub fn profile_dir(mut self, dir: impl Into<PathBuf>) -> Self {
172 self.profile_dir = Some(dir.into());
173 self
174 }
175
176 pub fn device_identity(mut self, identity: DeviceIdentity) -> Self {
181 self.device_identity = Some(identity);
182 self
183 }
184
185 pub fn build(self) -> Result<DeviceCodeStrategy, AuthError> {
190 let base_url = match self.base_url_override {
191 Some(url) => url,
192 None => crate::cts_base_url_from_env()?
193 .unwrap_or(CtsServiceDiscovery::endpoint(self.region)?),
194 };
195 Ok(DeviceCodeStrategy {
196 region: self.region,
197 base_url: ensure_trailing_slash(base_url),
198 client_id: self.client_id,
199 profile_dir: self.profile_dir,
200 device_identity: self.device_identity,
201 })
202 }
203}
204
205#[derive(Debug)]
230pub struct PendingDeviceCode {
231 token_url: Url,
232 region: Region,
233 client_id: String,
234 device_code: DeviceCode,
235 user_code: String,
237 verification_uri: String,
239 verification_uri_complete: String,
241 expires_in: u64,
243 profile_dir: Option<PathBuf>,
245 device_identity: Option<DeviceIdentity>,
247}
248
249impl PendingDeviceCode {
250 pub fn user_code(&self) -> &str {
252 &self.user_code
253 }
254
255 pub fn verification_uri(&self) -> &str {
257 &self.verification_uri
258 }
259
260 pub fn verification_uri_complete(&self) -> &str {
262 &self.verification_uri_complete
263 }
264
265 pub fn expires_in(&self) -> u64 {
267 self.expires_in
268 }
269
270 pub fn open_in_browser(&self) -> bool {
274 open::that(&self.verification_uri_complete).is_ok()
275 }
276
277 pub async fn poll_for_token(self) -> Result<Token, AuthError> {
290 let client = http_client();
291 let mut interval = tokio::time::Duration::from_secs(5);
292 let deadline =
293 tokio::time::Instant::now() + tokio::time::Duration::from_secs(self.expires_in);
294
295 tracing::debug!(
296 url = %self.token_url,
297 expires_in = self.expires_in,
298 "polling for token"
299 );
300
301 loop {
302 if tokio::time::Instant::now() >= deadline {
303 tracing::debug!("device code expired while polling");
304 return Err(AuthError::TokenExpired);
305 }
306
307 let resp = client
308 .post(self.token_url.clone())
309 .form(&TokenRequest {
310 client_id: &self.client_id,
311 device_code: &self.device_code,
312 grant_type: "urn:ietf:params:oauth:grant-type:device_code",
313 })
314 .send()
315 .await?;
316
317 if resp.status().is_success() {
318 tracing::debug!("token received");
319 let token_resp: TokenResponse = resp.json().await?;
320 let now = SystemTime::now()
321 .duration_since(UNIX_EPOCH)
322 .unwrap_or_default()
323 .as_secs();
324 let mut token = Token {
325 access_token: token_resp.access_token,
326 token_type: token_resp.token_type,
327 expires_at: now + token_resp.expires_in,
328 refresh_token: token_resp.refresh_token,
329 region: None,
330 client_id: None,
331 device_instance_id: None,
332 };
333 token.set_region(self.region.identifier());
334 token.set_client_id(&self.client_id);
335 if let Some(ref identity) = self.device_identity {
336 token.set_device_instance_id(identity.device_instance_id.to_string());
337 }
338
339 let store = match &self.profile_dir {
340 Some(dir) => ProfileStore::new(dir),
341 None => ProfileStore::resolve(None)?,
342 };
343 match store.save_profile(&token) {
344 Ok(()) => tracing::debug!("token saved to disk"),
345 Err(err) => tracing::warn!(%err, "failed to save token to disk"),
346 }
347
348 return Ok(token);
349 }
350
351 let err: ErrorResponse = resp.json().await?;
352 match err.error.as_str() {
353 "authorization_pending" => {
354 tracing::debug!("authorization pending, retrying");
355 }
356 "slow_down" => {
357 interval += tokio::time::Duration::from_secs(5);
358 tracing::debug!(interval_secs = interval.as_secs(), "slowing down");
359 }
360 "expired_token" => return Err(AuthError::TokenExpired),
361 "access_denied" => return Err(AuthError::AccessDenied),
362 "invalid_grant" => return Err(AuthError::InvalidGrant),
363 "invalid_client" => return Err(AuthError::InvalidClient),
364 _ => return Err(AuthError::Server(err.error_description)),
365 }
366
367 tokio::time::sleep(interval).await;
368 }
369 }
370}