rustauth_plugins/device_authorization/routes/
token.rs1use std::sync::Arc;
2
3use http::{header, Method, StatusCode};
4use rustauth_core::api::{
5 create_auth_endpoint, parse_request_body, ApiRequest, ApiResponse, AuthEndpointOptions,
6 BodyField, BodySchema, JsonSchemaType,
7};
8use rustauth_core::context::{request_state, AuthContext};
9use rustauth_core::db::{DbRecord, DbValue};
10use rustauth_core::error::RustAuthError;
11use rustauth_core::plugin::PluginEndpoint;
12use rustauth_core::session::CreateSessionInput;
13use serde::{Deserialize, Serialize};
14use time::OffsetDateTime;
15
16use crate::device_authorization::errors::{oauth_error_response, OAuthDeviceError};
17use crate::device_authorization::options::DeviceAuthorizationOptions;
18use crate::device_authorization::store::{
19 DeviceAuthorizationStatus, DeviceCodeRecord, DeviceCodeStore,
20};
21
22const DEVICE_CODE_GRANT_TYPE: &str = "urn:ietf:params:oauth:grant-type:device_code";
23
24#[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)]
25pub struct DeviceTokenRequest {
26 pub grant_type: String,
27 pub device_code: String,
28 pub client_id: String,
29}
30
31#[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)]
32pub struct DeviceTokenResponse {
33 pub access_token: String,
34 pub token_type: String,
35 pub expires_in: i64,
36 pub scope: String,
37}
38
39pub fn device_token(options: Arc<DeviceAuthorizationOptions>) -> PluginEndpoint {
40 create_auth_endpoint(
41 "/device/token",
42 Method::POST,
43 AuthEndpointOptions::new()
44 .operation_id("deviceToken")
45 .allowed_media_types(["application/json", "application/x-www-form-urlencoded"])
46 .openapi(super::openapi::device_token_operation())
47 .body_schema(BodySchema::object([
48 BodyField::new("grant_type", JsonSchemaType::String),
49 BodyField::new("device_code", JsonSchemaType::String),
50 BodyField::new("client_id", JsonSchemaType::String),
51 ])),
52 move |context, request| {
53 let options = Arc::clone(&options);
54 async move {
55 let body = parse_request_body::<DeviceTokenRequest>(&request)?;
56 if body.grant_type != DEVICE_CODE_GRANT_TYPE {
57 return token_oauth_error_response(
58 StatusCode::BAD_REQUEST,
59 OAuthDeviceError::InvalidRequest,
60 "Invalid grant type",
61 );
62 }
63 if let Some(validate_client) = &options.validate_client {
64 if !(validate_client)(body.client_id.clone()).await? {
65 return token_oauth_error_response(
66 StatusCode::BAD_REQUEST,
67 OAuthDeviceError::InvalidGrant,
68 "Invalid client ID",
69 );
70 }
71 }
72
73 let adapter = context.require_adapter()?;
74 let store = DeviceCodeStore::new(adapter.as_ref());
75 let Some(record) = store.find_by_device_code(&body.device_code).await? else {
76 return token_oauth_error_response(
77 StatusCode::BAD_REQUEST,
78 OAuthDeviceError::InvalidGrant,
79 "Invalid device code",
80 );
81 };
82 if record
83 .client_id
84 .as_ref()
85 .is_some_and(|client_id| client_id != &body.client_id)
86 {
87 return token_oauth_error_response(
88 StatusCode::BAD_REQUEST,
89 OAuthDeviceError::InvalidGrant,
90 "Client ID mismatch",
91 );
92 }
93 if polling_too_fast(&record) {
94 return token_oauth_error_response(
95 StatusCode::BAD_REQUEST,
96 OAuthDeviceError::SlowDown,
97 "Polling too frequently",
98 );
99 }
100 store.mark_polled(&record.id).await?;
101 if record.expires_at < OffsetDateTime::now_utc() {
102 store.delete(&record.id).await?;
103 return token_oauth_error_response(
104 StatusCode::BAD_REQUEST,
105 OAuthDeviceError::ExpiredToken,
106 "Device code has expired",
107 );
108 }
109
110 match record.status {
111 DeviceAuthorizationStatus::Pending => token_oauth_error_response(
112 StatusCode::BAD_REQUEST,
113 OAuthDeviceError::AuthorizationPending,
114 "Authorization pending",
115 ),
116 DeviceAuthorizationStatus::Denied => {
117 store.delete(&record.id).await?;
118 token_oauth_error_response(
119 StatusCode::BAD_REQUEST,
120 OAuthDeviceError::AccessDenied,
121 "Access denied",
122 )
123 }
124 DeviceAuthorizationStatus::Approved => {
125 approved_response(&context, &store, &record, &request).await
126 }
127 }
128 }
129 },
130 )
131}
132
133async fn approved_response(
134 context: &AuthContext,
135 store: &DeviceCodeStore<'_>,
136 record: &DeviceCodeRecord,
137 request: &ApiRequest,
138) -> Result<ApiResponse, RustAuthError> {
139 if !store.consume_approved(&record.id).await? {
140 return token_oauth_error_response(
141 StatusCode::BAD_REQUEST,
142 OAuthDeviceError::InvalidGrant,
143 "Device code has already been used",
144 );
145 }
146 let Some(user_id) = &record.user_id else {
147 return token_oauth_error_response(
148 StatusCode::INTERNAL_SERVER_ERROR,
149 OAuthDeviceError::ServerError,
150 "Invalid device code status",
151 );
152 };
153 let Some(user) = context.users()?.find_user_by_id(user_id).await? else {
154 return token_oauth_error_response(
155 StatusCode::INTERNAL_SERVER_ERROR,
156 OAuthDeviceError::ServerError,
157 "User not found",
158 );
159 };
160 let expires_at = OffsetDateTime::now_utc() + context.session_config.expires_in;
161 let mut input = CreateSessionInput::new(user.id.clone(), expires_at)
162 .additional_fields(additional_session_create_values(context));
163 if let Some(ip_address) = request_ip(request) {
164 input = input.ip_address(ip_address);
165 }
166 if let Some(user_agent) = request_user_agent(request) {
167 input = input.user_agent(user_agent);
168 }
169 let session = match context.sessions()?.create_session(input).await {
170 Ok(session) => session,
171 Err(_) => {
172 return token_oauth_error_response(
173 StatusCode::INTERNAL_SERVER_ERROR,
174 OAuthDeviceError::ServerError,
175 "Failed to create session",
176 );
177 }
178 };
179 if request_state::has_request_state() {
180 request_state::set_current_new_session(session.clone(), user)?;
181 }
182
183 let expires_in = (session.expires_at - OffsetDateTime::now_utc())
184 .whole_seconds()
185 .max(0);
186 let mut response = super::json_response(
187 StatusCode::OK,
188 &DeviceTokenResponse {
189 access_token: session.token,
190 token_type: "Bearer".to_owned(),
191 expires_in,
192 scope: record.scope.clone().unwrap_or_default(),
193 },
194 )?;
195 add_token_cache_headers(&mut response);
196 Ok(response)
197}
198
199fn polling_too_fast(record: &DeviceCodeRecord) -> bool {
200 let (Some(last_polled_at), Some(interval)) = (record.last_polled_at, record.polling_interval)
201 else {
202 return false;
203 };
204 let elapsed = OffsetDateTime::now_utc() - last_polled_at;
205 elapsed.whole_milliseconds() < i128::from(interval)
206}
207
208fn token_oauth_error_response(
209 status: StatusCode,
210 error: OAuthDeviceError,
211 description: &str,
212) -> Result<ApiResponse, RustAuthError> {
213 let mut response = oauth_error_response(status, error, description)?;
214 add_token_cache_headers(&mut response);
215 Ok(response)
216}
217
218fn add_token_cache_headers(response: &mut ApiResponse) {
219 response.headers_mut().insert(
220 header::CACHE_CONTROL,
221 http::HeaderValue::from_static("no-store"),
222 );
223 response
224 .headers_mut()
225 .insert(header::PRAGMA, http::HeaderValue::from_static("no-cache"));
226}
227
228fn additional_session_create_values(context: &AuthContext) -> DbRecord {
229 context
230 .options
231 .session
232 .additional_fields
233 .iter()
234 .map(|(name, field)| {
235 (
236 name.clone(),
237 field.default_value.clone().unwrap_or(DbValue::Null),
238 )
239 })
240 .collect()
241}
242
243fn request_user_agent(request: &ApiRequest) -> Option<String> {
244 request
245 .headers()
246 .get(header::USER_AGENT)
247 .and_then(|value| value.to_str().ok())
248 .map(str::to_owned)
249}
250
251fn request_ip(request: &ApiRequest) -> Option<String> {
252 request
253 .headers()
254 .get("x-forwarded-for")
255 .and_then(|value| value.to_str().ok())
256 .and_then(|value| value.split(',').next())
257 .map(str::trim)
258 .filter(|value| !value.is_empty())
259 .map(str::to_owned)
260 .or_else(|| {
261 request
262 .headers()
263 .get("x-real-ip")
264 .and_then(|value| value.to_str().ok())
265 .map(str::to_owned)
266 })
267}