Skip to main content

rustauth_core/db/
schema.rs

1use indexmap::IndexMap;
2use serde::{Deserialize, Serialize};
3
4use super::{DbRecord, DbValue, IdGeneration, IdPolicy};
5use crate::error::RustAuthError;
6
7mod builder;
8pub use builder::auth_schema;
9
10/// Storage backend selected for rate limit counters.
11#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
12pub enum RateLimitStorage {
13    #[default]
14    Memory,
15    Database,
16    SecondaryStorage,
17}
18
19/// Per-table schema overrides.
20#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
21pub struct TableOptions {
22    pub name: Option<String>,
23    pub field_names: IndexMap<String, String>,
24    pub additional_fields: IndexMap<String, DbField>,
25}
26
27impl TableOptions {
28    /// Return a copy of these options with a custom database table name.
29    pub fn with_name(mut self, name: impl Into<String>) -> Self {
30        self.name = Some(name.into());
31        self
32    }
33
34    /// Return a copy of these options with a custom database column name.
35    pub fn with_field_name(
36        mut self,
37        logical_name: impl Into<String>,
38        db_name: impl Into<String>,
39    ) -> Self {
40        self.field_names.insert(logical_name.into(), db_name.into());
41        self
42    }
43
44    /// Return a copy of these options with an additional logical field.
45    pub fn with_field(mut self, logical_name: impl Into<String>, field: DbField) -> Self {
46        self.additional_fields.insert(logical_name.into(), field);
47        self
48    }
49
50    fn field_name(&self, logical_name: &str) -> String {
51        self.field_names
52            .get(logical_name)
53            .cloned()
54            .unwrap_or_else(|| logical_name.to_owned())
55    }
56}
57
58/// Options used to build RustAuth's core database schema metadata.
59#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
60pub struct AuthSchemaOptions {
61    pub id_policy: IdPolicy,
62    pub user: TableOptions,
63    pub account: TableOptions,
64    pub session: TableOptions,
65    pub verification: TableOptions,
66    pub rate_limit: TableOptions,
67    pub has_secondary_storage: bool,
68    pub store_session_in_database: bool,
69    pub store_verification_in_database: bool,
70    pub rate_limit_storage: RateLimitStorage,
71}
72
73/// Supported database field kinds for core schema metadata.
74#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
75pub enum DbFieldType {
76    String,
77    Number,
78    Boolean,
79    Timestamp,
80    Json,
81    StringArray,
82    NumberArray,
83}
84
85/// Foreign key delete behavior.
86#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
87pub enum OnDelete {
88    NoAction,
89    Restrict,
90    Cascade,
91    SetNull,
92    SetDefault,
93}
94
95/// Foreign key metadata for adapter and migration implementations.
96#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
97pub struct ForeignKey {
98    pub table: String,
99    pub field: String,
100    pub on_delete: OnDelete,
101}
102
103impl ForeignKey {
104    pub fn new(table: impl Into<String>, field: impl Into<String>, on_delete: OnDelete) -> Self {
105        Self {
106            table: table.into(),
107            field: field.into(),
108            on_delete,
109        }
110    }
111}
112
113/// Field metadata used by adapters and migrations.
114#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
115pub struct DbField {
116    pub name: String,
117    pub field_type: DbFieldType,
118    pub required: bool,
119    pub unique: bool,
120    pub index: bool,
121    pub returned: bool,
122    pub input: bool,
123    pub foreign_key: Option<ForeignKey>,
124    #[serde(default)]
125    pub generated_id: Option<IdGeneration>,
126}
127
128impl DbField {
129    /// Create a required, returned, input-accepted field.
130    pub fn new(name: impl Into<String>, field_type: DbFieldType) -> Self {
131        Self {
132            name: name.into(),
133            field_type,
134            required: true,
135            unique: false,
136            index: false,
137            returned: true,
138            input: true,
139            foreign_key: None,
140            generated_id: None,
141        }
142    }
143
144    pub fn optional(mut self) -> Self {
145        self.required = false;
146        self
147    }
148
149    pub fn unique(mut self) -> Self {
150        self.unique = true;
151        self
152    }
153
154    pub fn indexed(mut self) -> Self {
155        self.index = true;
156        self
157    }
158
159    pub fn hidden(mut self) -> Self {
160        self.returned = false;
161        self
162    }
163
164    pub fn generated(mut self) -> Self {
165        self.input = false;
166        self
167    }
168
169    pub fn generated_id(mut self, generation: IdGeneration) -> Self {
170        self.generated_id = Some(generation);
171        self.generated()
172    }
173
174    pub fn references(mut self, foreign_key: ForeignKey) -> Self {
175        self.foreign_key = Some(foreign_key);
176        self
177    }
178}
179
180/// Table metadata keyed by logical field name.
181#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
182pub struct DbTable {
183    pub name: String,
184    pub fields: IndexMap<String, DbField>,
185    pub order: Option<u16>,
186}
187
188impl DbTable {
189    pub fn field(&self, logical_name: &str) -> Option<&DbField> {
190        self.fields.get(logical_name)
191    }
192
193    pub(crate) fn logical_field_name(&self, field: &str) -> Option<&str> {
194        if let Some((logical, _)) = self.fields.get_key_value(field) {
195            return Some(logical.as_str());
196        }
197        self.fields
198            .iter()
199            .find_map(|(logical, metadata)| (metadata.name == field).then_some(logical.as_str()))
200    }
201
202    fn resolve_field(&self, field: &str) -> Option<&DbField> {
203        self.fields
204            .get(field)
205            .or_else(|| self.fields.values().find(|metadata| metadata.name == field))
206    }
207}
208
209/// Schema metadata keyed by logical table name.
210#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
211pub struct DbSchema {
212    tables: IndexMap<String, DbTable>,
213}
214
215impl DbSchema {
216    pub fn table(&self, logical_name: &str) -> Option<&DbTable> {
217        self.tables.get(logical_name)
218    }
219
220    /// Resolve a logical or physical table name to its physical database name.
221    pub fn table_name(&self, table: &str) -> Result<&str, RustAuthError> {
222        self.resolve_table(table)
223            .map(|(_, table)| table.name.as_str())
224            .ok_or_else(|| RustAuthError::TableNotFound {
225                table: table.to_owned(),
226            })
227    }
228
229    /// Resolve a logical or physical field name to its physical database column name.
230    pub fn field_name(&self, table: &str, field: &str) -> Result<&str, RustAuthError> {
231        self.field(table, field)
232            .map(|field| field.name.as_str())
233            .map_err(|_| RustAuthError::FieldNotFound {
234                table: table.to_owned(),
235                field: field.to_owned(),
236            })
237    }
238
239    /// Resolve field metadata from logical or physical table and field names.
240    pub fn field(&self, table: &str, field: &str) -> Result<&DbField, RustAuthError> {
241        let (_, table_metadata) =
242            self.resolve_table(table)
243                .ok_or_else(|| RustAuthError::TableNotFound {
244                    table: table.to_owned(),
245                })?;
246
247        table_metadata
248            .resolve_field(field)
249            .ok_or_else(|| RustAuthError::FieldNotFound {
250                table: table.to_owned(),
251                field: field.to_owned(),
252            })
253    }
254
255    /// Map physical database column keys in a record to logical field names.
256    ///
257    /// Unknown columns are preserved. Nested join records keyed by logical table
258    /// names are mapped recursively. Idempotent when keys are already logical.
259    pub fn map_record_to_logical(
260        &self,
261        table: &str,
262        record: DbRecord,
263    ) -> Result<DbRecord, RustAuthError> {
264        let (_, table_metadata) =
265            self.resolve_table(table)
266                .ok_or_else(|| RustAuthError::TableNotFound {
267                    table: table.to_owned(),
268                })?;
269
270        let mut mapped = DbRecord::new();
271        for (key, value) in record {
272            if let Some(logical) = table_metadata.logical_field_name(&key) {
273                mapped.insert(logical.to_owned(), value);
274            } else if self.resolve_table(&key).is_some() {
275                mapped.insert(key.clone(), map_join_value(self, &key, value)?);
276            } else {
277                mapped.insert(key, value);
278            }
279        }
280        Ok(mapped)
281    }
282
283    pub fn tables(&self) -> impl Iterator<Item = (&str, &DbTable)> {
284        self.tables
285            .iter()
286            .map(|(logical_name, table)| (logical_name.as_str(), table))
287    }
288
289    pub fn insert_plugin_table(
290        &mut self,
291        logical_name: String,
292        table: DbTable,
293    ) -> Result<(), RustAuthError> {
294        if let Some(existing) = self.tables.get(&logical_name) {
295            if existing == &table {
296                return Ok(());
297            }
298            return Err(RustAuthError::InvalidConfig(format!(
299                "plugin schema table `{logical_name}` conflicts with an existing table"
300            )));
301        }
302        if self
303            .tables
304            .values()
305            .any(|existing| existing.name == table.name)
306        {
307            return Err(RustAuthError::InvalidConfig(format!(
308                "plugin schema table `{logical_name}` uses existing database table `{}`",
309                table.name
310            )));
311        }
312        self.tables.insert(logical_name, table);
313        Ok(())
314    }
315
316    pub fn insert_plugin_field(
317        &mut self,
318        table: &str,
319        logical_name: String,
320        field: DbField,
321    ) -> Result<(), RustAuthError> {
322        let (_, table_metadata) =
323            self.resolve_table_mut(table)
324                .ok_or_else(|| RustAuthError::TableNotFound {
325                    table: table.to_owned(),
326                })?;
327
328        if let Some(existing) = table_metadata.fields.get(&logical_name) {
329            if existing == &field {
330                return Ok(());
331            }
332            return Err(RustAuthError::InvalidConfig(format!(
333                "plugin schema field `{logical_name}` conflicts with table `{table}`"
334            )));
335        }
336        if table_metadata
337            .fields
338            .values()
339            .any(|existing| existing.name == field.name)
340        {
341            return Err(RustAuthError::InvalidConfig(format!(
342                "plugin schema field `{logical_name}` uses existing database field `{}` on table `{table}`",
343                field.name
344            )));
345        }
346        table_metadata.fields.insert(logical_name, field);
347        Ok(())
348    }
349
350    fn resolve_table(&self, table: &str) -> Option<(&str, &DbTable)> {
351        self.tables
352            .get_key_value(table)
353            .map(|(logical_name, table)| (logical_name.as_str(), table))
354            .or_else(|| {
355                self.tables
356                    .iter()
357                    .find(|(_, table_metadata)| table_metadata.name == table)
358                    .map(|(logical_name, table)| (logical_name.as_str(), table))
359            })
360    }
361
362    fn resolve_table_mut(&mut self, table: &str) -> Option<(&str, &mut DbTable)> {
363        if self.tables.contains_key(table) {
364            let (logical_name, table_metadata) = self.tables.get_key_value_mut(table)?;
365            return Some((logical_name.as_str(), table_metadata));
366        }
367        self.tables
368            .iter_mut()
369            .find(|(_, table_metadata)| table_metadata.name == table)
370            .map(|(logical_name, table)| (logical_name.as_str(), table))
371    }
372
373    fn insert(&mut self, logical_name: impl Into<String>, table: DbTable) {
374        self.tables.insert(logical_name.into(), table);
375    }
376}
377
378fn map_join_value(
379    schema: &DbSchema,
380    logical_table: &str,
381    value: DbValue,
382) -> Result<DbValue, RustAuthError> {
383    match value {
384        DbValue::Record(record) => Ok(DbValue::Record(
385            schema.map_record_to_logical(logical_table, record)?,
386        )),
387        DbValue::RecordArray(records) => Ok(DbValue::RecordArray(
388            records
389                .into_iter()
390                .map(|record| schema.map_record_to_logical(logical_table, record))
391                .collect::<Result<Vec<_>, _>>()?,
392        )),
393        DbValue::Null => Ok(DbValue::Null),
394        other => Ok(other),
395    }
396}