Skip to main content

rustrails_record/
inheritance.rs

1use std::collections::HashMap;
2
3use serde_json::Value;
4
5/// Configuration for single-table inheritance behavior.
6#[derive(Debug, Clone, PartialEq, Eq)]
7pub struct InheritanceConfig {
8    /// The discriminator column name.
9    pub inheritance_column: String,
10}
11
12impl Default for InheritanceConfig {
13    fn default() -> Self {
14        Self {
15            inheritance_column: "type".to_owned(),
16        }
17    }
18}
19
20impl InheritanceConfig {
21    /// Creates a new inheritance configuration.
22    #[must_use]
23    pub fn new(inheritance_column: impl Into<String>) -> Self {
24        Self {
25            inheritance_column: inheritance_column.into(),
26        }
27    }
28}
29
30/// Describes a concrete STI subtype.
31pub trait StiType {
32    /// Returns the discriminator value stored for this subtype.
33    fn sti_name() -> &'static str;
34}
35
36/// Trait implemented by base records that participate in STI.
37pub trait SingleTableInheritance {
38    /// Returns STI configuration for the record hierarchy.
39    fn inheritance_config() -> &'static InheritanceConfig {
40        static DEFAULT_CONFIG: std::sync::LazyLock<InheritanceConfig> =
41            std::sync::LazyLock::new(InheritanceConfig::default);
42        &DEFAULT_CONFIG
43    }
44}
45
46/// Casts a record into another subtype via `From` conversion.
47#[must_use]
48pub fn becomes<T, R>(record: R) -> T
49where
50    T: From<R>,
51{
52    T::from(record)
53}
54
55/// Applies an STI discriminator filter to query conditions.
56#[must_use]
57pub fn scope_for_type<T>(
58    config: &InheritanceConfig,
59    mut conditions: HashMap<String, Value>,
60) -> HashMap<String, Value>
61where
62    T: StiType,
63{
64    conditions.insert(
65        config.inheritance_column.clone(),
66        Value::String(T::sti_name().to_owned()),
67    );
68    conditions
69}
70
71/// Returns `true` when the attribute hash matches the requested subtype.
72#[must_use]
73pub fn matches_type<T>(config: &InheritanceConfig, attributes: &HashMap<String, Value>) -> bool
74where
75    T: StiType,
76{
77    attributes
78        .get(&config.inheritance_column)
79        .and_then(Value::as_str)
80        == Some(T::sti_name())
81}
82
83#[cfg(test)]
84mod tests {
85    use std::collections::HashMap;
86    use std::sync::LazyLock;
87
88    use serde_json::json;
89
90    use super::{
91        InheritanceConfig, SingleTableInheritance, StiType, becomes, matches_type, scope_for_type,
92    };
93
94    #[derive(Debug, Clone, PartialEq, Eq)]
95    struct CompanyRecord {
96        id: i64,
97        name: String,
98        record_type: String,
99    }
100
101    #[derive(Debug, Clone, PartialEq, Eq)]
102    struct FirmRecord {
103        id: i64,
104        name: String,
105    }
106
107    #[derive(Debug, Clone, PartialEq, Eq)]
108    struct ClientRecord {
109        id: i64,
110        name: String,
111    }
112
113    impl From<CompanyRecord> for FirmRecord {
114        fn from(value: CompanyRecord) -> Self {
115            Self {
116                id: value.id,
117                name: value.name,
118            }
119        }
120    }
121
122    impl From<CompanyRecord> for ClientRecord {
123        fn from(value: CompanyRecord) -> Self {
124            Self {
125                id: value.id,
126                name: value.name,
127            }
128        }
129    }
130
131    impl StiType for FirmRecord {
132        fn sti_name() -> &'static str {
133            "Firm"
134        }
135    }
136
137    impl StiType for ClientRecord {
138        fn sti_name() -> &'static str {
139            "Client"
140        }
141    }
142
143    impl SingleTableInheritance for CompanyRecord {
144        fn inheritance_config() -> &'static InheritanceConfig {
145            static CONFIG: LazyLock<InheritanceConfig> =
146                LazyLock::new(|| InheritanceConfig::new("record_type"));
147            &CONFIG
148        }
149    }
150
151    #[test]
152    fn default_config_uses_type_column() {
153        assert_eq!(InheritanceConfig::default().inheritance_column, "type");
154    }
155
156    #[test]
157    fn custom_config_uses_custom_column() {
158        assert_eq!(
159            CompanyRecord::inheritance_config().inheritance_column,
160            "record_type"
161        );
162    }
163
164    #[test]
165    fn becomes_casts_between_subtypes() {
166        let company = CompanyRecord {
167            id: 1,
168            name: "Acme".to_owned(),
169            record_type: "Firm".to_owned(),
170        };
171
172        let firm: FirmRecord = becomes(company);
173        assert_eq!(firm.id, 1);
174        assert_eq!(firm.name, "Acme");
175    }
176
177    #[test]
178    fn scope_for_type_adds_discriminator() {
179        let scope =
180            scope_for_type::<FirmRecord>(CompanyRecord::inheritance_config(), HashMap::new());
181        assert_eq!(scope.get("record_type"), Some(&json!("Firm")));
182    }
183
184    #[test]
185    fn scope_for_type_preserves_existing_conditions() {
186        let scope = scope_for_type::<ClientRecord>(
187            CompanyRecord::inheritance_config(),
188            HashMap::from([("active".to_owned(), json!(true))]),
189        );
190
191        assert_eq!(scope.get("active"), Some(&json!(true)));
192        assert_eq!(scope.get("record_type"), Some(&json!("Client")));
193    }
194
195    #[test]
196    fn matches_type_checks_discriminator_column() {
197        let attrs = HashMap::from([("record_type".to_owned(), json!("Firm"))]);
198        assert!(matches_type::<FirmRecord>(
199            CompanyRecord::inheritance_config(),
200            &attrs
201        ));
202        assert!(!matches_type::<ClientRecord>(
203            CompanyRecord::inheritance_config(),
204            &attrs
205        ));
206    }
207
208    #[test]
209    fn matches_type_returns_false_when_discriminator_missing() {
210        let attrs = HashMap::new();
211        assert!(!matches_type::<FirmRecord>(
212            CompanyRecord::inheritance_config(),
213            &attrs
214        ));
215    }
216}