Skip to main content

rustauth_plugins/device_authorization/routes/
code.rs

1use std::sync::Arc;
2
3use http::{header, Method, StatusCode};
4use rand::rngs::OsRng;
5use rand::RngCore;
6use rustauth_core::api::{
7    create_auth_endpoint, parse_request_body, AuthEndpointOptions, BodyField, BodySchema,
8    JsonSchemaType,
9};
10use rustauth_core::crypto::random::generate_random_string;
11use rustauth_core::error::RustAuthError;
12use rustauth_core::plugin::PluginEndpoint;
13use serde::{Deserialize, Serialize};
14use time::OffsetDateTime;
15use url::Url;
16
17use crate::device_authorization::errors::{oauth_error_response, OAuthDeviceError};
18use crate::device_authorization::options::{AsyncDeviceCodeGenerator, DeviceAuthorizationOptions};
19use crate::device_authorization::store::{CreateDeviceCodeInput, DeviceCodeStore};
20
21const DEFAULT_USER_CODE_CHARSET: &[u8] = b"ABCDEFGHJKLMNPQRSTUVWXYZ23456789";
22
23#[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)]
24pub struct DeviceCodeRequest {
25    pub client_id: String,
26    pub scope: Option<String>,
27}
28
29#[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)]
30pub struct DeviceCodeResponse {
31    pub device_code: String,
32    pub user_code: String,
33    pub verification_uri: String,
34    pub verification_uri_complete: String,
35    pub expires_in: i64,
36    pub interval: i64,
37}
38
39pub fn device_code(options: Arc<DeviceAuthorizationOptions>) -> PluginEndpoint {
40    create_auth_endpoint(
41        "/device/code",
42        Method::POST,
43        AuthEndpointOptions::new()
44            .operation_id("deviceCode")
45            .allowed_media_types(["application/json", "application/x-www-form-urlencoded"])
46            .openapi(super::openapi::device_code_operation())
47            .body_schema(BodySchema::object([
48                BodyField::new("client_id", JsonSchemaType::String),
49                BodyField::optional("scope", JsonSchemaType::String),
50            ])),
51        move |context, request| {
52            let options = Arc::clone(&options);
53            async move {
54                let body = parse_request_body::<DeviceCodeRequest>(&request)?;
55                if let Some(validate_client) = &options.validate_client {
56                    if !(validate_client)(body.client_id.clone()).await? {
57                        return oauth_error_response(
58                            StatusCode::BAD_REQUEST,
59                            OAuthDeviceError::InvalidClient,
60                            "Invalid client ID",
61                        );
62                    }
63                }
64                if let Some(hook) = &options.on_device_auth_request {
65                    (hook)(body.client_id.clone(), body.scope.clone()).await?;
66                }
67
68                let adapter = context.require_adapter()?;
69                let device_code = generate_code(
70                    options.generate_device_code.as_ref(),
71                    options.device_code_length,
72                    default_device_code,
73                )
74                .await;
75                let user_code = generate_code(
76                    options.generate_user_code.as_ref(),
77                    options.user_code_length,
78                    default_user_code,
79                )
80                .await;
81                let expires_in = options.expires_in.whole_seconds();
82                let interval = options.interval.whole_seconds();
83                let polling_interval = i64::try_from(options.interval.whole_milliseconds())
84                    .map_err(|_| {
85                        RustAuthError::InvalidConfig(
86                            "device authorization interval is too large".to_owned(),
87                        )
88                    })?;
89
90                DeviceCodeStore::new(adapter.as_ref())
91                    .create(CreateDeviceCodeInput {
92                        device_code: device_code.clone(),
93                        user_code: super::clean_user_code(&user_code),
94                        expires_at: OffsetDateTime::now_utc() + options.expires_in,
95                        polling_interval,
96                        client_id: body.client_id,
97                        scope: body.scope,
98                    })
99                    .await?;
100
101                let (verification_uri, verification_uri_complete) = build_verification_uris(
102                    &options.verification_uri,
103                    &context.base_url,
104                    &user_code,
105                )?;
106                let mut response = super::json_response(
107                    StatusCode::OK,
108                    &DeviceCodeResponse {
109                        device_code,
110                        user_code,
111                        verification_uri,
112                        verification_uri_complete,
113                        expires_in,
114                        interval,
115                    },
116                )?;
117                response.headers_mut().insert(
118                    header::CACHE_CONTROL,
119                    http::HeaderValue::from_static("no-store"),
120                );
121                Ok(response)
122            }
123        },
124    )
125}
126
127async fn generate_code(
128    generator: Option<&AsyncDeviceCodeGenerator>,
129    length: usize,
130    fallback: fn(usize) -> String,
131) -> String {
132    match generator {
133        Some(generator) => generator().await,
134        None => fallback(length),
135    }
136}
137
138fn default_device_code(length: usize) -> String {
139    generate_random_string(length)
140}
141
142fn default_user_code(length: usize) -> String {
143    let mut bytes = vec![0_u8; length];
144    OsRng.fill_bytes(&mut bytes);
145    bytes
146        .into_iter()
147        .map(|byte| {
148            let index = usize::from(byte) % DEFAULT_USER_CODE_CHARSET.len();
149            char::from(DEFAULT_USER_CODE_CHARSET[index])
150        })
151        .collect()
152}
153
154fn build_verification_uris(
155    verification_uri: &str,
156    base_url: &str,
157    user_code: &str,
158) -> Result<(String, String), RustAuthError> {
159    let verification_url = Url::parse(verification_uri)
160        .or_else(|_| Url::parse(base_url).and_then(|base| base.join(verification_uri)))
161        .map_err(|error| RustAuthError::InvalidConfig(error.to_string()))?;
162    let mut complete = verification_url.clone();
163    complete
164        .query_pairs_mut()
165        .append_pair("user_code", user_code);
166    Ok((verification_url.to_string(), complete.to_string()))
167}