rustrails_record/
inheritance.rs1use std::collections::HashMap;
2
3use serde_json::Value;
4
5#[derive(Debug, Clone, PartialEq, Eq)]
7pub struct InheritanceConfig {
8 pub inheritance_column: String,
10}
11
12impl Default for InheritanceConfig {
13 fn default() -> Self {
14 Self {
15 inheritance_column: "type".to_owned(),
16 }
17 }
18}
19
20impl InheritanceConfig {
21 #[must_use]
23 pub fn new(inheritance_column: impl Into<String>) -> Self {
24 Self {
25 inheritance_column: inheritance_column.into(),
26 }
27 }
28}
29
30pub trait StiType {
32 fn sti_name() -> &'static str;
34}
35
36pub trait SingleTableInheritance {
38 fn inheritance_config() -> &'static InheritanceConfig {
40 static DEFAULT_CONFIG: std::sync::LazyLock<InheritanceConfig> =
41 std::sync::LazyLock::new(InheritanceConfig::default);
42 &DEFAULT_CONFIG
43 }
44}
45
46#[must_use]
48pub fn becomes<T, R>(record: R) -> T
49where
50 T: From<R>,
51{
52 T::from(record)
53}
54
55#[must_use]
57pub fn scope_for_type<T>(
58 config: &InheritanceConfig,
59 mut conditions: HashMap<String, Value>,
60) -> HashMap<String, Value>
61where
62 T: StiType,
63{
64 conditions.insert(
65 config.inheritance_column.clone(),
66 Value::String(T::sti_name().to_owned()),
67 );
68 conditions
69}
70
71#[must_use]
73pub fn matches_type<T>(config: &InheritanceConfig, attributes: &HashMap<String, Value>) -> bool
74where
75 T: StiType,
76{
77 attributes
78 .get(&config.inheritance_column)
79 .and_then(Value::as_str)
80 == Some(T::sti_name())
81}
82
83#[cfg(test)]
84mod tests {
85 use std::collections::HashMap;
86 use std::sync::LazyLock;
87
88 use serde_json::json;
89
90 use super::{
91 InheritanceConfig, SingleTableInheritance, StiType, becomes, matches_type, scope_for_type,
92 };
93
94 #[derive(Debug, Clone, PartialEq, Eq)]
95 struct CompanyRecord {
96 id: i64,
97 name: String,
98 record_type: String,
99 }
100
101 #[derive(Debug, Clone, PartialEq, Eq)]
102 struct FirmRecord {
103 id: i64,
104 name: String,
105 }
106
107 #[derive(Debug, Clone, PartialEq, Eq)]
108 struct ClientRecord {
109 id: i64,
110 name: String,
111 }
112
113 impl From<CompanyRecord> for FirmRecord {
114 fn from(value: CompanyRecord) -> Self {
115 Self {
116 id: value.id,
117 name: value.name,
118 }
119 }
120 }
121
122 impl From<CompanyRecord> for ClientRecord {
123 fn from(value: CompanyRecord) -> Self {
124 Self {
125 id: value.id,
126 name: value.name,
127 }
128 }
129 }
130
131 impl StiType for FirmRecord {
132 fn sti_name() -> &'static str {
133 "Firm"
134 }
135 }
136
137 impl StiType for ClientRecord {
138 fn sti_name() -> &'static str {
139 "Client"
140 }
141 }
142
143 impl SingleTableInheritance for CompanyRecord {
144 fn inheritance_config() -> &'static InheritanceConfig {
145 static CONFIG: LazyLock<InheritanceConfig> =
146 LazyLock::new(|| InheritanceConfig::new("record_type"));
147 &CONFIG
148 }
149 }
150
151 #[test]
152 fn default_config_uses_type_column() {
153 assert_eq!(InheritanceConfig::default().inheritance_column, "type");
154 }
155
156 #[test]
157 fn custom_config_uses_custom_column() {
158 assert_eq!(
159 CompanyRecord::inheritance_config().inheritance_column,
160 "record_type"
161 );
162 }
163
164 #[test]
165 fn becomes_casts_between_subtypes() {
166 let company = CompanyRecord {
167 id: 1,
168 name: "Acme".to_owned(),
169 record_type: "Firm".to_owned(),
170 };
171
172 let firm: FirmRecord = becomes(company);
173 assert_eq!(firm.id, 1);
174 assert_eq!(firm.name, "Acme");
175 }
176
177 #[test]
178 fn scope_for_type_adds_discriminator() {
179 let scope =
180 scope_for_type::<FirmRecord>(CompanyRecord::inheritance_config(), HashMap::new());
181 assert_eq!(scope.get("record_type"), Some(&json!("Firm")));
182 }
183
184 #[test]
185 fn scope_for_type_preserves_existing_conditions() {
186 let scope = scope_for_type::<ClientRecord>(
187 CompanyRecord::inheritance_config(),
188 HashMap::from([("active".to_owned(), json!(true))]),
189 );
190
191 assert_eq!(scope.get("active"), Some(&json!(true)));
192 assert_eq!(scope.get("record_type"), Some(&json!("Client")));
193 }
194
195 #[test]
196 fn matches_type_checks_discriminator_column() {
197 let attrs = HashMap::from([("record_type".to_owned(), json!("Firm"))]);
198 assert!(matches_type::<FirmRecord>(
199 CompanyRecord::inheritance_config(),
200 &attrs
201 ));
202 assert!(!matches_type::<ClientRecord>(
203 CompanyRecord::inheritance_config(),
204 &attrs
205 ));
206 }
207
208 #[test]
209 fn matches_type_returns_false_when_discriminator_missing() {
210 let attrs = HashMap::new();
211 assert!(!matches_type::<FirmRecord>(
212 CompanyRecord::inheritance_config(),
213 &attrs
214 ));
215 }
216}