Skip to main content

rustauth_plugins/device_authorization/
store.rs

1use rustauth_core::crypto::random::generate_random_string;
2use rustauth_core::db::{
3    Create, DbAdapter, DbRecord, DbValue, Delete, DeleteMany, FindOne, Update, Where,
4};
5use rustauth_core::error::RustAuthError;
6use serde::{Deserialize, Serialize};
7use time::OffsetDateTime;
8
9use super::schema::DEVICE_CODE_MODEL;
10
11const DEVICE_CODE_FIELDS: [&str; 12] = [
12    "id",
13    "device_code",
14    "user_code",
15    "user_id",
16    "expires_at",
17    "status",
18    "last_polled_at",
19    "polling_interval",
20    "client_id",
21    "scope",
22    "created_at",
23    "updated_at",
24];
25const DEFAULT_ID_LENGTH: usize = 32;
26
27#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
28#[serde(rename_all = "lowercase")]
29pub enum DeviceAuthorizationStatus {
30    Pending,
31    Approved,
32    Denied,
33}
34
35impl DeviceAuthorizationStatus {
36    pub fn as_str(self) -> &'static str {
37        match self {
38            Self::Pending => "pending",
39            Self::Approved => "approved",
40            Self::Denied => "denied",
41        }
42    }
43}
44
45impl TryFrom<&str> for DeviceAuthorizationStatus {
46    type Error = RustAuthError;
47
48    fn try_from(value: &str) -> Result<Self, Self::Error> {
49        match value {
50            "pending" => Ok(Self::Pending),
51            "approved" => Ok(Self::Approved),
52            "denied" => Ok(Self::Denied),
53            _ => Err(RustAuthError::Adapter(format!(
54                "device code status `{value}` is invalid"
55            ))),
56        }
57    }
58}
59
60#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
61pub struct DeviceCodeRecord {
62    pub id: String,
63    pub device_code: String,
64    pub user_code: String,
65    pub user_id: Option<String>,
66    pub expires_at: OffsetDateTime,
67    pub status: DeviceAuthorizationStatus,
68    pub last_polled_at: Option<OffsetDateTime>,
69    pub polling_interval: Option<i64>,
70    pub client_id: Option<String>,
71    pub scope: Option<String>,
72    pub created_at: OffsetDateTime,
73    pub updated_at: OffsetDateTime,
74}
75
76#[derive(Debug, Clone, PartialEq, Eq)]
77pub struct CreateDeviceCodeInput {
78    pub device_code: String,
79    pub user_code: String,
80    pub expires_at: OffsetDateTime,
81    pub polling_interval: i64,
82    pub client_id: String,
83    pub scope: Option<String>,
84}
85
86#[derive(Clone, Copy)]
87pub struct DeviceCodeStore<'a> {
88    adapter: &'a dyn DbAdapter,
89}
90
91impl<'a> DeviceCodeStore<'a> {
92    pub fn new(adapter: &'a dyn DbAdapter) -> Self {
93        Self { adapter }
94    }
95
96    pub async fn create(
97        &self,
98        input: CreateDeviceCodeInput,
99    ) -> Result<DeviceCodeRecord, RustAuthError> {
100        let now = OffsetDateTime::now_utc();
101        let record = self
102            .adapter
103            .create(
104                Create::new(DEVICE_CODE_MODEL)
105                    .data(
106                        "id",
107                        DbValue::String(generate_random_string(DEFAULT_ID_LENGTH)),
108                    )
109                    .data("device_code", DbValue::String(input.device_code))
110                    .data("user_code", DbValue::String(input.user_code))
111                    .data("user_id", DbValue::Null)
112                    .data("expires_at", DbValue::Timestamp(input.expires_at))
113                    .data(
114                        "status",
115                        DbValue::String(DeviceAuthorizationStatus::Pending.as_str().to_owned()),
116                    )
117                    .data("last_polled_at", DbValue::Null)
118                    .data("polling_interval", DbValue::Number(input.polling_interval))
119                    .data("client_id", DbValue::String(input.client_id))
120                    .data("scope", optional_string(input.scope))
121                    .data("created_at", DbValue::Timestamp(now))
122                    .data("updated_at", DbValue::Timestamp(now))
123                    .select(DEVICE_CODE_FIELDS)
124                    .force_allow_id(),
125            )
126            .await?;
127        record_from_db(record)
128    }
129
130    pub async fn find_by_device_code(
131        &self,
132        device_code: &str,
133    ) -> Result<Option<DeviceCodeRecord>, RustAuthError> {
134        self.find_one(Where::new(
135            "device_code",
136            DbValue::String(device_code.to_owned()),
137        ))
138        .await
139    }
140
141    pub async fn find_by_user_code(
142        &self,
143        user_code: &str,
144    ) -> Result<Option<DeviceCodeRecord>, RustAuthError> {
145        self.find_one(Where::new(
146            "user_code",
147            DbValue::String(user_code.to_owned()),
148        ))
149        .await
150    }
151
152    pub async fn mark_polled(&self, id: &str) -> Result<Option<DeviceCodeRecord>, RustAuthError> {
153        self.update(
154            id,
155            DbRecord::from([(
156                "last_polled_at".to_owned(),
157                DbValue::Timestamp(OffsetDateTime::now_utc()),
158            )]),
159        )
160        .await
161    }
162
163    pub async fn approve(
164        &self,
165        id: &str,
166        user_id: &str,
167    ) -> Result<Option<DeviceCodeRecord>, RustAuthError> {
168        self.update(
169            id,
170            DbRecord::from([
171                (
172                    "status".to_owned(),
173                    DbValue::String(DeviceAuthorizationStatus::Approved.as_str().to_owned()),
174                ),
175                ("user_id".to_owned(), DbValue::String(user_id.to_owned())),
176            ]),
177        )
178        .await
179    }
180
181    pub async fn deny(
182        &self,
183        id: &str,
184        user_id: &str,
185    ) -> Result<Option<DeviceCodeRecord>, RustAuthError> {
186        self.update(
187            id,
188            DbRecord::from([
189                (
190                    "status".to_owned(),
191                    DbValue::String(DeviceAuthorizationStatus::Denied.as_str().to_owned()),
192                ),
193                ("user_id".to_owned(), DbValue::String(user_id.to_owned())),
194            ]),
195        )
196        .await
197    }
198
199    pub async fn delete(&self, id: &str) -> Result<(), RustAuthError> {
200        self.adapter
201            .delete(Delete::new(DEVICE_CODE_MODEL).where_clause(id_where(id)))
202            .await
203    }
204
205    /// Atomically consumes an approved device code before token minting.
206    ///
207    /// Parallel callers racing on the same approved code only observe a
208    /// successful consume once: the delete is keyed by both row id and
209    /// `status = approved`, so later attempts delete zero rows.
210    pub async fn consume_approved(&self, id: &str) -> Result<bool, RustAuthError> {
211        let deleted = self
212            .adapter
213            .delete_many(
214                DeleteMany::new(DEVICE_CODE_MODEL)
215                    .where_clause(id_where(id))
216                    .where_clause(Where::new(
217                        "status",
218                        DbValue::String(DeviceAuthorizationStatus::Approved.as_str().to_owned()),
219                    )),
220            )
221            .await?;
222        Ok(deleted == 1)
223    }
224
225    async fn find_one(
226        &self,
227        where_clause: Where,
228    ) -> Result<Option<DeviceCodeRecord>, RustAuthError> {
229        self.adapter
230            .find_one(
231                FindOne::new(DEVICE_CODE_MODEL)
232                    .where_clause(where_clause)
233                    .select(DEVICE_CODE_FIELDS),
234            )
235            .await?
236            .map(record_from_db)
237            .transpose()
238    }
239
240    async fn update(
241        &self,
242        id: &str,
243        data: DbRecord,
244    ) -> Result<Option<DeviceCodeRecord>, RustAuthError> {
245        let mut query = Update::new(DEVICE_CODE_MODEL)
246            .where_clause(id_where(id))
247            .data("updated_at", DbValue::Timestamp(OffsetDateTime::now_utc()));
248        for (field, value) in data {
249            query = query.data(field, value);
250        }
251
252        self.adapter
253            .update(query)
254            .await?
255            .map(record_from_db)
256            .transpose()
257    }
258}
259
260fn id_where(id: &str) -> Where {
261    Where::new("id", DbValue::String(id.to_owned()))
262}
263
264fn optional_string(value: Option<String>) -> DbValue {
265    value.map(DbValue::String).unwrap_or(DbValue::Null)
266}
267
268fn record_from_db(record: DbRecord) -> Result<DeviceCodeRecord, RustAuthError> {
269    Ok(DeviceCodeRecord {
270        id: required_string(&record, "id")?.to_owned(),
271        device_code: required_string(&record, "device_code")?.to_owned(),
272        user_code: required_string(&record, "user_code")?.to_owned(),
273        user_id: optional_string_field(&record, "user_id")?,
274        expires_at: required_timestamp(&record, "expires_at")?,
275        status: DeviceAuthorizationStatus::try_from(required_string(&record, "status")?)?,
276        last_polled_at: optional_timestamp(&record, "last_polled_at")?,
277        polling_interval: optional_number(&record, "polling_interval")?,
278        client_id: optional_string_field(&record, "client_id")?,
279        scope: optional_string_field(&record, "scope")?,
280        created_at: required_timestamp(&record, "created_at")?,
281        updated_at: required_timestamp(&record, "updated_at")?,
282    })
283}
284
285fn required_string<'a>(record: &'a DbRecord, field: &str) -> Result<&'a str, RustAuthError> {
286    match record.get(field) {
287        Some(DbValue::String(value)) => Ok(value),
288        Some(_) => Err(invalid_field(field, "string")),
289        None => Err(missing_field(field)),
290    }
291}
292
293fn optional_string_field(record: &DbRecord, field: &str) -> Result<Option<String>, RustAuthError> {
294    match record.get(field) {
295        Some(DbValue::String(value)) => Ok(Some(value.to_owned())),
296        Some(DbValue::Null) | None => Ok(None),
297        Some(_) => Err(invalid_field(field, "string or null")),
298    }
299}
300
301fn required_timestamp(record: &DbRecord, field: &str) -> Result<OffsetDateTime, RustAuthError> {
302    match record.get(field) {
303        Some(DbValue::Timestamp(value)) => Ok(*value),
304        Some(_) => Err(invalid_field(field, "timestamp")),
305        None => Err(missing_field(field)),
306    }
307}
308
309fn optional_timestamp(
310    record: &DbRecord,
311    field: &str,
312) -> Result<Option<OffsetDateTime>, RustAuthError> {
313    match record.get(field) {
314        Some(DbValue::Timestamp(value)) => Ok(Some(*value)),
315        Some(DbValue::Null) | None => Ok(None),
316        Some(_) => Err(invalid_field(field, "timestamp or null")),
317    }
318}
319
320fn optional_number(record: &DbRecord, field: &str) -> Result<Option<i64>, RustAuthError> {
321    match record.get(field) {
322        Some(DbValue::Number(value)) => Ok(Some(*value)),
323        Some(DbValue::Null) | None => Ok(None),
324        Some(_) => Err(invalid_field(field, "number or null")),
325    }
326}
327
328fn missing_field(field: &str) -> RustAuthError {
329    RustAuthError::Adapter(format!("device code record is missing `{field}`"))
330}
331
332fn invalid_field(field: &str, expected: &str) -> RustAuthError {
333    RustAuthError::Adapter(format!(
334        "device code record field `{field}` must be {expected}"
335    ))
336}