Skip to main content

scythe_codegen/
overrides.rs

1/// A type override that replaces the inferred neutral type for a column or SQL type.
2///
3/// Overrides are evaluated in order: the first match wins. A `column` match
4/// (e.g. `"users.metadata"`) takes priority over a `db_type` match when both
5/// fields are set on the same override entry.
6#[derive(Debug, Clone)]
7pub struct TypeOverride {
8    /// Fully-qualified column reference in `"table.column"` format.
9    pub column: Option<String>,
10    /// SQL type name (matched case-insensitively against the column's neutral type).
11    pub db_type: Option<String>,
12    /// Target neutral type to substitute (e.g. `"string"`, `"json"`).
13    pub neutral_type: Option<String>,
14}
15
16impl TypeOverride {
17    /// Check if this override matches a column.
18    ///
19    /// `column_match` is `"table_name.column_name"` (empty string if unknown).
20    /// `col_neutral_type` is the neutral type inferred by the analyzer.
21    pub fn matches(&self, column_match: &str, col_neutral_type: &str) -> bool {
22        if let Some(ref col) = self.column {
23            return col == column_match;
24        }
25        if let Some(ref dt) = self.db_type {
26            return dt.eq_ignore_ascii_case(col_neutral_type);
27        }
28        false
29    }
30}
31
32/// Find the first override that matches a column and return its neutral type.
33///
34/// Returns `None` when no override matches — the caller should fall through to
35/// the default type-resolution path.
36pub fn find_override<'a>(
37    overrides: &'a [TypeOverride],
38    column_match: &str,
39    col_neutral_type: &str,
40) -> Option<&'a str> {
41    overrides.iter().find_map(|o| {
42        if o.matches(column_match, col_neutral_type) {
43            o.neutral_type.as_deref()
44        } else {
45            None
46        }
47    })
48}
49
50#[cfg(test)]
51mod tests {
52    use super::*;
53
54    #[test]
55    fn test_column_override_matches() {
56        let o = TypeOverride {
57            column: Some("users.metadata".to_string()),
58            db_type: None,
59            neutral_type: Some("json".to_string()),
60        };
61        assert!(o.matches("users.metadata", "jsonb"));
62        assert!(!o.matches("posts.metadata", "jsonb"));
63    }
64
65    #[test]
66    fn test_db_type_override_matches() {
67        let o = TypeOverride {
68            column: None,
69            db_type: Some("ltree".to_string()),
70            neutral_type: Some("string".to_string()),
71        };
72        assert!(o.matches("", "ltree"));
73        assert!(o.matches("any.col", "LTREE"));
74        assert!(!o.matches("any.col", "text"));
75    }
76
77    #[test]
78    fn test_column_takes_priority_over_db_type() {
79        let o = TypeOverride {
80            column: Some("users.name".to_string()),
81            db_type: Some("text".to_string()),
82            neutral_type: Some("custom".to_string()),
83        };
84        // column match succeeds regardless of db_type
85        assert!(o.matches("users.name", "int32"));
86        // column mismatch means no match (db_type not checked when column is set)
87        assert!(!o.matches("other.name", "text"));
88    }
89
90    #[test]
91    fn test_find_override_first_match_wins() {
92        let overrides = vec![
93            TypeOverride {
94                column: Some("users.metadata".to_string()),
95                db_type: None,
96                neutral_type: Some("json".to_string()),
97            },
98            TypeOverride {
99                column: None,
100                db_type: Some("jsonb".to_string()),
101                neutral_type: Some("string".to_string()),
102            },
103        ];
104        // column match wins over db_type match
105        assert_eq!(
106            find_override(&overrides, "users.metadata", "jsonb"),
107            Some("json")
108        );
109        // db_type fallback for non-column-matched columns
110        assert_eq!(
111            find_override(&overrides, "posts.data", "jsonb"),
112            Some("string")
113        );
114        // no match
115        assert_eq!(find_override(&overrides, "posts.data", "text"), None);
116    }
117
118    #[test]
119    fn test_find_override_empty_list() {
120        assert_eq!(find_override(&[], "users.id", "int32"), None);
121    }
122}