1use indexmap::IndexMap;
2use serde::{Deserialize, Serialize};
3
4use super::{DbRecord, DbValue, IdGeneration, IdPolicy};
5use crate::error::RustAuthError;
6
7mod builder;
8pub use builder::auth_schema;
9
10#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
12pub enum RateLimitStorage {
13 #[default]
14 Memory,
15 Database,
16 SecondaryStorage,
17}
18
19#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
21pub struct TableOptions {
22 pub name: Option<String>,
23 pub field_names: IndexMap<String, String>,
24 pub additional_fields: IndexMap<String, DbField>,
25}
26
27impl TableOptions {
28 pub fn with_name(mut self, name: impl Into<String>) -> Self {
30 self.name = Some(name.into());
31 self
32 }
33
34 pub fn with_field_name(
36 mut self,
37 logical_name: impl Into<String>,
38 db_name: impl Into<String>,
39 ) -> Self {
40 self.field_names.insert(logical_name.into(), db_name.into());
41 self
42 }
43
44 pub fn with_field(mut self, logical_name: impl Into<String>, field: DbField) -> Self {
46 self.additional_fields.insert(logical_name.into(), field);
47 self
48 }
49
50 fn field_name(&self, logical_name: &str) -> String {
51 self.field_names
52 .get(logical_name)
53 .cloned()
54 .unwrap_or_else(|| logical_name.to_owned())
55 }
56}
57
58#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
60pub struct AuthSchemaOptions {
61 pub id_policy: IdPolicy,
62 pub user: TableOptions,
63 pub account: TableOptions,
64 pub session: TableOptions,
65 pub verification: TableOptions,
66 pub rate_limit: TableOptions,
67 pub has_secondary_storage: bool,
68 pub store_session_in_database: bool,
69 pub store_verification_in_database: bool,
70 pub rate_limit_storage: RateLimitStorage,
71}
72
73#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
75pub enum DbFieldType {
76 String,
77 Number,
78 Boolean,
79 Timestamp,
80 Json,
81 StringArray,
82 NumberArray,
83}
84
85#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
87pub enum OnDelete {
88 NoAction,
89 Restrict,
90 Cascade,
91 SetNull,
92 SetDefault,
93}
94
95#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
97pub struct ForeignKey {
98 pub table: String,
99 pub field: String,
100 pub on_delete: OnDelete,
101}
102
103impl ForeignKey {
104 pub fn new(table: impl Into<String>, field: impl Into<String>, on_delete: OnDelete) -> Self {
105 Self {
106 table: table.into(),
107 field: field.into(),
108 on_delete,
109 }
110 }
111}
112
113#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
115pub struct DbField {
116 pub name: String,
117 pub field_type: DbFieldType,
118 pub required: bool,
119 pub unique: bool,
120 pub index: bool,
121 pub returned: bool,
122 pub input: bool,
123 pub foreign_key: Option<ForeignKey>,
124 #[serde(default)]
125 pub generated_id: Option<IdGeneration>,
126}
127
128impl DbField {
129 pub fn new(name: impl Into<String>, field_type: DbFieldType) -> Self {
131 Self {
132 name: name.into(),
133 field_type,
134 required: true,
135 unique: false,
136 index: false,
137 returned: true,
138 input: true,
139 foreign_key: None,
140 generated_id: None,
141 }
142 }
143
144 pub fn optional(mut self) -> Self {
145 self.required = false;
146 self
147 }
148
149 pub fn unique(mut self) -> Self {
150 self.unique = true;
151 self
152 }
153
154 pub fn indexed(mut self) -> Self {
155 self.index = true;
156 self
157 }
158
159 pub fn hidden(mut self) -> Self {
160 self.returned = false;
161 self
162 }
163
164 pub fn generated(mut self) -> Self {
165 self.input = false;
166 self
167 }
168
169 pub fn generated_id(mut self, generation: IdGeneration) -> Self {
170 self.generated_id = Some(generation);
171 self.generated()
172 }
173
174 pub fn references(mut self, foreign_key: ForeignKey) -> Self {
175 self.foreign_key = Some(foreign_key);
176 self
177 }
178}
179
180#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
182pub struct DbTable {
183 pub name: String,
184 pub fields: IndexMap<String, DbField>,
185 pub order: Option<u16>,
186}
187
188impl DbTable {
189 pub fn field(&self, logical_name: &str) -> Option<&DbField> {
190 self.fields.get(logical_name)
191 }
192
193 pub(crate) fn logical_field_name(&self, field: &str) -> Option<&str> {
194 if let Some((logical, _)) = self.fields.get_key_value(field) {
195 return Some(logical.as_str());
196 }
197 self.fields
198 .iter()
199 .find_map(|(logical, metadata)| (metadata.name == field).then_some(logical.as_str()))
200 }
201
202 fn resolve_field(&self, field: &str) -> Option<&DbField> {
203 self.fields
204 .get(field)
205 .or_else(|| self.fields.values().find(|metadata| metadata.name == field))
206 }
207}
208
209#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
211pub struct DbSchema {
212 tables: IndexMap<String, DbTable>,
213}
214
215impl DbSchema {
216 pub fn table(&self, logical_name: &str) -> Option<&DbTable> {
217 self.tables.get(logical_name)
218 }
219
220 pub fn table_name(&self, table: &str) -> Result<&str, RustAuthError> {
222 self.resolve_table(table)
223 .map(|(_, table)| table.name.as_str())
224 .ok_or_else(|| RustAuthError::TableNotFound {
225 table: table.to_owned(),
226 })
227 }
228
229 pub fn field_name(&self, table: &str, field: &str) -> Result<&str, RustAuthError> {
231 self.field(table, field)
232 .map(|field| field.name.as_str())
233 .map_err(|_| RustAuthError::FieldNotFound {
234 table: table.to_owned(),
235 field: field.to_owned(),
236 })
237 }
238
239 pub fn field(&self, table: &str, field: &str) -> Result<&DbField, RustAuthError> {
241 let (_, table_metadata) =
242 self.resolve_table(table)
243 .ok_or_else(|| RustAuthError::TableNotFound {
244 table: table.to_owned(),
245 })?;
246
247 table_metadata
248 .resolve_field(field)
249 .ok_or_else(|| RustAuthError::FieldNotFound {
250 table: table.to_owned(),
251 field: field.to_owned(),
252 })
253 }
254
255 pub fn map_record_to_logical(
260 &self,
261 table: &str,
262 record: DbRecord,
263 ) -> Result<DbRecord, RustAuthError> {
264 let (_, table_metadata) =
265 self.resolve_table(table)
266 .ok_or_else(|| RustAuthError::TableNotFound {
267 table: table.to_owned(),
268 })?;
269
270 let mut mapped = DbRecord::new();
271 for (key, value) in record {
272 if let Some(logical) = table_metadata.logical_field_name(&key) {
273 mapped.insert(logical.to_owned(), value);
274 } else if self.resolve_table(&key).is_some() {
275 mapped.insert(key.clone(), map_join_value(self, &key, value)?);
276 } else {
277 mapped.insert(key, value);
278 }
279 }
280 Ok(mapped)
281 }
282
283 pub fn tables(&self) -> impl Iterator<Item = (&str, &DbTable)> {
284 self.tables
285 .iter()
286 .map(|(logical_name, table)| (logical_name.as_str(), table))
287 }
288
289 pub fn insert_plugin_table(
290 &mut self,
291 logical_name: String,
292 table: DbTable,
293 ) -> Result<(), RustAuthError> {
294 if let Some(existing) = self.tables.get(&logical_name) {
295 if existing == &table {
296 return Ok(());
297 }
298 return Err(RustAuthError::InvalidConfig(format!(
299 "plugin schema table `{logical_name}` conflicts with an existing table"
300 )));
301 }
302 if self
303 .tables
304 .values()
305 .any(|existing| existing.name == table.name)
306 {
307 return Err(RustAuthError::InvalidConfig(format!(
308 "plugin schema table `{logical_name}` uses existing database table `{}`",
309 table.name
310 )));
311 }
312 self.tables.insert(logical_name, table);
313 Ok(())
314 }
315
316 pub fn insert_plugin_field(
317 &mut self,
318 table: &str,
319 logical_name: String,
320 field: DbField,
321 ) -> Result<(), RustAuthError> {
322 let (_, table_metadata) =
323 self.resolve_table_mut(table)
324 .ok_or_else(|| RustAuthError::TableNotFound {
325 table: table.to_owned(),
326 })?;
327
328 if let Some(existing) = table_metadata.fields.get(&logical_name) {
329 if existing == &field {
330 return Ok(());
331 }
332 return Err(RustAuthError::InvalidConfig(format!(
333 "plugin schema field `{logical_name}` conflicts with table `{table}`"
334 )));
335 }
336 if table_metadata
337 .fields
338 .values()
339 .any(|existing| existing.name == field.name)
340 {
341 return Err(RustAuthError::InvalidConfig(format!(
342 "plugin schema field `{logical_name}` uses existing database field `{}` on table `{table}`",
343 field.name
344 )));
345 }
346 table_metadata.fields.insert(logical_name, field);
347 Ok(())
348 }
349
350 fn resolve_table(&self, table: &str) -> Option<(&str, &DbTable)> {
351 self.tables
352 .get_key_value(table)
353 .map(|(logical_name, table)| (logical_name.as_str(), table))
354 .or_else(|| {
355 self.tables
356 .iter()
357 .find(|(_, table_metadata)| table_metadata.name == table)
358 .map(|(logical_name, table)| (logical_name.as_str(), table))
359 })
360 }
361
362 fn resolve_table_mut(&mut self, table: &str) -> Option<(&str, &mut DbTable)> {
363 if self.tables.contains_key(table) {
364 let (logical_name, table_metadata) = self.tables.get_key_value_mut(table)?;
365 return Some((logical_name.as_str(), table_metadata));
366 }
367 self.tables
368 .iter_mut()
369 .find(|(_, table_metadata)| table_metadata.name == table)
370 .map(|(logical_name, table)| (logical_name.as_str(), table))
371 }
372
373 fn insert(&mut self, logical_name: impl Into<String>, table: DbTable) {
374 self.tables.insert(logical_name.into(), table);
375 }
376}
377
378fn map_join_value(
379 schema: &DbSchema,
380 logical_table: &str,
381 value: DbValue,
382) -> Result<DbValue, RustAuthError> {
383 match value {
384 DbValue::Record(record) => Ok(DbValue::Record(
385 schema.map_record_to_logical(logical_table, record)?,
386 )),
387 DbValue::RecordArray(records) => Ok(DbValue::RecordArray(
388 records
389 .into_iter()
390 .map(|record| schema.map_record_to_logical(logical_table, record))
391 .collect::<Result<Vec<_>, _>>()?,
392 )),
393 DbValue::Null => Ok(DbValue::Null),
394 other => Ok(other),
395 }
396}