1use sqlmodel_core::{Model, RelationshipInfo, RelationshipKind, Value};
8use std::marker::PhantomData;
9
10#[derive(Debug, Clone)]
21pub struct EagerLoader<T: Model> {
22 includes: Vec<IncludePath>,
24 _marker: PhantomData<T>,
26}
27
28#[derive(Debug, Clone)]
30pub struct IncludePath {
31 pub relationship: &'static str,
33 pub nested: Vec<IncludePath>,
35}
36
37impl IncludePath {
38 #[must_use]
40 pub fn new(relationship: &'static str) -> Self {
41 Self {
42 relationship,
43 nested: Vec::new(),
44 }
45 }
46
47 #[must_use]
49 pub fn nest(mut self, path: IncludePath) -> Self {
50 self.nested.push(path);
51 self
52 }
53}
54
55impl<T: Model> EagerLoader<T> {
56 #[must_use]
58 pub fn new() -> Self {
59 Self {
60 includes: Vec::new(),
61 _marker: PhantomData,
62 }
63 }
64
65 #[must_use]
73 pub fn include(mut self, relationship: &'static str) -> Self {
74 self.includes.push(IncludePath::new(relationship));
75 self
76 }
77
78 #[must_use]
86 pub fn include_nested(mut self, path: &'static str) -> Self {
87 let path = path.trim();
89 if path.is_empty() {
90 return self;
91 }
92
93 let parts: Vec<&'static str> = path.split('.').collect();
94 if parts.iter().all(|p| p.is_empty()) {
97 return self;
98 }
99
100 let parts: Vec<&'static str> = parts.into_iter().filter(|p| !p.is_empty()).collect();
102 if parts.is_empty() {
103 return self;
104 }
105
106 let include = Self::build_nested_path(&parts);
108 self.includes.push(include);
109 self
110 }
111
112 fn build_nested_path(parts: &[&'static str]) -> IncludePath {
114 if parts.len() == 1 {
115 IncludePath::new(parts[0])
116 } else {
117 let mut path = IncludePath::new(parts[0]);
118 path.nested.push(Self::build_nested_path(&parts[1..]));
119 path
120 }
121 }
122
123 #[must_use]
125 pub fn includes(&self) -> &[IncludePath] {
126 &self.includes
127 }
128
129 #[must_use]
131 pub fn has_includes(&self) -> bool {
132 !self.includes.is_empty()
133 }
134}
135
136impl<T: Model> Default for EagerLoader<T> {
137 fn default() -> Self {
138 Self::new()
139 }
140}
141
142#[must_use]
144pub fn find_relationship<M: Model>(name: &str) -> Option<&'static RelationshipInfo> {
145 M::RELATIONSHIPS.iter().find(|r| r.name == name)
146}
147
148#[must_use]
150pub fn build_join_clause(
151 parent_table: &str,
152 rel: &RelationshipInfo,
153 _param_offset: usize,
154) -> (String, Vec<Value>) {
155 let params = Vec::new();
156
157 let remote_pk = rel.remote_key.unwrap_or("id");
159
160 let sql = match rel.kind {
161 RelationshipKind::ManyToOne | RelationshipKind::OneToOne => {
162 let local_key = rel.local_key.unwrap_or("id");
164 format!(
165 " LEFT JOIN {} ON {}.{} = {}.{}",
166 rel.related_table, parent_table, local_key, rel.related_table, remote_pk
167 )
168 }
169 RelationshipKind::OneToMany => {
170 let fk_on_related = rel.remote_key.unwrap_or("id");
173 let local_pk = rel.local_key.unwrap_or("id");
175 format!(
176 " LEFT JOIN {} ON {}.{} = {}.{}",
177 rel.related_table, rel.related_table, fk_on_related, parent_table, local_pk
178 )
179 }
180 RelationshipKind::ManyToMany => {
181 if let Some(link) = &rel.link_table {
184 let local_pk = rel.local_key.unwrap_or("id");
185 let Some(link_local_col) = link.local_cols().first().copied() else {
186 return (String::new(), params);
187 };
188 let Some(link_remote_col) = link.remote_cols().first().copied() else {
189 return (String::new(), params);
190 };
191 format!(
192 " LEFT JOIN {} ON {}.{} = {}.{} LEFT JOIN {} ON {}.{} = {}.{}",
193 link.table_name,
194 parent_table,
195 local_pk,
196 link.table_name,
197 link_local_col,
198 rel.related_table,
199 link.table_name,
200 link_remote_col,
201 rel.related_table,
202 remote_pk
203 )
204 } else {
205 String::new()
206 }
207 }
208 };
209
210 (sql, params)
211}
212
213#[must_use]
217pub fn build_aliased_column_parts(table_name: &str, columns: &[&str]) -> Vec<String> {
218 columns
219 .iter()
220 .map(|col| format!("{}.{} AS {}__{}", table_name, col, table_name, col))
221 .collect()
222}
223
224#[must_use]
228pub fn build_aliased_columns(table_name: &str, columns: &[&str]) -> String {
229 build_aliased_column_parts(table_name, columns).join(", ")
230}
231
232#[cfg(test)]
233mod tests {
234 use super::*;
235 use sqlmodel_core::{Error, FieldInfo, Model, Result, Row, Value};
236
237 #[derive(Debug, Clone)]
238 struct TestHero;
239
240 impl Model for TestHero {
241 const TABLE_NAME: &'static str = "heroes";
242 const PRIMARY_KEY: &'static [&'static str] = &["id"];
243 const RELATIONSHIPS: &'static [RelationshipInfo] =
244 &[
245 RelationshipInfo::new("team", "teams", RelationshipKind::ManyToOne)
246 .local_key("team_id"),
247 ];
248
249 fn fields() -> &'static [FieldInfo] {
250 &[]
251 }
252
253 fn to_row(&self) -> Vec<(&'static str, Value)> {
254 vec![]
255 }
256
257 fn from_row(_row: &Row) -> Result<Self> {
258 Err(Error::Custom("not used".to_string()))
259 }
260
261 fn primary_key_value(&self) -> Vec<Value> {
262 vec![]
263 }
264
265 fn is_new(&self) -> bool {
266 true
267 }
268 }
269
270 #[test]
271 fn test_eager_loader_new() {
272 let loader = EagerLoader::<TestHero>::new();
273 assert!(!loader.has_includes());
274 assert!(loader.includes().is_empty());
275 }
276
277 #[test]
278 fn test_eager_loader_include() {
279 let loader = EagerLoader::<TestHero>::new().include("team");
280 assert!(loader.has_includes());
281 assert_eq!(loader.includes().len(), 1);
282 assert_eq!(loader.includes()[0].relationship, "team");
283 }
284
285 #[test]
286 fn test_eager_loader_multiple_includes() {
287 let loader = EagerLoader::<TestHero>::new()
288 .include("team")
289 .include("powers");
290 assert_eq!(loader.includes().len(), 2);
291 }
292
293 #[test]
294 fn test_eager_loader_include_nested() {
295 let loader = EagerLoader::<TestHero>::new().include_nested("team.headquarters");
296 assert_eq!(loader.includes().len(), 1);
297 assert_eq!(loader.includes()[0].relationship, "team");
298 assert_eq!(loader.includes()[0].nested.len(), 1);
299 assert_eq!(loader.includes()[0].nested[0].relationship, "headquarters");
300 }
301
302 #[test]
303 fn test_eager_loader_include_deeply_nested() {
304 let loader =
305 EagerLoader::<TestHero>::new().include_nested("team.headquarters.city.country");
306 assert_eq!(loader.includes().len(), 1);
307 assert_eq!(loader.includes()[0].relationship, "team");
308 assert_eq!(loader.includes()[0].nested[0].relationship, "headquarters");
309 assert_eq!(
310 loader.includes()[0].nested[0].nested[0].relationship,
311 "city"
312 );
313 assert_eq!(
314 loader.includes()[0].nested[0].nested[0].nested[0].relationship,
315 "country"
316 );
317 }
318
319 #[test]
320 fn test_find_relationship() {
321 let rel = find_relationship::<TestHero>("team");
322 assert!(rel.is_some());
323 assert_eq!(rel.unwrap().name, "team");
324 assert_eq!(rel.unwrap().related_table, "teams");
325 }
326
327 #[test]
328 fn test_find_relationship_not_found() {
329 let rel = find_relationship::<TestHero>("nonexistent");
330 assert!(rel.is_none());
331 }
332
333 #[test]
334 fn test_build_join_many_to_one() {
335 let rel = RelationshipInfo::new("team", "teams", RelationshipKind::ManyToOne)
336 .local_key("team_id");
337
338 let (sql, params) = build_join_clause("heroes", &rel, 0);
339
340 assert_eq!(sql, " LEFT JOIN teams ON heroes.team_id = teams.id");
341 assert!(params.is_empty());
342 }
343
344 #[test]
345 fn test_build_join_one_to_many() {
346 let rel = RelationshipInfo::new("heroes", "heroes", RelationshipKind::OneToMany)
347 .remote_key("team_id");
348
349 let (sql, params) = build_join_clause("teams", &rel, 0);
350
351 assert_eq!(sql, " LEFT JOIN heroes ON heroes.team_id = teams.id");
352 assert!(params.is_empty());
353 }
354
355 #[test]
356 fn test_build_join_many_to_many() {
357 let rel =
358 RelationshipInfo::new("powers", "powers", RelationshipKind::ManyToMany).link_table(
359 sqlmodel_core::LinkTableInfo::new("hero_powers", "hero_id", "power_id"),
360 );
361
362 let (sql, params) = build_join_clause("heroes", &rel, 0);
363
364 assert!(sql.contains("LEFT JOIN hero_powers"));
365 assert!(sql.contains("LEFT JOIN powers"));
366 assert!(params.is_empty());
367 }
368
369 #[test]
370 fn test_build_aliased_columns() {
371 let result = build_aliased_columns("heroes", &["id", "name", "team_id"]);
372 assert!(result.contains("heroes.id AS heroes__id"));
373 assert!(result.contains("heroes.name AS heroes__name"));
374 assert!(result.contains("heroes.team_id AS heroes__team_id"));
375 }
376
377 #[test]
378 fn test_eager_loader_default() {
379 let loader: EagerLoader<TestHero> = EagerLoader::default();
380 assert!(!loader.has_includes());
381 }
382
383 #[test]
384 fn test_include_path_new() {
385 let path = IncludePath::new("team");
386 assert_eq!(path.relationship, "team");
387 assert!(path.nested.is_empty());
388 }
389
390 #[test]
391 fn test_include_path_nest() {
392 let path = IncludePath::new("team").nest(IncludePath::new("headquarters"));
393 assert_eq!(path.nested.len(), 1);
394 assert_eq!(path.nested[0].relationship, "headquarters");
395 }
396}