synd_auth/device_flow/
mod.rs1use std::{borrow::Cow, time::Duration};
2
3use http::{StatusCode, Uri};
4use reqwest::{Client, Url};
5use serde::{Deserialize, Serialize};
6use tracing::debug;
7
8use crate::USER_AGENT;
9
10pub mod provider;
11
12pub trait Provider: private::Sealed {
13 type DeviceAccessTokenRequest<'d>: Serialize + Send
14 where
15 Self: 'd;
16
17 fn device_authorization_endpoint(&self) -> Url;
18 fn token_endpoint(&self) -> Url;
19 fn device_authorization_request(&self) -> DeviceAuthorizationRequest;
20 fn device_access_token_request<'d, 'p: 'd>(
21 &'p self,
22 device_code: &'d str,
23 ) -> Self::DeviceAccessTokenRequest<'d>;
24}
25
26mod private {
27 use crate::device_flow::provider;
28
29 pub trait Sealed {}
30
31 impl Sealed for provider::Github {}
32 impl Sealed for provider::Google {}
33}
34
35#[derive(Clone)]
36pub struct DeviceFlow<P> {
37 provider: P,
38 client: Client,
39}
40
41impl<P> DeviceFlow<P> {
42 pub fn new(provider: P) -> Self {
43 let client = reqwest::ClientBuilder::new()
44 .user_agent(USER_AGENT)
45 .connect_timeout(Duration::from_secs(10))
46 .timeout(Duration::from_secs(10))
47 .build()
48 .unwrap();
49
50 Self { provider, client }
51 }
52}
53
54impl<P: Provider> DeviceFlow<P> {
55 #[tracing::instrument(skip(self))]
56 pub async fn device_authorize_request(&self) -> anyhow::Result<DeviceAuthorizationResponse> {
57 let response = self
58 .client
59 .post(self.provider.device_authorization_endpoint())
60 .header(http::header::ACCEPT, "application/json")
61 .form(&self.provider.device_authorization_request())
62 .send()
63 .await?
64 .error_for_status()?
65 .json::<DeviceAuthorizationResponse>()
66 .await?;
67
68 Ok(response)
69 }
70
71 pub async fn poll_device_access_token(
72 &self,
73 device_code: String,
74 interval: Option<i64>,
75 ) -> anyhow::Result<DeviceAccessTokenResponse> {
76 macro_rules! continue_or_abort {
78 ( $response_bytes:ident ) => {{
79 let err_response = serde_json::from_slice::<DeviceAccessTokenErrorResponse>(&$response_bytes)?;
80 if err_response.error.should_continue_to_poll() {
81 debug!(error_code=?err_response.error,interval, "Continue to poll");
82
83 let interval = interval.unwrap_or(5);
84
85 #[allow(clippy::cast_sign_loss)]
86 tokio::time::sleep(Duration::from_secs(interval as u64)).await;
87 } else {
88 anyhow::bail!(
89 "authorization server or oidc provider respond with {err_response:?}"
90 )
91 }
92 }};
93 }
94
95 let response = loop {
96 let response = self
97 .client
98 .post(self.provider.token_endpoint())
99 .header(http::header::ACCEPT, "application/json")
100 .form(&self.provider.device_access_token_request(&device_code))
101 .send()
102 .await?;
103
104 match response.status() {
105 StatusCode::OK => {
106 let full = response.bytes().await?;
107 if let Ok(response) = serde_json::from_slice::<DeviceAccessTokenResponse>(&full)
108 {
109 break response;
110 }
111 continue_or_abort!(full);
112 }
113 StatusCode::BAD_REQUEST | StatusCode::PRECONDITION_REQUIRED => {
115 let full = response.bytes().await?;
116 continue_or_abort!(full);
117 }
118 other => {
119 let error_msg = response.text().await.unwrap_or_default();
120 anyhow::bail!("Failed to authenticate. authorization server respond with {other} {error_msg}")
121 }
122 }
123 };
124
125 Ok(response)
126 }
127}
128
129#[derive(Serialize, Deserialize, Debug)]
131pub struct DeviceAuthorizationRequest<'s> {
132 pub client_id: Cow<'s, str>,
133 pub scope: Cow<'s, str>,
134}
135
136#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
138pub struct DeviceAuthorizationResponse {
139 pub device_code: String,
141 pub user_code: String,
143 #[serde(with = "http_serde_ext::uri::option", default)]
145 pub verification_uri: Option<Uri>,
146 #[serde(with = "http_serde_ext::uri::option", default)]
148 pub verification_url: Option<Uri>,
149 #[allow(unused)]
151 #[serde(with = "http_serde_ext::uri::option", default)]
152 pub verification_uri_complete: Option<Uri>,
153 #[allow(unused)]
155 pub expires_in: i64,
156 pub interval: Option<i64>,
159}
160
161impl DeviceAuthorizationResponse {
162 pub fn verification_uri(&self) -> &Uri {
163 self.verification_uri
164 .as_ref()
165 .or(self.verification_url.as_ref())
166 .expect("verification uri or url not found")
167 }
168}
169
170#[derive(Serialize, Deserialize)]
171pub struct DeviceAccessTokenRequest<'s> {
172 grant_type: Cow<'static, str>,
174 pub device_code: Cow<'s, str>,
176 pub client_id: Cow<'s, str>,
177}
178
179impl<'s> DeviceAccessTokenRequest<'s> {
180 const GRANT_TYPE: &'static str = "urn:ietf:params:oauth:grant-type:device_code";
181
182 #[must_use]
183 pub fn new(device_code: impl Into<Cow<'s, str>>, client_id: &'s str) -> Self {
184 Self {
185 grant_type: Self::GRANT_TYPE.into(),
186 device_code: device_code.into(),
187 client_id: client_id.into(),
188 }
189 }
190}
191
192#[derive(Serialize, Deserialize, Debug, Clone)]
195pub struct DeviceAccessTokenResponse {
196 pub access_token: String,
198 pub token_type: String,
199 pub expires_in: Option<i64>,
201
202 pub refresh_token: Option<String>,
204 pub id_token: Option<String>,
205}
206
207#[derive(Serialize, Deserialize, Debug)]
209pub struct DeviceAccessTokenErrorResponse {
210 pub error: DeviceAccessTokenErrorCode,
211 #[allow(unused)]
212 pub error_description: Option<String>,
213 #[allow(unused)]
215 #[serde(with = "http_serde_ext::uri::option", skip_deserializing)]
216 pub error_uri: Option<Uri>,
217}
218
219#[derive(PartialEq, Eq, Debug, Serialize, Deserialize)]
220#[serde(rename_all = "snake_case")]
221pub enum DeviceAccessTokenErrorCode {
222 AuthorizationPending,
223 SlowDown,
224 AccessDenied,
225 ExpiredToken,
226 InvalidRequest,
227 InvalidClient,
228 InvalidGrant,
229 UnauthorizedClient,
230 UnsupportedGrantType,
231 InvalidScope,
232 IncorrectDeviceCode,
233}
234
235impl DeviceAccessTokenErrorCode {
236 pub fn should_continue_to_poll(&self) -> bool {
238 use DeviceAccessTokenErrorCode::{AuthorizationPending, SlowDown};
239 *self == AuthorizationPending || *self == SlowDown
240 }
241}