synd_auth/device_flow/
mod.rs

1use std::{borrow::Cow, time::Duration};
2
3use http::{StatusCode, Uri};
4use reqwest::{Client, Url};
5use serde::{Deserialize, Serialize};
6use tracing::debug;
7
8use crate::USER_AGENT;
9
10pub mod provider;
11
12pub trait Provider: private::Sealed {
13    type DeviceAccessTokenRequest<'d>: Serialize + Send
14    where
15        Self: 'd;
16
17    fn device_authorization_endpoint(&self) -> Url;
18    fn token_endpoint(&self) -> Url;
19    fn device_authorization_request(&self) -> DeviceAuthorizationRequest;
20    fn device_access_token_request<'d, 'p: 'd>(
21        &'p self,
22        device_code: &'d str,
23    ) -> Self::DeviceAccessTokenRequest<'d>;
24}
25
26mod private {
27    use crate::device_flow::provider;
28
29    pub trait Sealed {}
30
31    impl Sealed for provider::Github {}
32    impl Sealed for provider::Google {}
33}
34
35#[derive(Clone)]
36pub struct DeviceFlow<P> {
37    provider: P,
38    client: Client,
39}
40
41impl<P> DeviceFlow<P> {
42    pub fn new(provider: P) -> Self {
43        let client = reqwest::ClientBuilder::new()
44            .user_agent(USER_AGENT)
45            .connect_timeout(Duration::from_secs(10))
46            .timeout(Duration::from_secs(10))
47            .build()
48            .unwrap();
49
50        Self { provider, client }
51    }
52}
53
54impl<P: Provider> DeviceFlow<P> {
55    #[tracing::instrument(skip(self))]
56    pub async fn device_authorize_request(&self) -> anyhow::Result<DeviceAuthorizationResponse> {
57        let response = self
58            .client
59            .post(self.provider.device_authorization_endpoint())
60            .header(http::header::ACCEPT, "application/json")
61            .form(&self.provider.device_authorization_request())
62            .send()
63            .await?
64            .error_for_status()?
65            .json::<DeviceAuthorizationResponse>()
66            .await?;
67
68        Ok(response)
69    }
70
71    pub async fn poll_device_access_token(
72        &self,
73        device_code: String,
74        interval: Option<i64>,
75    ) -> anyhow::Result<DeviceAccessTokenResponse> {
76        // poll to check if user authorized the device
77        macro_rules! continue_or_abort {
78            ( $response_bytes:ident ) => {{
79                let err_response = serde_json::from_slice::<DeviceAccessTokenErrorResponse>(&$response_bytes)?;
80                if err_response.error.should_continue_to_poll() {
81                    debug!(error_code=?err_response.error,interval, "Continue to poll");
82
83                    let interval = interval.unwrap_or(5);
84
85                    #[allow(clippy::cast_sign_loss)]
86                    tokio::time::sleep(Duration::from_secs(interval as u64)).await;
87                } else {
88                    anyhow::bail!(
89                        "authorization server or oidc provider respond with {err_response:?}"
90                    )
91                }
92            }};
93        }
94
95        let response = loop {
96            let response = self
97                .client
98                .post(self.provider.token_endpoint())
99                .header(http::header::ACCEPT, "application/json")
100                .form(&self.provider.device_access_token_request(&device_code))
101                .send()
102                .await?;
103
104            match response.status() {
105                StatusCode::OK => {
106                    let full = response.bytes().await?;
107                    if let Ok(response) = serde_json::from_slice::<DeviceAccessTokenResponse>(&full)
108                    {
109                        break response;
110                    }
111                    continue_or_abort!(full);
112                }
113                // Google return 428(Precondition required)
114                StatusCode::BAD_REQUEST | StatusCode::PRECONDITION_REQUIRED => {
115                    let full = response.bytes().await?;
116                    continue_or_abort!(full);
117                }
118                other => {
119                    let error_msg = response.text().await.unwrap_or_default();
120                    anyhow::bail!("Failed to authenticate. authorization server respond with {other} {error_msg}")
121                }
122            }
123        };
124
125        Ok(response)
126    }
127}
128
129/// <https://datatracker.ietf.org/doc/html/rfc8628#section-3.1>
130#[derive(Serialize, Deserialize, Debug)]
131pub struct DeviceAuthorizationRequest<'s> {
132    pub client_id: Cow<'s, str>,
133    pub scope: Cow<'s, str>,
134}
135
136/// <https://datatracker.ietf.org/doc/html/rfc8628#section-3.2>
137#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
138pub struct DeviceAuthorizationResponse {
139    /// device verification code
140    pub device_code: String,
141    /// end user verification code
142    pub user_code: String,
143    /// end user verification uri on the authorization server
144    #[serde(with = "http_serde_ext::uri::option", default)]
145    pub verification_uri: Option<Uri>,
146    /// Google use verification_"url"
147    #[serde(with = "http_serde_ext::uri::option", default)]
148    pub verification_url: Option<Uri>,
149    /// a verification uri that includes `user_code` which is designed for non-textual transmission.
150    #[allow(unused)]
151    #[serde(with = "http_serde_ext::uri::option", default)]
152    pub verification_uri_complete: Option<Uri>,
153    /// the lifetime in seconds of the `device_code` and `user_code`
154    #[allow(unused)]
155    pub expires_in: i64,
156    /// the minimum amount of time in seconds that the client should wait between polling requests to the token endpoint
157    /// if no value is provided, clients must use 5 as the default
158    pub interval: Option<i64>,
159}
160
161impl DeviceAuthorizationResponse {
162    pub fn verification_uri(&self) -> &Uri {
163        self.verification_uri
164            .as_ref()
165            .or(self.verification_url.as_ref())
166            .expect("verification uri or url not found")
167    }
168}
169
170#[derive(Serialize, Deserialize)]
171pub struct DeviceAccessTokenRequest<'s> {
172    /// Value MUST be set to "urn:ietf:params:oauth:grant-type:device_code"
173    grant_type: Cow<'static, str>,
174    /// The device verification code, `device_code` from the device authorization response
175    pub device_code: Cow<'s, str>,
176    pub client_id: Cow<'s, str>,
177}
178
179impl<'s> DeviceAccessTokenRequest<'s> {
180    const GRANT_TYPE: &'static str = "urn:ietf:params:oauth:grant-type:device_code";
181
182    #[must_use]
183    pub fn new(device_code: impl Into<Cow<'s, str>>, client_id: &'s str) -> Self {
184        Self {
185            grant_type: Self::GRANT_TYPE.into(),
186            device_code: device_code.into(),
187            client_id: client_id.into(),
188        }
189    }
190}
191
192/// Successful Response
193/// <https://datatracker.ietf.org/doc/html/rfc6749#section-5.1>
194#[derive(Serialize, Deserialize, Debug, Clone)]
195pub struct DeviceAccessTokenResponse {
196    /// the access token issued by the authorization server
197    pub access_token: String,
198    pub token_type: String,
199    /// the lifetime in seconds of the access token
200    pub expires_in: Option<i64>,
201
202    // OIDC usecase
203    pub refresh_token: Option<String>,
204    pub id_token: Option<String>,
205}
206
207/// <https://datatracker.ietf.org/doc/html/rfc6749#section-5.2>
208#[derive(Serialize, Deserialize, Debug)]
209pub struct DeviceAccessTokenErrorResponse {
210    pub error: DeviceAccessTokenErrorCode,
211    #[allow(unused)]
212    pub error_description: Option<String>,
213    // error if there is no field on deserializing, maybe bug on http_serde_ext crate ?
214    #[allow(unused)]
215    #[serde(with = "http_serde_ext::uri::option", skip_deserializing)]
216    pub error_uri: Option<Uri>,
217}
218
219#[derive(PartialEq, Eq, Debug, Serialize, Deserialize)]
220#[serde(rename_all = "snake_case")]
221pub enum DeviceAccessTokenErrorCode {
222    AuthorizationPending,
223    SlowDown,
224    AccessDenied,
225    ExpiredToken,
226    InvalidRequest,
227    InvalidClient,
228    InvalidGrant,
229    UnauthorizedClient,
230    UnsupportedGrantType,
231    InvalidScope,
232    IncorrectDeviceCode,
233}
234
235impl DeviceAccessTokenErrorCode {
236    ///  The `authorization_pending` and `slow_down` error codes define particularly unique behavior, as they indicate that the OAuth client should continue to poll the token endpoint by repeating the token request (implementing the precise behavior defined above)
237    pub fn should_continue_to_poll(&self) -> bool {
238        use DeviceAccessTokenErrorCode::{AuthorizationPending, SlowDown};
239        *self == AuthorizationPending || *self == SlowDown
240    }
241}