rustauth_plugins/device_authorization/routes/
code.rs1use std::sync::Arc;
2
3use http::{header, Method, StatusCode};
4use rand::rngs::OsRng;
5use rand::RngCore;
6use rustauth_core::api::{
7 create_auth_endpoint, parse_request_body, AuthEndpointOptions, BodyField, BodySchema,
8 JsonSchemaType,
9};
10use rustauth_core::crypto::random::generate_random_string;
11use rustauth_core::error::RustAuthError;
12use rustauth_core::plugin::PluginEndpoint;
13use serde::{Deserialize, Serialize};
14use time::OffsetDateTime;
15use url::Url;
16
17use crate::device_authorization::errors::{oauth_error_response, OAuthDeviceError};
18use crate::device_authorization::options::{AsyncDeviceCodeGenerator, DeviceAuthorizationOptions};
19use crate::device_authorization::store::{CreateDeviceCodeInput, DeviceCodeStore};
20
21const DEFAULT_USER_CODE_CHARSET: &[u8] = b"ABCDEFGHJKLMNPQRSTUVWXYZ23456789";
22
23#[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)]
24pub struct DeviceCodeRequest {
25 pub client_id: String,
26 pub scope: Option<String>,
27}
28
29#[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)]
30pub struct DeviceCodeResponse {
31 pub device_code: String,
32 pub user_code: String,
33 pub verification_uri: String,
34 pub verification_uri_complete: String,
35 pub expires_in: i64,
36 pub interval: i64,
37}
38
39pub fn device_code(options: Arc<DeviceAuthorizationOptions>) -> PluginEndpoint {
40 create_auth_endpoint(
41 "/device/code",
42 Method::POST,
43 AuthEndpointOptions::new()
44 .operation_id("deviceCode")
45 .allowed_media_types(["application/json", "application/x-www-form-urlencoded"])
46 .openapi(super::openapi::device_code_operation())
47 .body_schema(BodySchema::object([
48 BodyField::new("client_id", JsonSchemaType::String),
49 BodyField::optional("scope", JsonSchemaType::String),
50 ])),
51 move |context, request| {
52 let options = Arc::clone(&options);
53 async move {
54 let body = parse_request_body::<DeviceCodeRequest>(&request)?;
55 if let Some(validate_client) = &options.validate_client {
56 if !(validate_client)(body.client_id.clone()).await? {
57 return oauth_error_response(
58 StatusCode::BAD_REQUEST,
59 OAuthDeviceError::InvalidClient,
60 "Invalid client ID",
61 );
62 }
63 }
64 if let Some(hook) = &options.on_device_auth_request {
65 (hook)(body.client_id.clone(), body.scope.clone()).await?;
66 }
67
68 let adapter = context.require_adapter()?;
69 let device_code = generate_code(
70 options.generate_device_code.as_ref(),
71 options.device_code_length,
72 default_device_code,
73 )
74 .await;
75 let user_code = generate_code(
76 options.generate_user_code.as_ref(),
77 options.user_code_length,
78 default_user_code,
79 )
80 .await;
81 let expires_in = options.expires_in.whole_seconds();
82 let interval = options.interval.whole_seconds();
83 let polling_interval = i64::try_from(options.interval.whole_milliseconds())
84 .map_err(|_| {
85 RustAuthError::InvalidConfig(
86 "device authorization interval is too large".to_owned(),
87 )
88 })?;
89
90 DeviceCodeStore::new(adapter.as_ref())
91 .create(CreateDeviceCodeInput {
92 device_code: device_code.clone(),
93 user_code: super::clean_user_code(&user_code),
94 expires_at: OffsetDateTime::now_utc() + options.expires_in,
95 polling_interval,
96 client_id: body.client_id,
97 scope: body.scope,
98 })
99 .await?;
100
101 let (verification_uri, verification_uri_complete) = build_verification_uris(
102 &options.verification_uri,
103 &context.base_url,
104 &user_code,
105 )?;
106 let mut response = super::json_response(
107 StatusCode::OK,
108 &DeviceCodeResponse {
109 device_code,
110 user_code,
111 verification_uri,
112 verification_uri_complete,
113 expires_in,
114 interval,
115 },
116 )?;
117 response.headers_mut().insert(
118 header::CACHE_CONTROL,
119 http::HeaderValue::from_static("no-store"),
120 );
121 Ok(response)
122 }
123 },
124 )
125}
126
127async fn generate_code(
128 generator: Option<&AsyncDeviceCodeGenerator>,
129 length: usize,
130 fallback: fn(usize) -> String,
131) -> String {
132 match generator {
133 Some(generator) => generator().await,
134 None => fallback(length),
135 }
136}
137
138fn default_device_code(length: usize) -> String {
139 generate_random_string(length)
140}
141
142fn default_user_code(length: usize) -> String {
143 let mut bytes = vec![0_u8; length];
144 OsRng.fill_bytes(&mut bytes);
145 bytes
146 .into_iter()
147 .map(|byte| {
148 let index = usize::from(byte) % DEFAULT_USER_CODE_CHARSET.len();
149 char::from(DEFAULT_USER_CODE_CHARSET[index])
150 })
151 .collect()
152}
153
154fn build_verification_uris(
155 verification_uri: &str,
156 base_url: &str,
157 user_code: &str,
158) -> Result<(String, String), RustAuthError> {
159 let verification_url = Url::parse(verification_uri)
160 .or_else(|_| Url::parse(base_url).and_then(|base| base.join(verification_uri)))
161 .map_err(|error| RustAuthError::InvalidConfig(error.to_string()))?;
162 let mut complete = verification_url.clone();
163 complete
164 .query_pairs_mut()
165 .append_pair("user_code", user_code);
166 Ok((verification_url.to_string(), complete.to_string()))
167}