1use rustauth_core::crypto::random::generate_random_string;
2use rustauth_core::db::{
3 Create, DbAdapter, DbRecord, DbValue, Delete, DeleteMany, FindOne, Update, Where,
4};
5use rustauth_core::error::RustAuthError;
6use serde::{Deserialize, Serialize};
7use time::OffsetDateTime;
8
9use super::schema::DEVICE_CODE_MODEL;
10
11const DEVICE_CODE_FIELDS: [&str; 12] = [
12 "id",
13 "device_code",
14 "user_code",
15 "user_id",
16 "expires_at",
17 "status",
18 "last_polled_at",
19 "polling_interval",
20 "client_id",
21 "scope",
22 "created_at",
23 "updated_at",
24];
25const DEFAULT_ID_LENGTH: usize = 32;
26
27#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
28#[serde(rename_all = "lowercase")]
29pub enum DeviceAuthorizationStatus {
30 Pending,
31 Approved,
32 Denied,
33}
34
35impl DeviceAuthorizationStatus {
36 pub fn as_str(self) -> &'static str {
37 match self {
38 Self::Pending => "pending",
39 Self::Approved => "approved",
40 Self::Denied => "denied",
41 }
42 }
43}
44
45impl TryFrom<&str> for DeviceAuthorizationStatus {
46 type Error = RustAuthError;
47
48 fn try_from(value: &str) -> Result<Self, Self::Error> {
49 match value {
50 "pending" => Ok(Self::Pending),
51 "approved" => Ok(Self::Approved),
52 "denied" => Ok(Self::Denied),
53 _ => Err(RustAuthError::Adapter(format!(
54 "device code status `{value}` is invalid"
55 ))),
56 }
57 }
58}
59
60#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
61pub struct DeviceCodeRecord {
62 pub id: String,
63 pub device_code: String,
64 pub user_code: String,
65 pub user_id: Option<String>,
66 pub expires_at: OffsetDateTime,
67 pub status: DeviceAuthorizationStatus,
68 pub last_polled_at: Option<OffsetDateTime>,
69 pub polling_interval: Option<i64>,
70 pub client_id: Option<String>,
71 pub scope: Option<String>,
72 pub created_at: OffsetDateTime,
73 pub updated_at: OffsetDateTime,
74}
75
76#[derive(Debug, Clone, PartialEq, Eq)]
77pub struct CreateDeviceCodeInput {
78 pub device_code: String,
79 pub user_code: String,
80 pub expires_at: OffsetDateTime,
81 pub polling_interval: i64,
82 pub client_id: String,
83 pub scope: Option<String>,
84}
85
86#[derive(Clone, Copy)]
87pub struct DeviceCodeStore<'a> {
88 adapter: &'a dyn DbAdapter,
89}
90
91impl<'a> DeviceCodeStore<'a> {
92 pub fn new(adapter: &'a dyn DbAdapter) -> Self {
93 Self { adapter }
94 }
95
96 pub async fn create(
97 &self,
98 input: CreateDeviceCodeInput,
99 ) -> Result<DeviceCodeRecord, RustAuthError> {
100 let now = OffsetDateTime::now_utc();
101 let record = self
102 .adapter
103 .create(
104 Create::new(DEVICE_CODE_MODEL)
105 .data(
106 "id",
107 DbValue::String(generate_random_string(DEFAULT_ID_LENGTH)),
108 )
109 .data("device_code", DbValue::String(input.device_code))
110 .data("user_code", DbValue::String(input.user_code))
111 .data("user_id", DbValue::Null)
112 .data("expires_at", DbValue::Timestamp(input.expires_at))
113 .data(
114 "status",
115 DbValue::String(DeviceAuthorizationStatus::Pending.as_str().to_owned()),
116 )
117 .data("last_polled_at", DbValue::Null)
118 .data("polling_interval", DbValue::Number(input.polling_interval))
119 .data("client_id", DbValue::String(input.client_id))
120 .data("scope", optional_string(input.scope))
121 .data("created_at", DbValue::Timestamp(now))
122 .data("updated_at", DbValue::Timestamp(now))
123 .select(DEVICE_CODE_FIELDS)
124 .force_allow_id(),
125 )
126 .await?;
127 record_from_db(record)
128 }
129
130 pub async fn find_by_device_code(
131 &self,
132 device_code: &str,
133 ) -> Result<Option<DeviceCodeRecord>, RustAuthError> {
134 self.find_one(Where::new(
135 "device_code",
136 DbValue::String(device_code.to_owned()),
137 ))
138 .await
139 }
140
141 pub async fn find_by_user_code(
142 &self,
143 user_code: &str,
144 ) -> Result<Option<DeviceCodeRecord>, RustAuthError> {
145 self.find_one(Where::new(
146 "user_code",
147 DbValue::String(user_code.to_owned()),
148 ))
149 .await
150 }
151
152 pub async fn mark_polled(&self, id: &str) -> Result<Option<DeviceCodeRecord>, RustAuthError> {
153 self.update(
154 id,
155 DbRecord::from([(
156 "last_polled_at".to_owned(),
157 DbValue::Timestamp(OffsetDateTime::now_utc()),
158 )]),
159 )
160 .await
161 }
162
163 pub async fn approve(
164 &self,
165 id: &str,
166 user_id: &str,
167 ) -> Result<Option<DeviceCodeRecord>, RustAuthError> {
168 self.update(
169 id,
170 DbRecord::from([
171 (
172 "status".to_owned(),
173 DbValue::String(DeviceAuthorizationStatus::Approved.as_str().to_owned()),
174 ),
175 ("user_id".to_owned(), DbValue::String(user_id.to_owned())),
176 ]),
177 )
178 .await
179 }
180
181 pub async fn deny(
182 &self,
183 id: &str,
184 user_id: &str,
185 ) -> Result<Option<DeviceCodeRecord>, RustAuthError> {
186 self.update(
187 id,
188 DbRecord::from([
189 (
190 "status".to_owned(),
191 DbValue::String(DeviceAuthorizationStatus::Denied.as_str().to_owned()),
192 ),
193 ("user_id".to_owned(), DbValue::String(user_id.to_owned())),
194 ]),
195 )
196 .await
197 }
198
199 pub async fn delete(&self, id: &str) -> Result<(), RustAuthError> {
200 self.adapter
201 .delete(Delete::new(DEVICE_CODE_MODEL).where_clause(id_where(id)))
202 .await
203 }
204
205 pub async fn consume_approved(&self, id: &str) -> Result<bool, RustAuthError> {
211 let deleted = self
212 .adapter
213 .delete_many(
214 DeleteMany::new(DEVICE_CODE_MODEL)
215 .where_clause(id_where(id))
216 .where_clause(Where::new(
217 "status",
218 DbValue::String(DeviceAuthorizationStatus::Approved.as_str().to_owned()),
219 )),
220 )
221 .await?;
222 Ok(deleted == 1)
223 }
224
225 async fn find_one(
226 &self,
227 where_clause: Where,
228 ) -> Result<Option<DeviceCodeRecord>, RustAuthError> {
229 self.adapter
230 .find_one(
231 FindOne::new(DEVICE_CODE_MODEL)
232 .where_clause(where_clause)
233 .select(DEVICE_CODE_FIELDS),
234 )
235 .await?
236 .map(record_from_db)
237 .transpose()
238 }
239
240 async fn update(
241 &self,
242 id: &str,
243 data: DbRecord,
244 ) -> Result<Option<DeviceCodeRecord>, RustAuthError> {
245 let mut query = Update::new(DEVICE_CODE_MODEL)
246 .where_clause(id_where(id))
247 .data("updated_at", DbValue::Timestamp(OffsetDateTime::now_utc()));
248 for (field, value) in data {
249 query = query.data(field, value);
250 }
251
252 self.adapter
253 .update(query)
254 .await?
255 .map(record_from_db)
256 .transpose()
257 }
258}
259
260fn id_where(id: &str) -> Where {
261 Where::new("id", DbValue::String(id.to_owned()))
262}
263
264fn optional_string(value: Option<String>) -> DbValue {
265 value.map(DbValue::String).unwrap_or(DbValue::Null)
266}
267
268fn record_from_db(record: DbRecord) -> Result<DeviceCodeRecord, RustAuthError> {
269 Ok(DeviceCodeRecord {
270 id: required_string(&record, "id")?.to_owned(),
271 device_code: required_string(&record, "device_code")?.to_owned(),
272 user_code: required_string(&record, "user_code")?.to_owned(),
273 user_id: optional_string_field(&record, "user_id")?,
274 expires_at: required_timestamp(&record, "expires_at")?,
275 status: DeviceAuthorizationStatus::try_from(required_string(&record, "status")?)?,
276 last_polled_at: optional_timestamp(&record, "last_polled_at")?,
277 polling_interval: optional_number(&record, "polling_interval")?,
278 client_id: optional_string_field(&record, "client_id")?,
279 scope: optional_string_field(&record, "scope")?,
280 created_at: required_timestamp(&record, "created_at")?,
281 updated_at: required_timestamp(&record, "updated_at")?,
282 })
283}
284
285fn required_string<'a>(record: &'a DbRecord, field: &str) -> Result<&'a str, RustAuthError> {
286 match record.get(field) {
287 Some(DbValue::String(value)) => Ok(value),
288 Some(_) => Err(invalid_field(field, "string")),
289 None => Err(missing_field(field)),
290 }
291}
292
293fn optional_string_field(record: &DbRecord, field: &str) -> Result<Option<String>, RustAuthError> {
294 match record.get(field) {
295 Some(DbValue::String(value)) => Ok(Some(value.to_owned())),
296 Some(DbValue::Null) | None => Ok(None),
297 Some(_) => Err(invalid_field(field, "string or null")),
298 }
299}
300
301fn required_timestamp(record: &DbRecord, field: &str) -> Result<OffsetDateTime, RustAuthError> {
302 match record.get(field) {
303 Some(DbValue::Timestamp(value)) => Ok(*value),
304 Some(_) => Err(invalid_field(field, "timestamp")),
305 None => Err(missing_field(field)),
306 }
307}
308
309fn optional_timestamp(
310 record: &DbRecord,
311 field: &str,
312) -> Result<Option<OffsetDateTime>, RustAuthError> {
313 match record.get(field) {
314 Some(DbValue::Timestamp(value)) => Ok(Some(*value)),
315 Some(DbValue::Null) | None => Ok(None),
316 Some(_) => Err(invalid_field(field, "timestamp or null")),
317 }
318}
319
320fn optional_number(record: &DbRecord, field: &str) -> Result<Option<i64>, RustAuthError> {
321 match record.get(field) {
322 Some(DbValue::Number(value)) => Ok(Some(*value)),
323 Some(DbValue::Null) | None => Ok(None),
324 Some(_) => Err(invalid_field(field, "number or null")),
325 }
326}
327
328fn missing_field(field: &str) -> RustAuthError {
329 RustAuthError::Adapter(format!("device code record is missing `{field}`"))
330}
331
332fn invalid_field(field: &str, expected: &str) -> RustAuthError {
333 RustAuthError::Adapter(format!(
334 "device code record field `{field}` must be {expected}"
335 ))
336}