sea_orm/query/
loader.rs

1use crate::{
2    error::*, Condition, ConnectionTrait, DbErr, EntityTrait, Identity, ModelTrait, QueryFilter,
3    Related, RelationType, Select,
4};
5use async_trait::async_trait;
6use sea_query::{ColumnRef, DynIden, Expr, IntoColumnRef, SimpleExpr, TableRef, ValueTuple};
7use std::{collections::HashMap, str::FromStr};
8
9/// Entity, or a Select<Entity>; to be used as parameters in [`LoaderTrait`]
10pub trait EntityOrSelect<E: EntityTrait>: Send {
11    /// If self is Entity, use Entity::find()
12    fn select(self) -> Select<E>;
13}
14
15/// This trait implements the Data Loader API
16#[async_trait]
17pub trait LoaderTrait {
18    /// Source model
19    type Model: ModelTrait;
20
21    /// Used to eager load has_one relations
22    async fn load_one<R, S, C>(&self, stmt: S, db: &C) -> Result<Vec<Option<R::Model>>, DbErr>
23    where
24        C: ConnectionTrait,
25        R: EntityTrait,
26        R::Model: Send + Sync,
27        S: EntityOrSelect<R>,
28        <<Self as LoaderTrait>::Model as ModelTrait>::Entity: Related<R>;
29
30    /// Used to eager load has_many relations
31    async fn load_many<R, S, C>(&self, stmt: S, db: &C) -> Result<Vec<Vec<R::Model>>, DbErr>
32    where
33        C: ConnectionTrait,
34        R: EntityTrait,
35        R::Model: Send + Sync,
36        S: EntityOrSelect<R>,
37        <<Self as LoaderTrait>::Model as ModelTrait>::Entity: Related<R>;
38
39    /// Used to eager load many_to_many relations
40    async fn load_many_to_many<R, S, V, C>(
41        &self,
42        stmt: S,
43        via: V,
44        db: &C,
45    ) -> Result<Vec<Vec<R::Model>>, DbErr>
46    where
47        C: ConnectionTrait,
48        R: EntityTrait,
49        R::Model: Send + Sync,
50        S: EntityOrSelect<R>,
51        V: EntityTrait,
52        V::Model: Send + Sync,
53        <<Self as LoaderTrait>::Model as ModelTrait>::Entity: Related<R>;
54}
55
56impl<E> EntityOrSelect<E> for E
57where
58    E: EntityTrait,
59{
60    fn select(self) -> Select<E> {
61        E::find()
62    }
63}
64
65impl<E> EntityOrSelect<E> for Select<E>
66where
67    E: EntityTrait,
68{
69    fn select(self) -> Select<E> {
70        self
71    }
72}
73
74#[async_trait]
75impl<M> LoaderTrait for Vec<M>
76where
77    M: ModelTrait + Sync,
78{
79    type Model = M;
80
81    async fn load_one<R, S, C>(&self, stmt: S, db: &C) -> Result<Vec<Option<R::Model>>, DbErr>
82    where
83        C: ConnectionTrait,
84        R: EntityTrait,
85        R::Model: Send + Sync,
86        S: EntityOrSelect<R>,
87        <<Self as LoaderTrait>::Model as ModelTrait>::Entity: Related<R>,
88    {
89        self.as_slice().load_one(stmt, db).await
90    }
91
92    async fn load_many<R, S, C>(&self, stmt: S, db: &C) -> Result<Vec<Vec<R::Model>>, DbErr>
93    where
94        C: ConnectionTrait,
95        R: EntityTrait,
96        R::Model: Send + Sync,
97        S: EntityOrSelect<R>,
98        <<Self as LoaderTrait>::Model as ModelTrait>::Entity: Related<R>,
99    {
100        self.as_slice().load_many(stmt, db).await
101    }
102
103    async fn load_many_to_many<R, S, V, C>(
104        &self,
105        stmt: S,
106        via: V,
107        db: &C,
108    ) -> Result<Vec<Vec<R::Model>>, DbErr>
109    where
110        C: ConnectionTrait,
111        R: EntityTrait,
112        R::Model: Send + Sync,
113        S: EntityOrSelect<R>,
114        V: EntityTrait,
115        V::Model: Send + Sync,
116        <<Self as LoaderTrait>::Model as ModelTrait>::Entity: Related<R>,
117    {
118        self.as_slice().load_many_to_many(stmt, via, db).await
119    }
120}
121
122#[async_trait]
123impl<M> LoaderTrait for &[M]
124where
125    M: ModelTrait + Sync,
126{
127    type Model = M;
128
129    async fn load_one<R, S, C>(&self, stmt: S, db: &C) -> Result<Vec<Option<R::Model>>, DbErr>
130    where
131        C: ConnectionTrait,
132        R: EntityTrait,
133        R::Model: Send + Sync,
134        S: EntityOrSelect<R>,
135        <<Self as LoaderTrait>::Model as ModelTrait>::Entity: Related<R>,
136    {
137        // we verify that is HasOne relation
138        if <<<Self as LoaderTrait>::Model as ModelTrait>::Entity as Related<R>>::via().is_some() {
139            return Err(query_err("Relation is ManytoMany instead of HasOne"));
140        }
141        let rel_def = <<<Self as LoaderTrait>::Model as ModelTrait>::Entity as Related<R>>::to();
142        if rel_def.rel_type == RelationType::HasMany {
143            return Err(query_err("Relation is HasMany instead of HasOne"));
144        }
145
146        if self.is_empty() {
147            return Ok(Vec::new());
148        }
149
150        let keys: Vec<ValueTuple> = self
151            .iter()
152            .map(|model: &M| extract_key(&rel_def.from_col, model))
153            .collect();
154
155        let condition = prepare_condition(&rel_def.to_tbl, &rel_def.to_col, &keys);
156
157        let stmt = <Select<R> as QueryFilter>::filter(stmt.select(), condition);
158
159        let data = stmt.all(db).await?;
160
161        let hashmap: HashMap<ValueTuple, <R as EntityTrait>::Model> = data.into_iter().fold(
162            HashMap::new(),
163            |mut acc, value: <R as EntityTrait>::Model| {
164                {
165                    let key = extract_key(&rel_def.to_col, &value);
166                    acc.insert(key, value);
167                }
168
169                acc
170            },
171        );
172
173        let result: Vec<Option<<R as EntityTrait>::Model>> =
174            keys.iter().map(|key| hashmap.get(key).cloned()).collect();
175
176        Ok(result)
177    }
178
179    async fn load_many<R, S, C>(&self, stmt: S, db: &C) -> Result<Vec<Vec<R::Model>>, DbErr>
180    where
181        C: ConnectionTrait,
182        R: EntityTrait,
183        R::Model: Send + Sync,
184        S: EntityOrSelect<R>,
185        <<Self as LoaderTrait>::Model as ModelTrait>::Entity: Related<R>,
186    {
187        // we verify that is HasMany relation
188
189        if <<<Self as LoaderTrait>::Model as ModelTrait>::Entity as Related<R>>::via().is_some() {
190            return Err(query_err("Relation is ManyToMany instead of HasMany"));
191        }
192        let rel_def = <<<Self as LoaderTrait>::Model as ModelTrait>::Entity as Related<R>>::to();
193        if rel_def.rel_type == RelationType::HasOne {
194            return Err(query_err("Relation is HasOne instead of HasMany"));
195        }
196
197        if self.is_empty() {
198            return Ok(Vec::new());
199        }
200
201        let keys: Vec<ValueTuple> = self
202            .iter()
203            .map(|model: &M| extract_key(&rel_def.from_col, model))
204            .collect();
205
206        let condition = prepare_condition(&rel_def.to_tbl, &rel_def.to_col, &keys);
207
208        let stmt = <Select<R> as QueryFilter>::filter(stmt.select(), condition);
209
210        let data = stmt.all(db).await?;
211
212        let mut hashmap: HashMap<ValueTuple, Vec<<R as EntityTrait>::Model>> =
213            keys.iter()
214                .fold(HashMap::new(), |mut acc, key: &ValueTuple| {
215                    acc.insert(key.clone(), Vec::new());
216                    acc
217                });
218
219        data.into_iter()
220            .for_each(|value: <R as EntityTrait>::Model| {
221                let key = extract_key(&rel_def.to_col, &value);
222
223                let vec = hashmap
224                    .get_mut(&key)
225                    .expect("Failed at finding key on hashmap");
226
227                vec.push(value);
228            });
229
230        let result: Vec<Vec<R::Model>> = keys
231            .iter()
232            .map(|key: &ValueTuple| hashmap.get(key).cloned().unwrap_or_default())
233            .collect();
234
235        Ok(result)
236    }
237
238    async fn load_many_to_many<R, S, V, C>(
239        &self,
240        stmt: S,
241        via: V,
242        db: &C,
243    ) -> Result<Vec<Vec<R::Model>>, DbErr>
244    where
245        C: ConnectionTrait,
246        R: EntityTrait,
247        R::Model: Send + Sync,
248        S: EntityOrSelect<R>,
249        V: EntityTrait,
250        V::Model: Send + Sync,
251        <<Self as LoaderTrait>::Model as ModelTrait>::Entity: Related<R>,
252    {
253        if let Some(via_rel) =
254            <<<Self as LoaderTrait>::Model as ModelTrait>::Entity as Related<R>>::via()
255        {
256            let rel_def =
257                <<<Self as LoaderTrait>::Model as ModelTrait>::Entity as Related<R>>::to();
258            if rel_def.rel_type != RelationType::HasOne {
259                return Err(query_err("Relation to is not HasOne"));
260            }
261
262            if !cmp_table_ref(&via_rel.to_tbl, &via.table_ref()) {
263                return Err(query_err(format!(
264                    "The given via Entity is incorrect: expected: {:?}, given: {:?}",
265                    via_rel.to_tbl,
266                    via.table_ref()
267                )));
268            }
269
270            if self.is_empty() {
271                return Ok(Vec::new());
272            }
273
274            let pkeys: Vec<ValueTuple> = self
275                .iter()
276                .map(|model: &M| extract_key(&via_rel.from_col, model))
277                .collect();
278
279            // Map of M::PK -> Vec<R::PK>
280            let mut keymap: HashMap<ValueTuple, Vec<ValueTuple>> = Default::default();
281
282            let keys: Vec<ValueTuple> = {
283                let condition = prepare_condition(&via_rel.to_tbl, &via_rel.to_col, &pkeys);
284                let stmt = V::find().filter(condition);
285                let data = stmt.all(db).await?;
286                data.into_iter().for_each(|model| {
287                    let pk = extract_key(&via_rel.to_col, &model);
288                    let entry = keymap.entry(pk).or_default();
289
290                    let fk = extract_key(&rel_def.from_col, &model);
291                    entry.push(fk);
292                });
293
294                keymap.values().flatten().cloned().collect()
295            };
296
297            let condition = prepare_condition(&rel_def.to_tbl, &rel_def.to_col, &keys);
298
299            let stmt = <Select<R> as QueryFilter>::filter(stmt.select(), condition);
300
301            let data = stmt.all(db).await?;
302
303            // Map of R::PK -> R::Model
304            let data: HashMap<ValueTuple, <R as EntityTrait>::Model> = data
305                .into_iter()
306                .map(|model| {
307                    let key = extract_key(&rel_def.to_col, &model);
308                    (key, model)
309                })
310                .collect();
311
312            let result: Vec<Vec<R::Model>> = pkeys
313                .into_iter()
314                .map(|pkey| {
315                    let fkeys = keymap.get(&pkey).cloned().unwrap_or_default();
316
317                    let models: Vec<_> = fkeys
318                        .into_iter()
319                        .filter_map(|fkey| data.get(&fkey).cloned())
320                        .collect();
321
322                    models
323                })
324                .collect();
325
326            Ok(result)
327        } else {
328            return Err(query_err("Relation is not ManyToMany"));
329        }
330    }
331}
332
333fn cmp_table_ref(left: &TableRef, right: &TableRef) -> bool {
334    // not ideal; but
335    format!("{left:?}") == format!("{right:?}")
336}
337
338fn extract_key<Model>(target_col: &Identity, model: &Model) -> ValueTuple
339where
340    Model: ModelTrait,
341{
342    match target_col {
343        Identity::Unary(a) => {
344            let column_a =
345                <<<Model as ModelTrait>::Entity as EntityTrait>::Column as FromStr>::from_str(
346                    &a.to_string(),
347                )
348                .unwrap_or_else(|_| panic!("Failed at mapping string to column A:1"));
349            ValueTuple::One(model.get(column_a))
350        }
351        Identity::Binary(a, b) => {
352            let column_a =
353                <<<Model as ModelTrait>::Entity as EntityTrait>::Column as FromStr>::from_str(
354                    &a.to_string(),
355                )
356                .unwrap_or_else(|_| panic!("Failed at mapping string to column A:2"));
357            let column_b =
358                <<<Model as ModelTrait>::Entity as EntityTrait>::Column as FromStr>::from_str(
359                    &b.to_string(),
360                )
361                .unwrap_or_else(|_| panic!("Failed at mapping string to column B:2"));
362            ValueTuple::Two(model.get(column_a), model.get(column_b))
363        }
364        Identity::Ternary(a, b, c) => {
365            let column_a =
366                <<<Model as ModelTrait>::Entity as EntityTrait>::Column as FromStr>::from_str(
367                    &a.to_string(),
368                )
369                .unwrap_or_else(|_| panic!("Failed at mapping string to column A:3"));
370            let column_b =
371                <<<Model as ModelTrait>::Entity as EntityTrait>::Column as FromStr>::from_str(
372                    &b.to_string(),
373                )
374                .unwrap_or_else(|_| panic!("Failed at mapping string to column B:3"));
375            let column_c =
376                <<<Model as ModelTrait>::Entity as EntityTrait>::Column as FromStr>::from_str(
377                    &c.to_string(),
378                )
379                .unwrap_or_else(|_| panic!("Failed at mapping string to column C:3"));
380            ValueTuple::Three(
381                model.get(column_a),
382                model.get(column_b),
383                model.get(column_c),
384            )
385        }
386        Identity::Many(cols) => {
387            let values = cols.iter().map(|col| {
388                let col_name = col.to_string();
389                let column = <<<Model as ModelTrait>::Entity as EntityTrait>::Column as FromStr>::from_str(
390                    &col_name,
391                )
392                .unwrap_or_else(|_| panic!("Failed at mapping '{}' to column", col_name));
393                model.get(column)
394            })
395            .collect();
396            ValueTuple::Many(values)
397        }
398    }
399}
400
401fn prepare_condition(table: &TableRef, col: &Identity, keys: &[ValueTuple]) -> Condition {
402    // TODO when value is hashable, retain only unique values
403    let keys = keys.to_owned();
404    match col {
405        Identity::Unary(column_a) => {
406            let column_a = table_column(table, column_a);
407            Condition::all().add(Expr::col(column_a).is_in(keys.into_iter().flatten()))
408        }
409        Identity::Binary(column_a, column_b) => Condition::all().add(
410            Expr::tuple([
411                SimpleExpr::Column(table_column(table, column_a)),
412                SimpleExpr::Column(table_column(table, column_b)),
413            ])
414            .in_tuples(keys),
415        ),
416        Identity::Ternary(column_a, column_b, column_c) => Condition::all().add(
417            Expr::tuple([
418                SimpleExpr::Column(table_column(table, column_a)),
419                SimpleExpr::Column(table_column(table, column_b)),
420                SimpleExpr::Column(table_column(table, column_c)),
421            ])
422            .in_tuples(keys),
423        ),
424        Identity::Many(cols) => {
425            let columns = cols
426                .iter()
427                .map(|col| SimpleExpr::Column(table_column(table, col)));
428            Condition::all().add(Expr::tuple(columns).in_tuples(keys))
429        }
430    }
431}
432
433fn table_column(tbl: &TableRef, col: &DynIden) -> ColumnRef {
434    match tbl.to_owned() {
435        TableRef::Table(tbl) => (tbl, col.clone()).into_column_ref(),
436        TableRef::SchemaTable(sch, tbl) => (sch, tbl, col.clone()).into_column_ref(),
437        val => unimplemented!("Unsupported TableRef {val:?}"),
438    }
439}
440
441#[cfg(test)]
442mod tests {
443    fn cake_model(id: i32) -> sea_orm::tests_cfg::cake::Model {
444        let name = match id {
445            1 => "apple cake",
446            2 => "orange cake",
447            3 => "fruit cake",
448            4 => "chocolate cake",
449            _ => "",
450        }
451        .to_string();
452        sea_orm::tests_cfg::cake::Model { id, name }
453    }
454
455    fn fruit_model(id: i32, cake_id: Option<i32>) -> sea_orm::tests_cfg::fruit::Model {
456        let name = match id {
457            1 => "apple",
458            2 => "orange",
459            3 => "grape",
460            4 => "strawberry",
461            _ => "",
462        }
463        .to_string();
464        sea_orm::tests_cfg::fruit::Model { id, name, cake_id }
465    }
466
467    fn filling_model(id: i32) -> sea_orm::tests_cfg::filling::Model {
468        let name = match id {
469            1 => "apple juice",
470            2 => "orange jam",
471            3 => "chocolate crust",
472            4 => "strawberry jam",
473            _ => "",
474        }
475        .to_string();
476        sea_orm::tests_cfg::filling::Model {
477            id,
478            name,
479            vendor_id: Some(1),
480            ignored_attr: 0,
481        }
482    }
483
484    fn cake_filling_model(
485        cake_id: i32,
486        filling_id: i32,
487    ) -> sea_orm::tests_cfg::cake_filling::Model {
488        sea_orm::tests_cfg::cake_filling::Model {
489            cake_id,
490            filling_id,
491        }
492    }
493
494    #[tokio::test]
495    async fn test_load_one() {
496        use sea_orm::{entity::prelude::*, tests_cfg::*, DbBackend, LoaderTrait, MockDatabase};
497
498        let db = MockDatabase::new(DbBackend::Postgres)
499            .append_query_results([[cake_model(1), cake_model(2)]])
500            .into_connection();
501
502        let fruits = vec![fruit_model(1, Some(1))];
503
504        let cakes = fruits
505            .load_one(cake::Entity::find(), &db)
506            .await
507            .expect("Should return something");
508
509        assert_eq!(cakes, [Some(cake_model(1))]);
510    }
511
512    #[tokio::test]
513    async fn test_load_one_same_cake() {
514        use sea_orm::{entity::prelude::*, tests_cfg::*, DbBackend, LoaderTrait, MockDatabase};
515
516        let db = MockDatabase::new(DbBackend::Postgres)
517            .append_query_results([[cake_model(1), cake_model(2)]])
518            .into_connection();
519
520        let fruits = vec![fruit_model(1, Some(1)), fruit_model(2, Some(1))];
521
522        let cakes = fruits
523            .load_one(cake::Entity::find(), &db)
524            .await
525            .expect("Should return something");
526
527        assert_eq!(cakes, [Some(cake_model(1)), Some(cake_model(1))]);
528    }
529
530    #[tokio::test]
531    async fn test_load_one_empty() {
532        use sea_orm::{entity::prelude::*, tests_cfg::*, DbBackend, LoaderTrait, MockDatabase};
533
534        let db = MockDatabase::new(DbBackend::Postgres)
535            .append_query_results([[cake_model(1), cake_model(2)]])
536            .into_connection();
537
538        let fruits: Vec<fruit::Model> = vec![];
539
540        let cakes = fruits
541            .load_one(cake::Entity::find(), &db)
542            .await
543            .expect("Should return something");
544
545        assert_eq!(cakes, []);
546    }
547
548    #[tokio::test]
549    async fn test_load_many() {
550        use sea_orm::{entity::prelude::*, tests_cfg::*, DbBackend, LoaderTrait, MockDatabase};
551
552        let db = MockDatabase::new(DbBackend::Postgres)
553            .append_query_results([[fruit_model(1, Some(1))]])
554            .into_connection();
555
556        let cakes = vec![cake_model(1), cake_model(2)];
557
558        let fruits = cakes
559            .load_many(fruit::Entity::find(), &db)
560            .await
561            .expect("Should return something");
562
563        assert_eq!(fruits, [vec![fruit_model(1, Some(1))], vec![]]);
564    }
565
566    #[tokio::test]
567    async fn test_load_many_same_fruit() {
568        use sea_orm::{entity::prelude::*, tests_cfg::*, DbBackend, LoaderTrait, MockDatabase};
569
570        let db = MockDatabase::new(DbBackend::Postgres)
571            .append_query_results([[fruit_model(1, Some(1)), fruit_model(2, Some(1))]])
572            .into_connection();
573
574        let cakes = vec![cake_model(1), cake_model(2)];
575
576        let fruits = cakes
577            .load_many(fruit::Entity::find(), &db)
578            .await
579            .expect("Should return something");
580
581        assert_eq!(
582            fruits,
583            [
584                vec![fruit_model(1, Some(1)), fruit_model(2, Some(1))],
585                vec![]
586            ]
587        );
588    }
589
590    #[tokio::test]
591    async fn test_load_many_empty() {
592        use sea_orm::{entity::prelude::*, tests_cfg::*, DbBackend, MockDatabase};
593
594        let db = MockDatabase::new(DbBackend::Postgres).into_connection();
595
596        let cakes: Vec<cake::Model> = vec![];
597
598        let fruits = cakes
599            .load_many(fruit::Entity::find(), &db)
600            .await
601            .expect("Should return something");
602
603        let empty_vec: Vec<Vec<fruit::Model>> = vec![];
604
605        assert_eq!(fruits, empty_vec);
606    }
607
608    #[tokio::test]
609    async fn test_load_many_to_many_base() {
610        use sea_orm::{tests_cfg::*, DbBackend, IntoMockRow, LoaderTrait, MockDatabase};
611
612        let db = MockDatabase::new(DbBackend::Postgres)
613            .append_query_results([
614                [cake_filling_model(1, 1).into_mock_row()],
615                [filling_model(1).into_mock_row()],
616            ])
617            .into_connection();
618
619        let cakes = vec![cake_model(1)];
620
621        let fillings = cakes
622            .load_many_to_many(Filling, CakeFilling, &db)
623            .await
624            .expect("Should return something");
625
626        assert_eq!(fillings, vec![vec![filling_model(1)]]);
627    }
628
629    #[tokio::test]
630    async fn test_load_many_to_many_complex() {
631        use sea_orm::{tests_cfg::*, DbBackend, IntoMockRow, LoaderTrait, MockDatabase};
632
633        let db = MockDatabase::new(DbBackend::Postgres)
634            .append_query_results([
635                [
636                    cake_filling_model(1, 1).into_mock_row(),
637                    cake_filling_model(1, 2).into_mock_row(),
638                    cake_filling_model(1, 3).into_mock_row(),
639                    cake_filling_model(2, 1).into_mock_row(),
640                    cake_filling_model(2, 2).into_mock_row(),
641                ],
642                [
643                    filling_model(1).into_mock_row(),
644                    filling_model(2).into_mock_row(),
645                    filling_model(3).into_mock_row(),
646                    filling_model(4).into_mock_row(),
647                    filling_model(5).into_mock_row(),
648                ],
649            ])
650            .into_connection();
651
652        let cakes = vec![cake_model(1), cake_model(2), cake_model(3)];
653
654        let fillings = cakes
655            .load_many_to_many(Filling, CakeFilling, &db)
656            .await
657            .expect("Should return something");
658
659        assert_eq!(
660            fillings,
661            vec![
662                vec![filling_model(1), filling_model(2), filling_model(3)],
663                vec![filling_model(1), filling_model(2)],
664                vec![],
665            ]
666        );
667    }
668
669    #[tokio::test]
670    async fn test_load_many_to_many_empty() {
671        use sea_orm::{tests_cfg::*, DbBackend, IntoMockRow, LoaderTrait, MockDatabase};
672
673        let db = MockDatabase::new(DbBackend::Postgres)
674            .append_query_results([
675                [cake_filling_model(1, 1).into_mock_row()],
676                [filling_model(1).into_mock_row()],
677            ])
678            .into_connection();
679
680        let cakes: Vec<cake::Model> = vec![];
681
682        let fillings = cakes
683            .load_many_to_many(Filling, CakeFilling, &db)
684            .await
685            .expect("Should return something");
686
687        let empty_vec: Vec<Vec<filling::Model>> = vec![];
688
689        assert_eq!(fillings, empty_vec);
690    }
691}