Skip to main content

sqlmodel_query/
eager.rs

1//! Eager loading infrastructure for relationships.
2//!
3//! This module provides the `EagerLoader` builder for configuring which
4//! relationships to load with a query. Eager loading fetches related
5//! objects in the same query using SQL JOINs.
6
7use sqlmodel_core::{Model, RelationshipInfo, RelationshipKind, Value};
8use std::marker::PhantomData;
9
10/// Builder for eager loading configuration.
11///
12/// # Example
13///
14/// ```ignore
15/// let heroes = select!(Hero)
16///     .eager(EagerLoader::new().include("team"))
17///     .all_eager(&conn)
18///     .await?;
19/// ```
20#[derive(Debug, Clone)]
21pub struct EagerLoader<T: Model> {
22    /// Relationships to eager-load.
23    includes: Vec<IncludePath>,
24    /// Model type marker.
25    _marker: PhantomData<T>,
26}
27
28/// A path to a relationship to include.
29#[derive(Debug, Clone)]
30pub struct IncludePath {
31    /// Relationship name on parent.
32    pub relationship: &'static str,
33    /// Nested relationships to load.
34    pub nested: Vec<IncludePath>,
35}
36
37impl IncludePath {
38    /// Create a new include path for a single relationship.
39    #[must_use]
40    pub fn new(relationship: &'static str) -> Self {
41        Self {
42            relationship,
43            nested: Vec::new(),
44        }
45    }
46
47    /// Add a nested relationship to load.
48    #[must_use]
49    pub fn nest(mut self, path: IncludePath) -> Self {
50        self.nested.push(path);
51        self
52    }
53}
54
55impl<T: Model> EagerLoader<T> {
56    /// Create a new empty eager loader.
57    #[must_use]
58    pub fn new() -> Self {
59        Self {
60            includes: Vec::new(),
61            _marker: PhantomData,
62        }
63    }
64
65    /// Include a relationship in eager loading.
66    ///
67    /// # Example
68    ///
69    /// ```ignore
70    /// EagerLoader::<Hero>::new().include("team")
71    /// ```
72    #[must_use]
73    pub fn include(mut self, relationship: &'static str) -> Self {
74        self.includes.push(IncludePath::new(relationship));
75        self
76    }
77
78    /// Include a nested relationship (e.g., "team.headquarters").
79    ///
80    /// # Example
81    ///
82    /// ```ignore
83    /// EagerLoader::<Hero>::new().include_nested("team.headquarters")
84    /// ```
85    #[must_use]
86    pub fn include_nested(mut self, path: &'static str) -> Self {
87        // Handle empty or whitespace-only paths
88        let path = path.trim();
89        if path.is_empty() {
90            return self;
91        }
92
93        let parts: Vec<&'static str> = path.split('.').collect();
94        // split('.') on non-empty string always returns at least one element
95        // but we should still guard against [""] from paths like "."
96        if parts.iter().all(|p| p.is_empty()) {
97            return self;
98        }
99
100        // Filter out empty parts (handles cases like "team..headquarters")
101        let parts: Vec<&'static str> = parts.into_iter().filter(|p| !p.is_empty()).collect();
102        if parts.is_empty() {
103            return self;
104        }
105
106        // Build nested IncludePath structure
107        let include = Self::build_nested_path(&parts);
108        self.includes.push(include);
109        self
110    }
111
112    /// Build a nested IncludePath from path parts.
113    fn build_nested_path(parts: &[&'static str]) -> IncludePath {
114        if parts.len() == 1 {
115            IncludePath::new(parts[0])
116        } else {
117            let mut path = IncludePath::new(parts[0]);
118            path.nested.push(Self::build_nested_path(&parts[1..]));
119            path
120        }
121    }
122
123    /// Get the include paths.
124    #[must_use]
125    pub fn includes(&self) -> &[IncludePath] {
126        &self.includes
127    }
128
129    /// Check if any relationships are included.
130    #[must_use]
131    pub fn has_includes(&self) -> bool {
132        !self.includes.is_empty()
133    }
134}
135
136impl<T: Model> Default for EagerLoader<T> {
137    fn default() -> Self {
138        Self::new()
139    }
140}
141
142/// Find a relationship by name in a model's RELATIONSHIPS.
143#[must_use]
144pub fn find_relationship<M: Model>(name: &str) -> Option<&'static RelationshipInfo> {
145    M::RELATIONSHIPS.iter().find(|r| r.name == name)
146}
147
148/// Generate a JOIN clause for a relationship.
149#[must_use]
150pub fn build_join_clause(
151    parent_table: &str,
152    rel: &RelationshipInfo,
153    _param_offset: usize,
154) -> (String, Vec<Value>) {
155    let params = Vec::new();
156
157    // Get the primary key column name from the relationship, defaulting to "id"
158    let remote_pk = rel.remote_key.unwrap_or("id");
159
160    let sql = match rel.kind {
161        RelationshipKind::ManyToOne | RelationshipKind::OneToOne => {
162            // LEFT JOIN related_table ON parent.fk = related.pk
163            let local_key = rel.local_key.unwrap_or("id");
164            format!(
165                " LEFT JOIN {} ON {}.{} = {}.{}",
166                rel.related_table, parent_table, local_key, rel.related_table, remote_pk
167            )
168        }
169        RelationshipKind::OneToMany => {
170            // LEFT JOIN related_table ON related.fk = parent.pk
171            // For OneToMany, remote_key is the FK on the related table pointing to us
172            let fk_on_related = rel.remote_key.unwrap_or("id");
173            // And we need local_key as our PK (default "id")
174            let local_pk = rel.local_key.unwrap_or("id");
175            format!(
176                " LEFT JOIN {} ON {}.{} = {}.{}",
177                rel.related_table, rel.related_table, fk_on_related, parent_table, local_pk
178            )
179        }
180        RelationshipKind::ManyToMany => {
181            // LEFT JOIN link_table ON parent.pk = link.local_col
182            // LEFT JOIN related_table ON link.remote_col = related.pk
183            if let Some(link) = &rel.link_table {
184                let local_pk = rel.local_key.unwrap_or("id");
185                let Some(link_local_col) = link.local_cols().first().copied() else {
186                    return (String::new(), params);
187                };
188                let Some(link_remote_col) = link.remote_cols().first().copied() else {
189                    return (String::new(), params);
190                };
191                format!(
192                    " LEFT JOIN {} ON {}.{} = {}.{} LEFT JOIN {} ON {}.{} = {}.{}",
193                    link.table_name,
194                    parent_table,
195                    local_pk,
196                    link.table_name,
197                    link_local_col,
198                    rel.related_table,
199                    link.table_name,
200                    link_remote_col,
201                    rel.related_table,
202                    remote_pk
203                )
204            } else {
205                String::new()
206            }
207        }
208    };
209
210    (sql, params)
211}
212
213/// Generate aliased column names for eager loading.
214///
215/// Prefixes each column with the table name to avoid conflicts.
216#[must_use]
217pub fn build_aliased_column_parts(table_name: &str, columns: &[&str]) -> Vec<String> {
218    columns
219        .iter()
220        .map(|col| format!("{}.{} AS {}__{}", table_name, col, table_name, col))
221        .collect()
222}
223
224/// Generate aliased column list for eager loading.
225///
226/// Prefixes each column with the table name to avoid conflicts.
227#[must_use]
228pub fn build_aliased_columns(table_name: &str, columns: &[&str]) -> String {
229    build_aliased_column_parts(table_name, columns).join(", ")
230}
231
232#[cfg(test)]
233mod tests {
234    use super::*;
235    use sqlmodel_core::{Error, FieldInfo, Model, Result, Row, Value};
236
237    #[derive(Debug, Clone)]
238    struct TestHero;
239
240    impl Model for TestHero {
241        const TABLE_NAME: &'static str = "heroes";
242        const PRIMARY_KEY: &'static [&'static str] = &["id"];
243        const RELATIONSHIPS: &'static [RelationshipInfo] =
244            &[
245                RelationshipInfo::new("team", "teams", RelationshipKind::ManyToOne)
246                    .local_key("team_id"),
247            ];
248
249        fn fields() -> &'static [FieldInfo] {
250            &[]
251        }
252
253        fn to_row(&self) -> Vec<(&'static str, Value)> {
254            vec![]
255        }
256
257        fn from_row(_row: &Row) -> Result<Self> {
258            Err(Error::Custom("not used".to_string()))
259        }
260
261        fn primary_key_value(&self) -> Vec<Value> {
262            vec![]
263        }
264
265        fn is_new(&self) -> bool {
266            true
267        }
268    }
269
270    #[test]
271    fn test_eager_loader_new() {
272        let loader = EagerLoader::<TestHero>::new();
273        assert!(!loader.has_includes());
274        assert!(loader.includes().is_empty());
275    }
276
277    #[test]
278    fn test_eager_loader_include() {
279        let loader = EagerLoader::<TestHero>::new().include("team");
280        assert!(loader.has_includes());
281        assert_eq!(loader.includes().len(), 1);
282        assert_eq!(loader.includes()[0].relationship, "team");
283    }
284
285    #[test]
286    fn test_eager_loader_multiple_includes() {
287        let loader = EagerLoader::<TestHero>::new()
288            .include("team")
289            .include("powers");
290        assert_eq!(loader.includes().len(), 2);
291    }
292
293    #[test]
294    fn test_eager_loader_include_nested() {
295        let loader = EagerLoader::<TestHero>::new().include_nested("team.headquarters");
296        assert_eq!(loader.includes().len(), 1);
297        assert_eq!(loader.includes()[0].relationship, "team");
298        assert_eq!(loader.includes()[0].nested.len(), 1);
299        assert_eq!(loader.includes()[0].nested[0].relationship, "headquarters");
300    }
301
302    #[test]
303    fn test_eager_loader_include_deeply_nested() {
304        let loader =
305            EagerLoader::<TestHero>::new().include_nested("team.headquarters.city.country");
306        assert_eq!(loader.includes().len(), 1);
307        assert_eq!(loader.includes()[0].relationship, "team");
308        assert_eq!(loader.includes()[0].nested[0].relationship, "headquarters");
309        assert_eq!(
310            loader.includes()[0].nested[0].nested[0].relationship,
311            "city"
312        );
313        assert_eq!(
314            loader.includes()[0].nested[0].nested[0].nested[0].relationship,
315            "country"
316        );
317    }
318
319    #[test]
320    fn test_find_relationship() {
321        let rel = find_relationship::<TestHero>("team");
322        assert!(rel.is_some());
323        assert_eq!(rel.unwrap().name, "team");
324        assert_eq!(rel.unwrap().related_table, "teams");
325    }
326
327    #[test]
328    fn test_find_relationship_not_found() {
329        let rel = find_relationship::<TestHero>("nonexistent");
330        assert!(rel.is_none());
331    }
332
333    #[test]
334    fn test_build_join_many_to_one() {
335        let rel = RelationshipInfo::new("team", "teams", RelationshipKind::ManyToOne)
336            .local_key("team_id");
337
338        let (sql, params) = build_join_clause("heroes", &rel, 0);
339
340        assert_eq!(sql, " LEFT JOIN teams ON heroes.team_id = teams.id");
341        assert!(params.is_empty());
342    }
343
344    #[test]
345    fn test_build_join_one_to_many() {
346        let rel = RelationshipInfo::new("heroes", "heroes", RelationshipKind::OneToMany)
347            .remote_key("team_id");
348
349        let (sql, params) = build_join_clause("teams", &rel, 0);
350
351        assert_eq!(sql, " LEFT JOIN heroes ON heroes.team_id = teams.id");
352        assert!(params.is_empty());
353    }
354
355    #[test]
356    fn test_build_join_many_to_many() {
357        let rel =
358            RelationshipInfo::new("powers", "powers", RelationshipKind::ManyToMany).link_table(
359                sqlmodel_core::LinkTableInfo::new("hero_powers", "hero_id", "power_id"),
360            );
361
362        let (sql, params) = build_join_clause("heroes", &rel, 0);
363
364        assert!(sql.contains("LEFT JOIN hero_powers"));
365        assert!(sql.contains("LEFT JOIN powers"));
366        assert!(params.is_empty());
367    }
368
369    #[test]
370    fn test_build_aliased_columns() {
371        let result = build_aliased_columns("heroes", &["id", "name", "team_id"]);
372        assert!(result.contains("heroes.id AS heroes__id"));
373        assert!(result.contains("heroes.name AS heroes__name"));
374        assert!(result.contains("heroes.team_id AS heroes__team_id"));
375    }
376
377    #[test]
378    fn test_eager_loader_default() {
379        let loader: EagerLoader<TestHero> = EagerLoader::default();
380        assert!(!loader.has_includes());
381    }
382
383    #[test]
384    fn test_include_path_new() {
385        let path = IncludePath::new("team");
386        assert_eq!(path.relationship, "team");
387        assert!(path.nested.is_empty());
388    }
389
390    #[test]
391    fn test_include_path_nest() {
392        let path = IncludePath::new("team").nest(IncludePath::new("headquarters"));
393        assert_eq!(path.nested.len(), 1);
394        assert_eq!(path.nested[0].relationship, "headquarters");
395    }
396}