Skip to main content

rustauth_plugins/device_authorization/routes/
token.rs

1use std::sync::Arc;
2
3use http::{header, Method, StatusCode};
4use rustauth_core::api::{
5    create_auth_endpoint, parse_request_body, ApiRequest, ApiResponse, AuthEndpointOptions,
6    BodyField, BodySchema, JsonSchemaType,
7};
8use rustauth_core::context::{request_state, AuthContext};
9use rustauth_core::db::{DbRecord, DbValue};
10use rustauth_core::error::RustAuthError;
11use rustauth_core::plugin::PluginEndpoint;
12use rustauth_core::session::CreateSessionInput;
13use serde::{Deserialize, Serialize};
14use time::OffsetDateTime;
15
16use crate::device_authorization::errors::{oauth_error_response, OAuthDeviceError};
17use crate::device_authorization::options::DeviceAuthorizationOptions;
18use crate::device_authorization::store::{
19    DeviceAuthorizationStatus, DeviceCodeRecord, DeviceCodeStore,
20};
21
22const DEVICE_CODE_GRANT_TYPE: &str = "urn:ietf:params:oauth:grant-type:device_code";
23
24#[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)]
25pub struct DeviceTokenRequest {
26    pub grant_type: String,
27    pub device_code: String,
28    pub client_id: String,
29}
30
31#[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)]
32pub struct DeviceTokenResponse {
33    pub access_token: String,
34    pub token_type: String,
35    pub expires_in: i64,
36    pub scope: String,
37}
38
39pub fn device_token(options: Arc<DeviceAuthorizationOptions>) -> PluginEndpoint {
40    create_auth_endpoint(
41        "/device/token",
42        Method::POST,
43        AuthEndpointOptions::new()
44            .operation_id("deviceToken")
45            .allowed_media_types(["application/json", "application/x-www-form-urlencoded"])
46            .openapi(super::openapi::device_token_operation())
47            .body_schema(BodySchema::object([
48                BodyField::new("grant_type", JsonSchemaType::String),
49                BodyField::new("device_code", JsonSchemaType::String),
50                BodyField::new("client_id", JsonSchemaType::String),
51            ])),
52        move |context, request| {
53            let options = Arc::clone(&options);
54            async move {
55                let body = parse_request_body::<DeviceTokenRequest>(&request)?;
56                if body.grant_type != DEVICE_CODE_GRANT_TYPE {
57                    return token_oauth_error_response(
58                        StatusCode::BAD_REQUEST,
59                        OAuthDeviceError::InvalidRequest,
60                        "Invalid grant type",
61                    );
62                }
63                if let Some(validate_client) = &options.validate_client {
64                    if !(validate_client)(body.client_id.clone()).await? {
65                        return token_oauth_error_response(
66                            StatusCode::BAD_REQUEST,
67                            OAuthDeviceError::InvalidGrant,
68                            "Invalid client ID",
69                        );
70                    }
71                }
72
73                let adapter = context.require_adapter()?;
74                let store = DeviceCodeStore::new(adapter.as_ref());
75                let Some(record) = store.find_by_device_code(&body.device_code).await? else {
76                    return token_oauth_error_response(
77                        StatusCode::BAD_REQUEST,
78                        OAuthDeviceError::InvalidGrant,
79                        "Invalid device code",
80                    );
81                };
82                if record
83                    .client_id
84                    .as_ref()
85                    .is_some_and(|client_id| client_id != &body.client_id)
86                {
87                    return token_oauth_error_response(
88                        StatusCode::BAD_REQUEST,
89                        OAuthDeviceError::InvalidGrant,
90                        "Client ID mismatch",
91                    );
92                }
93                if polling_too_fast(&record) {
94                    return token_oauth_error_response(
95                        StatusCode::BAD_REQUEST,
96                        OAuthDeviceError::SlowDown,
97                        "Polling too frequently",
98                    );
99                }
100                store.mark_polled(&record.id).await?;
101                if record.expires_at < OffsetDateTime::now_utc() {
102                    store.delete(&record.id).await?;
103                    return token_oauth_error_response(
104                        StatusCode::BAD_REQUEST,
105                        OAuthDeviceError::ExpiredToken,
106                        "Device code has expired",
107                    );
108                }
109
110                match record.status {
111                    DeviceAuthorizationStatus::Pending => token_oauth_error_response(
112                        StatusCode::BAD_REQUEST,
113                        OAuthDeviceError::AuthorizationPending,
114                        "Authorization pending",
115                    ),
116                    DeviceAuthorizationStatus::Denied => {
117                        store.delete(&record.id).await?;
118                        token_oauth_error_response(
119                            StatusCode::BAD_REQUEST,
120                            OAuthDeviceError::AccessDenied,
121                            "Access denied",
122                        )
123                    }
124                    DeviceAuthorizationStatus::Approved => {
125                        approved_response(&context, &store, &record, &request).await
126                    }
127                }
128            }
129        },
130    )
131}
132
133async fn approved_response(
134    context: &AuthContext,
135    store: &DeviceCodeStore<'_>,
136    record: &DeviceCodeRecord,
137    request: &ApiRequest,
138) -> Result<ApiResponse, RustAuthError> {
139    if !store.consume_approved(&record.id).await? {
140        return token_oauth_error_response(
141            StatusCode::BAD_REQUEST,
142            OAuthDeviceError::InvalidGrant,
143            "Device code has already been used",
144        );
145    }
146    let Some(user_id) = &record.user_id else {
147        return token_oauth_error_response(
148            StatusCode::INTERNAL_SERVER_ERROR,
149            OAuthDeviceError::ServerError,
150            "Invalid device code status",
151        );
152    };
153    let Some(user) = context.users()?.find_user_by_id(user_id).await? else {
154        return token_oauth_error_response(
155            StatusCode::INTERNAL_SERVER_ERROR,
156            OAuthDeviceError::ServerError,
157            "User not found",
158        );
159    };
160    let expires_at = OffsetDateTime::now_utc() + context.session_config.expires_in;
161    let mut input = CreateSessionInput::new(user.id.clone(), expires_at)
162        .additional_fields(additional_session_create_values(context));
163    if let Some(ip_address) = request_ip(request) {
164        input = input.ip_address(ip_address);
165    }
166    if let Some(user_agent) = request_user_agent(request) {
167        input = input.user_agent(user_agent);
168    }
169    let session = match context.sessions()?.create_session(input).await {
170        Ok(session) => session,
171        Err(_) => {
172            return token_oauth_error_response(
173                StatusCode::INTERNAL_SERVER_ERROR,
174                OAuthDeviceError::ServerError,
175                "Failed to create session",
176            );
177        }
178    };
179    if request_state::has_request_state() {
180        request_state::set_current_new_session(session.clone(), user)?;
181    }
182
183    let expires_in = (session.expires_at - OffsetDateTime::now_utc())
184        .whole_seconds()
185        .max(0);
186    let mut response = super::json_response(
187        StatusCode::OK,
188        &DeviceTokenResponse {
189            access_token: session.token,
190            token_type: "Bearer".to_owned(),
191            expires_in,
192            scope: record.scope.clone().unwrap_or_default(),
193        },
194    )?;
195    add_token_cache_headers(&mut response);
196    Ok(response)
197}
198
199fn polling_too_fast(record: &DeviceCodeRecord) -> bool {
200    let (Some(last_polled_at), Some(interval)) = (record.last_polled_at, record.polling_interval)
201    else {
202        return false;
203    };
204    let elapsed = OffsetDateTime::now_utc() - last_polled_at;
205    elapsed.whole_milliseconds() < i128::from(interval)
206}
207
208fn token_oauth_error_response(
209    status: StatusCode,
210    error: OAuthDeviceError,
211    description: &str,
212) -> Result<ApiResponse, RustAuthError> {
213    let mut response = oauth_error_response(status, error, description)?;
214    add_token_cache_headers(&mut response);
215    Ok(response)
216}
217
218fn add_token_cache_headers(response: &mut ApiResponse) {
219    response.headers_mut().insert(
220        header::CACHE_CONTROL,
221        http::HeaderValue::from_static("no-store"),
222    );
223    response
224        .headers_mut()
225        .insert(header::PRAGMA, http::HeaderValue::from_static("no-cache"));
226}
227
228fn additional_session_create_values(context: &AuthContext) -> DbRecord {
229    context
230        .options
231        .session
232        .additional_fields
233        .iter()
234        .map(|(name, field)| {
235            (
236                name.clone(),
237                field.default_value.clone().unwrap_or(DbValue::Null),
238            )
239        })
240        .collect()
241}
242
243fn request_user_agent(request: &ApiRequest) -> Option<String> {
244    request
245        .headers()
246        .get(header::USER_AGENT)
247        .and_then(|value| value.to_str().ok())
248        .map(str::to_owned)
249}
250
251fn request_ip(request: &ApiRequest) -> Option<String> {
252    request
253        .headers()
254        .get("x-forwarded-for")
255        .and_then(|value| value.to_str().ok())
256        .and_then(|value| value.split(',').next())
257        .map(str::trim)
258        .filter(|value| !value.is_empty())
259        .map(str::to_owned)
260        .or_else(|| {
261            request
262                .headers()
263                .get("x-real-ip")
264                .and_then(|value| value.to_str().ok())
265                .map(str::to_owned)
266        })
267}