Skip to main content

rustauth_core/db/
id.rs

1use super::{DbField, DbFieldType};
2use serde::{Deserialize, Serialize};
3
4/// ID generation strategy for core database models.
5#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
6pub enum IdGeneration {
7    /// RustAuth generates string IDs.
8    #[default]
9    Random,
10    /// Database generates IDs.
11    Disabled,
12    /// Database generates numeric serial IDs.
13    Serial,
14    /// UUID IDs are used. The database may generate them natively.
15    Uuid,
16}
17
18/// Normalized ID value.
19#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
20pub enum IdValue {
21    String(String),
22    Number(i64),
23}
24
25/// ID field and transform policy.
26#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
27pub struct IdPolicy {
28    generation: IdGeneration,
29    database_supports_uuid: bool,
30    force_allow_id: bool,
31}
32
33impl IdPolicy {
34    pub fn new(generation: IdGeneration) -> Self {
35        Self {
36            generation,
37            database_supports_uuid: false,
38            force_allow_id: false,
39        }
40    }
41
42    pub fn with_database_uuid_support(mut self, supports_uuid: bool) -> Self {
43        self.database_supports_uuid = supports_uuid;
44        self
45    }
46
47    pub fn with_force_allow_id(mut self, force_allow_id: bool) -> Self {
48        self.force_allow_id = force_allow_id;
49        self
50    }
51
52    pub fn field(self) -> DbField {
53        let field_type = match self.generation {
54            IdGeneration::Serial => DbFieldType::Number,
55            IdGeneration::Random | IdGeneration::Disabled | IdGeneration::Uuid => {
56                DbFieldType::String
57            }
58        };
59
60        let mut field = DbField::new("id", field_type).generated();
61        field.required = self.should_generate_id();
62        if self.database_generates_id() {
63            field.generated_id = Some(self.generation);
64        }
65        field
66    }
67
68    pub fn transform_input(self, value: Option<&str>) -> Option<IdValue> {
69        let value = value.filter(|value| !value.is_empty())?;
70
71        match self.generation {
72            IdGeneration::Disabled => None,
73            IdGeneration::Serial => value.parse::<i64>().ok().map(IdValue::Number),
74            IdGeneration::Random => Some(IdValue::String(value.to_owned())),
75            IdGeneration::Uuid => self.transform_uuid_input(value),
76        }
77    }
78
79    pub fn transform_output(self, value: Option<IdValue>) -> Option<String> {
80        match value? {
81            IdValue::String(value) => Some(value),
82            IdValue::Number(value) => Some(value.to_string()),
83        }
84    }
85
86    fn should_generate_id(self) -> bool {
87        match self.generation {
88            IdGeneration::Random => true,
89            IdGeneration::Disabled | IdGeneration::Serial => false,
90            IdGeneration::Uuid => !self.database_supports_uuid,
91        }
92    }
93
94    fn database_generates_id(self) -> bool {
95        match self.generation {
96            IdGeneration::Disabled | IdGeneration::Serial => true,
97            IdGeneration::Uuid => self.database_supports_uuid,
98            IdGeneration::Random => false,
99        }
100    }
101
102    fn transform_uuid_input(self, value: &str) -> Option<IdValue> {
103        if self.force_allow_id {
104            return is_uuid(value).then(|| IdValue::String(value.to_owned()));
105        }
106
107        if self.database_supports_uuid {
108            None
109        } else {
110            Some(IdValue::String(value.to_owned()))
111        }
112    }
113}
114
115impl Default for IdPolicy {
116    fn default() -> Self {
117        Self::new(IdGeneration::Random)
118    }
119}
120
121fn is_uuid(value: &str) -> bool {
122    let bytes = value.as_bytes();
123    if bytes.len() != 36 {
124        return false;
125    }
126
127    for (index, byte) in bytes.iter().enumerate() {
128        if matches!(index, 8 | 13 | 18 | 23) {
129            if *byte != b'-' {
130                return false;
131            }
132            continue;
133        }
134
135        if !byte.is_ascii_hexdigit() {
136            return false;
137        }
138    }
139
140    matches!(bytes[14], b'1'..=b'5') && matches!(bytes[19], b'8' | b'9' | b'a' | b'A' | b'b' | b'B')
141}