Skip to main content

rustrails_record/
schema.rs

1use serde_json::Value;
2
3/// Supported column types in the lightweight schema DSL.
4#[derive(Debug, Clone, PartialEq)]
5pub enum ColumnType {
6    /// A signed 32-bit integer column.
7    Integer,
8    /// A signed 64-bit integer column.
9    BigInteger,
10    /// A variable-length string column with an optional length limit.
11    String(Option<usize>),
12    /// An unbounded text column.
13    Text,
14    /// A boolean column.
15    Boolean,
16    /// A floating-point column.
17    Float,
18    /// A fixed-precision decimal column.
19    Decimal {
20        /// The total number of digits.
21        precision: u32,
22        /// The number of digits after the decimal point.
23        scale: u32,
24    },
25    /// A timestamp column.
26    DateTime,
27    /// A date-only column.
28    Date,
29    /// A time-only column.
30    Time,
31    /// A binary blob column.
32    Binary,
33    /// A JSON column.
34    Json,
35    /// A UUID column.
36    Uuid,
37}
38
39/// A column definition in the lightweight schema DSL.
40#[derive(Debug, Clone, PartialEq)]
41pub struct ColumnDef {
42    /// The column name.
43    pub name: String,
44    /// The logical column type.
45    pub col_type: ColumnType,
46    /// Whether the column accepts `NULL` values.
47    pub nullable: bool,
48    /// The default value when one is configured.
49    pub default: Option<Value>,
50    /// Whether this column is the primary key.
51    pub primary_key: bool,
52    /// Whether the column auto-increments.
53    pub auto_increment: bool,
54    /// Whether the column is unique.
55    pub unique: bool,
56    /// Whether the column should be indexed.
57    pub index: bool,
58}
59
60impl ColumnDef {
61    /// Creates a new column definition with sensible defaults.
62    pub fn new(name: &str, col_type: ColumnType) -> Self {
63        Self {
64            name: name.to_owned(),
65            col_type,
66            nullable: false,
67            default: None,
68            primary_key: false,
69            auto_increment: false,
70            unique: false,
71            index: false,
72        }
73    }
74
75    /// Marks the column as nullable.
76    pub fn nullable(mut self) -> Self {
77        self.nullable = true;
78        self
79    }
80
81    /// Sets the default value for the column.
82    pub fn default(mut self, val: Value) -> Self {
83        self.default = Some(val);
84        self
85    }
86
87    /// Marks the column as the primary key.
88    pub fn primary_key(mut self) -> Self {
89        self.primary_key = true;
90        self
91    }
92
93    /// Marks the column as auto-incrementing.
94    pub fn auto_increment(mut self) -> Self {
95        self.auto_increment = true;
96        self
97    }
98
99    /// Marks the column as unique.
100    pub fn unique(mut self) -> Self {
101        self.unique = true;
102        self
103    }
104}
105
106/// A table definition assembled with the lightweight schema DSL.
107#[derive(Debug, Clone, PartialEq)]
108pub struct TableDef {
109    /// The table name.
110    pub name: String,
111    /// The column definitions for the table.
112    pub columns: Vec<ColumnDef>,
113    /// Whether timestamp columns are enabled.
114    pub timestamps: bool,
115}
116
117impl TableDef {
118    /// Creates a new empty table definition.
119    pub fn new(name: &str) -> Self {
120        Self {
121            name: name.to_owned(),
122            columns: Vec::new(),
123            timestamps: false,
124        }
125    }
126
127    /// Adds a column definition.
128    pub fn column(mut self, col: ColumnDef) -> Self {
129        self.columns.push(col);
130        self
131    }
132
133    /// Enables `created_at` and `updated_at` timestamp columns.
134    pub fn timestamps(mut self) -> Self {
135        self.timestamps = true;
136        Self::ensure_timestamps(&mut self.columns);
137        self
138    }
139
140    /// Finalizes the table definition.
141    pub fn build(mut self) -> Self {
142        if self.timestamps {
143            Self::ensure_timestamps(&mut self.columns);
144        }
145        self
146    }
147
148    fn ensure_timestamps(columns: &mut Vec<ColumnDef>) {
149        if !columns.iter().any(|column| column.name == "created_at") {
150            columns.push(ColumnDef::new("created_at", ColumnType::DateTime));
151        }
152        if !columns.iter().any(|column| column.name == "updated_at") {
153            columns.push(ColumnDef::new("updated_at", ColumnType::DateTime));
154        }
155    }
156}
157
158#[cfg(test)]
159mod tests {
160    use serde_json::json;
161
162    use super::{ColumnDef, ColumnType, TableDef};
163
164    #[tokio::test]
165    async fn column_definition_uses_expected_defaults() {
166        let column = ColumnDef::new("name", ColumnType::String(Some(255)));
167
168        assert_eq!(column.name, "name");
169        assert_eq!(column.col_type, ColumnType::String(Some(255)));
170        assert!(!column.nullable);
171        assert!(column.default.is_none());
172        assert!(!column.primary_key);
173        assert!(!column.auto_increment);
174        assert!(!column.unique);
175    }
176
177    #[tokio::test]
178    async fn column_builder_methods_set_flags() {
179        let column = ColumnDef::new("id", ColumnType::BigInteger)
180            .primary_key()
181            .auto_increment()
182            .unique();
183
184        assert!(column.primary_key);
185        assert!(column.auto_increment);
186        assert!(column.unique);
187    }
188
189    #[tokio::test]
190    async fn nullable_and_default_are_recorded() {
191        let column = ColumnDef::new("published", ColumnType::Boolean)
192            .nullable()
193            .default(json!(true));
194
195        assert!(column.nullable);
196        assert_eq!(column.default, Some(json!(true)));
197    }
198
199    #[tokio::test]
200    async fn table_builder_collects_columns() {
201        let table = TableDef::new("posts")
202            .column(ColumnDef::new("id", ColumnType::Integer).primary_key())
203            .column(ColumnDef::new("title", ColumnType::String(Some(255))))
204            .build();
205
206        assert_eq!(table.name, "posts");
207        assert_eq!(table.columns.len(), 2);
208        assert_eq!(table.columns[0].name, "id");
209        assert_eq!(table.columns[1].name, "title");
210    }
211
212    #[tokio::test]
213    async fn timestamps_add_created_and_updated_columns() {
214        let table = TableDef::new("posts").timestamps().build();
215
216        assert!(table.timestamps);
217        assert!(
218            table
219                .columns
220                .iter()
221                .any(|column| column.name == "created_at")
222        );
223        assert!(
224            table
225                .columns
226                .iter()
227                .any(|column| column.name == "updated_at")
228        );
229        assert_eq!(table.columns.len(), 2);
230    }
231
232    #[test]
233    fn column_definition_defaults_index_to_false() {
234        let column = ColumnDef::new("email", ColumnType::String(None));
235
236        assert!(!column.index);
237    }
238
239    #[test]
240    fn table_definition_starts_without_columns_or_timestamps() {
241        let table = TableDef::new("accounts");
242
243        assert_eq!(table.name, "accounts");
244        assert!(table.columns.is_empty());
245        assert!(!table.timestamps);
246    }
247
248    #[test]
249    fn integer_column_type_is_retained() {
250        let column = ColumnDef::new("age", ColumnType::Integer);
251
252        assert_eq!(column.col_type, ColumnType::Integer);
253    }
254
255    #[test]
256    fn boolean_column_type_is_retained() {
257        let column = ColumnDef::new("published", ColumnType::Boolean);
258
259        assert_eq!(column.col_type, ColumnType::Boolean);
260    }
261
262    #[test]
263    fn text_column_type_is_retained() {
264        let column = ColumnDef::new("body", ColumnType::Text);
265
266        assert_eq!(column.col_type, ColumnType::Text);
267    }
268
269    #[test]
270    fn datetime_column_type_is_retained() {
271        let column = ColumnDef::new("published_at", ColumnType::DateTime);
272
273        assert_eq!(column.col_type, ColumnType::DateTime);
274    }
275
276    #[test]
277    fn float_column_type_is_retained() {
278        let column = ColumnDef::new("rating", ColumnType::Float);
279
280        assert_eq!(column.col_type, ColumnType::Float);
281    }
282
283    #[test]
284    fn decimal_column_type_retains_precision_and_scale() {
285        let column = ColumnDef::new(
286            "amount",
287            ColumnType::Decimal {
288                precision: 12,
289                scale: 4,
290            },
291        );
292
293        assert_eq!(
294            column.col_type,
295            ColumnType::Decimal {
296                precision: 12,
297                scale: 4,
298            }
299        );
300    }
301
302    #[test]
303    fn default_value_can_store_strings() {
304        let column = ColumnDef::new("status", ColumnType::String(Some(20))).default(json!("draft"));
305
306        assert_eq!(column.default, Some(json!("draft")));
307    }
308
309    #[test]
310    fn default_value_can_store_integers() {
311        let column = ColumnDef::new("retries", ColumnType::Integer).default(json!(3));
312
313        assert_eq!(column.default, Some(json!(3)));
314    }
315
316    #[test]
317    fn default_value_can_store_objects() {
318        let column = ColumnDef::new("settings", ColumnType::Json).default(json!({"theme": "dark"}));
319
320        assert_eq!(column.default, Some(json!({"theme": "dark"})));
321    }
322
323    #[test]
324    fn build_without_timestamps_does_not_add_timestamp_columns() {
325        let table = TableDef::new("posts")
326            .column(ColumnDef::new("title", ColumnType::String(Some(255))))
327            .build();
328
329        assert_eq!(table.columns.len(), 1);
330        assert!(
331            table
332                .columns
333                .iter()
334                .all(|column| column.name != "created_at")
335        );
336        assert!(
337            table
338                .columns
339                .iter()
340                .all(|column| column.name != "updated_at")
341        );
342    }
343
344    #[test]
345    fn timestamps_preserve_existing_created_at_column() {
346        let table = TableDef::new("posts")
347            .column(ColumnDef::new("created_at", ColumnType::DateTime).nullable())
348            .timestamps()
349            .build();
350
351        assert_eq!(
352            table
353                .columns
354                .iter()
355                .filter(|column| column.name == "created_at")
356                .count(),
357            1
358        );
359        assert!(
360            table
361                .columns
362                .iter()
363                .find(|column| column.name == "created_at")
364                .expect("created_at should exist")
365                .nullable
366        );
367    }
368
369    #[test]
370    fn timestamps_preserve_existing_updated_at_column() {
371        let table = TableDef::new("posts")
372            .column(ColumnDef::new("updated_at", ColumnType::DateTime).nullable())
373            .timestamps()
374            .build();
375
376        assert_eq!(
377            table
378                .columns
379                .iter()
380                .filter(|column| column.name == "updated_at")
381                .count(),
382            1
383        );
384        assert!(
385            table
386                .columns
387                .iter()
388                .find(|column| column.name == "updated_at")
389                .expect("updated_at should exist")
390                .nullable
391        );
392    }
393
394    #[test]
395    fn timestamps_add_only_missing_timestamp_column() {
396        let table = TableDef::new("posts")
397            .column(ColumnDef::new("created_at", ColumnType::DateTime))
398            .timestamps()
399            .build();
400
401        let names = table
402            .columns
403            .iter()
404            .map(|column| column.name.as_str())
405            .collect::<Vec<_>>();
406
407        assert_eq!(names, vec!["created_at", "updated_at"]);
408    }
409
410    #[test]
411    fn repeated_timestamps_calls_do_not_duplicate_columns() {
412        let table = TableDef::new("posts").timestamps().timestamps().build();
413
414        assert_eq!(
415            table
416                .columns
417                .iter()
418                .filter(|column| column.name == "created_at")
419                .count(),
420            1
421        );
422        assert_eq!(
423            table
424                .columns
425                .iter()
426                .filter(|column| column.name == "updated_at")
427                .count(),
428            1
429        );
430    }
431
432    #[test]
433    fn timestamps_append_columns_after_existing_definitions() {
434        let table = TableDef::new("posts")
435            .column(ColumnDef::new("id", ColumnType::Integer).primary_key())
436            .column(ColumnDef::new("title", ColumnType::String(Some(255))))
437            .timestamps()
438            .build();
439
440        let names = table
441            .columns
442            .iter()
443            .map(|column| column.name.as_str())
444            .collect::<Vec<_>>();
445
446        assert_eq!(names, vec!["id", "title", "created_at", "updated_at"]);
447    }
448
449    #[test]
450    fn timestamp_columns_use_datetime_type_and_non_nullable_defaults() {
451        let table = TableDef::new("posts").timestamps().build();
452
453        let created_at = table
454            .columns
455            .iter()
456            .find(|column| column.name == "created_at")
457            .expect("created_at should exist");
458        let updated_at = table
459            .columns
460            .iter()
461            .find(|column| column.name == "updated_at")
462            .expect("updated_at should exist");
463
464        assert_eq!(created_at.col_type, ColumnType::DateTime);
465        assert_eq!(updated_at.col_type, ColumnType::DateTime);
466        assert!(!created_at.nullable);
467        assert!(!updated_at.nullable);
468    }
469}