use std::{borrow::Cow, time::Duration};
use http::{StatusCode, Uri};
use reqwest::{Client, Url};
use serde::{Deserialize, Serialize};
use tracing::debug;
use crate::USER_AGENT;
pub mod provider;
pub trait Provider {
type DeviceAccessTokenRequest<'d>: Serialize + Send
where
Self: 'd;
fn device_authorization_endpoint(&self) -> Url;
fn token_endpoint(&self) -> Url;
fn device_authorization_request(&self) -> DeviceAuthorizationRequest;
fn device_access_token_request<'d, 'p: 'd>(
&'p self,
device_code: &'d str,
) -> Self::DeviceAccessTokenRequest<'d>;
}
#[derive(Clone)]
pub struct DeviceFlow<P> {
provider: P,
client: Client,
}
impl<P> DeviceFlow<P> {
pub fn new(provider: P) -> Self {
let client = reqwest::ClientBuilder::new()
.user_agent(USER_AGENT)
.connect_timeout(Duration::from_secs(10))
.timeout(Duration::from_secs(10))
.build()
.unwrap();
Self { provider, client }
}
}
impl<P: Provider> DeviceFlow<P> {
#[tracing::instrument(skip(self))]
pub async fn device_authorize_request(&self) -> anyhow::Result<DeviceAuthorizationResponse> {
let response = self
.client
.post(self.provider.device_authorization_endpoint())
.header(http::header::ACCEPT, "application/json")
.form(&self.provider.device_authorization_request())
.send()
.await?
.error_for_status()?
.json::<DeviceAuthorizationResponse>()
.await?;
Ok(response)
}
pub async fn poll_device_access_token(
&self,
device_code: String,
interval: Option<i64>,
) -> anyhow::Result<DeviceAccessTokenResponse> {
macro_rules! continue_or_abort {
( $response_bytes:ident ) => {{
let err_response = serde_json::from_slice::<DeviceAccessTokenErrorResponse>(&$response_bytes)?;
if err_response.error.should_continue_to_poll() {
debug!(error_code=?err_response.error,interval, "Continue to poll");
let interval = interval.unwrap_or(5);
tokio::time::sleep(Duration::from_secs(interval as u64)).await;
} else {
anyhow::bail!(
"authorization server or oidc provider respond with {err_response:?}"
)
}
}};
}
let response = loop {
let response = self
.client
.post(self.provider.token_endpoint())
.header(http::header::ACCEPT, "application/json")
.form(&self.provider.device_access_token_request(&device_code))
.send()
.await?;
match response.status() {
StatusCode::OK => {
let full = response.bytes().await?;
if let Ok(response) = serde_json::from_slice::<DeviceAccessTokenResponse>(&full)
{
break response;
}
continue_or_abort!(full);
}
StatusCode::BAD_REQUEST | StatusCode::PRECONDITION_REQUIRED => {
let full = response.bytes().await?;
continue_or_abort!(full);
}
other => {
let error_msg = response.text().await.unwrap_or_default();
anyhow::bail!("Failed to authenticate. authorization server respond with {other} {error_msg}")
}
}
};
Ok(response)
}
}
#[derive(Serialize, Deserialize, Debug)]
pub struct DeviceAuthorizationRequest<'s> {
pub client_id: Cow<'s, str>,
pub scope: Cow<'s, str>,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct DeviceAuthorizationResponse {
pub device_code: String,
pub user_code: String,
#[serde(with = "http_serde_ext::uri::option", default)]
pub verification_uri: Option<Uri>,
#[serde(with = "http_serde_ext::uri::option", default)]
pub verification_url: Option<Uri>,
#[allow(unused)]
#[serde(with = "http_serde_ext::uri::option", default)]
pub verification_uri_complete: Option<Uri>,
#[allow(unused)]
pub expires_in: i64,
pub interval: Option<i64>,
}
impl DeviceAuthorizationResponse {
pub fn verification_uri(&self) -> &Uri {
self.verification_uri
.as_ref()
.or(self.verification_url.as_ref())
.expect("verification uri or url not found")
}
}
#[derive(Serialize, Deserialize)]
pub struct DeviceAccessTokenRequest<'s> {
grant_type: Cow<'static, str>,
pub device_code: Cow<'s, str>,
pub client_id: Cow<'s, str>,
}
impl<'s> DeviceAccessTokenRequest<'s> {
const GRANT_TYPE: &'static str = "urn:ietf:params:oauth:grant-type:device_code";
#[must_use]
pub fn new(device_code: impl Into<Cow<'s, str>>, client_id: &'s str) -> Self {
Self {
grant_type: Self::GRANT_TYPE.into(),
device_code: device_code.into(),
client_id: client_id.into(),
}
}
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct DeviceAccessTokenResponse {
pub access_token: String,
pub token_type: String,
pub expires_in: Option<i64>,
pub refresh_token: Option<String>,
pub id_token: Option<String>,
}
#[derive(Deserialize, Debug)]
pub struct DeviceAccessTokenErrorResponse {
pub error: DeviceAccessTokenErrorCode,
#[allow(unused)]
pub error_description: Option<String>,
#[allow(unused)]
#[serde(with = "http_serde_ext::uri::option", skip_deserializing)]
pub error_uri: Option<Uri>,
}
#[derive(PartialEq, Eq, Debug, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum DeviceAccessTokenErrorCode {
AuthorizationPending,
SlowDown,
AccessDenied,
ExpiredToken,
InvalidRequest,
InvalidClient,
InvalidGrant,
UnauthorizedClient,
UnsupportedGrantType,
InvalidScope,
IncorrectDeviceCode,
}
impl DeviceAccessTokenErrorCode {
pub fn should_continue_to_poll(&self) -> bool {
use DeviceAccessTokenErrorCode::{AuthorizationPending, SlowDown};
*self == AuthorizationPending || *self == SlowDown
}
}