sea_orm/query/
loader.rs

1use crate::{
2    error::*, Condition, ConnectionTrait, DbErr, EntityTrait, Identity, ModelTrait, QueryFilter,
3    Related, RelationType, Select,
4};
5use async_trait::async_trait;
6use sea_query::{ColumnRef, DynIden, Expr, IntoColumnRef, SimpleExpr, TableRef, ValueTuple};
7use std::{
8    collections::{HashMap, HashSet},
9    str::FromStr,
10};
11
12/// Entity, or a Select<Entity>; to be used as parameters in [`LoaderTrait`]
13pub trait EntityOrSelect<E: EntityTrait>: Send {
14    /// If self is Entity, use Entity::find()
15    fn select(self) -> Select<E>;
16}
17
18/// This trait implements the Data Loader API
19#[async_trait]
20pub trait LoaderTrait {
21    /// Source model
22    type Model: ModelTrait;
23
24    /// Used to eager load has_one relations
25    async fn load_one<R, S, C>(&self, stmt: S, db: &C) -> Result<Vec<Option<R::Model>>, DbErr>
26    where
27        C: ConnectionTrait,
28        R: EntityTrait,
29        R::Model: Send + Sync,
30        S: EntityOrSelect<R>,
31        <<Self as LoaderTrait>::Model as ModelTrait>::Entity: Related<R>;
32
33    /// Used to eager load has_many relations
34    async fn load_many<R, S, C>(&self, stmt: S, db: &C) -> Result<Vec<Vec<R::Model>>, DbErr>
35    where
36        C: ConnectionTrait,
37        R: EntityTrait,
38        R::Model: Send + Sync,
39        S: EntityOrSelect<R>,
40        <<Self as LoaderTrait>::Model as ModelTrait>::Entity: Related<R>;
41
42    /// Used to eager load many_to_many relations
43    async fn load_many_to_many<R, S, V, C>(
44        &self,
45        stmt: S,
46        via: V,
47        db: &C,
48    ) -> Result<Vec<Vec<R::Model>>, DbErr>
49    where
50        C: ConnectionTrait,
51        R: EntityTrait,
52        R::Model: Send + Sync,
53        S: EntityOrSelect<R>,
54        V: EntityTrait,
55        V::Model: Send + Sync,
56        <<Self as LoaderTrait>::Model as ModelTrait>::Entity: Related<R>;
57}
58
59impl<E> EntityOrSelect<E> for E
60where
61    E: EntityTrait,
62{
63    fn select(self) -> Select<E> {
64        E::find()
65    }
66}
67
68impl<E> EntityOrSelect<E> for Select<E>
69where
70    E: EntityTrait,
71{
72    fn select(self) -> Select<E> {
73        self
74    }
75}
76
77#[async_trait]
78impl<M> LoaderTrait for Vec<M>
79where
80    M: ModelTrait + Sync,
81{
82    type Model = M;
83
84    async fn load_one<R, S, C>(&self, stmt: S, db: &C) -> Result<Vec<Option<R::Model>>, DbErr>
85    where
86        C: ConnectionTrait,
87        R: EntityTrait,
88        R::Model: Send + Sync,
89        S: EntityOrSelect<R>,
90        <<Self as LoaderTrait>::Model as ModelTrait>::Entity: Related<R>,
91    {
92        self.as_slice().load_one(stmt, db).await
93    }
94
95    async fn load_many<R, S, C>(&self, stmt: S, db: &C) -> Result<Vec<Vec<R::Model>>, DbErr>
96    where
97        C: ConnectionTrait,
98        R: EntityTrait,
99        R::Model: Send + Sync,
100        S: EntityOrSelect<R>,
101        <<Self as LoaderTrait>::Model as ModelTrait>::Entity: Related<R>,
102    {
103        self.as_slice().load_many(stmt, db).await
104    }
105
106    async fn load_many_to_many<R, S, V, C>(
107        &self,
108        stmt: S,
109        via: V,
110        db: &C,
111    ) -> Result<Vec<Vec<R::Model>>, DbErr>
112    where
113        C: ConnectionTrait,
114        R: EntityTrait,
115        R::Model: Send + Sync,
116        S: EntityOrSelect<R>,
117        V: EntityTrait,
118        V::Model: Send + Sync,
119        <<Self as LoaderTrait>::Model as ModelTrait>::Entity: Related<R>,
120    {
121        self.as_slice().load_many_to_many(stmt, via, db).await
122    }
123}
124
125#[async_trait]
126impl<M> LoaderTrait for &[M]
127where
128    M: ModelTrait + Sync,
129{
130    type Model = M;
131
132    async fn load_one<R, S, C>(&self, stmt: S, db: &C) -> Result<Vec<Option<R::Model>>, DbErr>
133    where
134        C: ConnectionTrait,
135        R: EntityTrait,
136        R::Model: Send + Sync,
137        S: EntityOrSelect<R>,
138        <<Self as LoaderTrait>::Model as ModelTrait>::Entity: Related<R>,
139    {
140        // we verify that is HasOne relation
141        if <<<Self as LoaderTrait>::Model as ModelTrait>::Entity as Related<R>>::via().is_some() {
142            return Err(query_err("Relation is ManytoMany instead of HasOne"));
143        }
144        let rel_def = <<<Self as LoaderTrait>::Model as ModelTrait>::Entity as Related<R>>::to();
145        if rel_def.rel_type == RelationType::HasMany {
146            return Err(query_err("Relation is HasMany instead of HasOne"));
147        }
148
149        if self.is_empty() {
150            return Ok(Vec::new());
151        }
152
153        let keys = self
154            .iter()
155            .map(|model| extract_key(&rel_def.from_col, model))
156            .collect::<Result<Vec<_>, _>>()?;
157
158        let condition = prepare_condition(&rel_def.to_tbl, &rel_def.to_col, &keys);
159
160        let stmt = <Select<R> as QueryFilter>::filter(stmt.select(), condition);
161
162        let data = stmt.all(db).await?;
163
164        let hashmap = data.into_iter().try_fold(
165            HashMap::<ValueTuple, <R as EntityTrait>::Model>::new(),
166            |mut acc, value| {
167                extract_key(&rel_def.to_col, &value).map(|key| {
168                    acc.insert(key, value);
169
170                    acc
171                })
172            },
173        )?;
174
175        let result: Vec<Option<<R as EntityTrait>::Model>> =
176            keys.iter().map(|key| hashmap.get(key).cloned()).collect();
177
178        Ok(result)
179    }
180
181    async fn load_many<R, S, C>(&self, stmt: S, db: &C) -> Result<Vec<Vec<R::Model>>, DbErr>
182    where
183        C: ConnectionTrait,
184        R: EntityTrait,
185        R::Model: Send + Sync,
186        S: EntityOrSelect<R>,
187        <<Self as LoaderTrait>::Model as ModelTrait>::Entity: Related<R>,
188    {
189        // we verify that is HasMany relation
190
191        if <<<Self as LoaderTrait>::Model as ModelTrait>::Entity as Related<R>>::via().is_some() {
192            return Err(query_err("Relation is ManyToMany instead of HasMany"));
193        }
194        let rel_def = <<<Self as LoaderTrait>::Model as ModelTrait>::Entity as Related<R>>::to();
195        if rel_def.rel_type == RelationType::HasOne {
196            return Err(query_err("Relation is HasOne instead of HasMany"));
197        }
198
199        if self.is_empty() {
200            return Ok(Vec::new());
201        }
202
203        let keys = self
204            .iter()
205            .map(|model| extract_key(&rel_def.from_col, model))
206            .collect::<Result<Vec<_>, _>>()?;
207
208        let condition = prepare_condition(&rel_def.to_tbl, &rel_def.to_col, &keys);
209
210        let stmt = <Select<R> as QueryFilter>::filter(stmt.select(), condition);
211
212        let data = stmt.all(db).await?;
213
214        let mut hashmap: HashMap<ValueTuple, Vec<<R as EntityTrait>::Model>> =
215            keys.iter()
216                .fold(HashMap::new(), |mut acc, key: &ValueTuple| {
217                    acc.insert(key.clone(), Vec::new());
218                    acc
219                });
220
221        for value in data {
222            let key = extract_key(&rel_def.to_col, &value)?;
223
224            let vec = hashmap.get_mut(&key).ok_or_else(|| {
225                DbErr::RecordNotFound(format!("Loader: failed to find model for {key:?}"))
226            })?;
227
228            vec.push(value);
229        }
230
231        let result: Vec<Vec<R::Model>> = keys
232            .iter()
233            .map(|key: &ValueTuple| hashmap.get(key).cloned().unwrap_or_default())
234            .collect();
235
236        Ok(result)
237    }
238
239    async fn load_many_to_many<R, S, V, C>(
240        &self,
241        stmt: S,
242        via: V,
243        db: &C,
244    ) -> Result<Vec<Vec<R::Model>>, DbErr>
245    where
246        C: ConnectionTrait,
247        R: EntityTrait,
248        R::Model: Send + Sync,
249        S: EntityOrSelect<R>,
250        V: EntityTrait,
251        V::Model: Send + Sync,
252        <<Self as LoaderTrait>::Model as ModelTrait>::Entity: Related<R>,
253    {
254        if let Some(via_rel) =
255            <<<Self as LoaderTrait>::Model as ModelTrait>::Entity as Related<R>>::via()
256        {
257            let rel_def =
258                <<<Self as LoaderTrait>::Model as ModelTrait>::Entity as Related<R>>::to();
259            if rel_def.rel_type != RelationType::HasOne {
260                return Err(query_err("Relation to is not HasOne"));
261            }
262
263            if !cmp_table_ref(&via_rel.to_tbl, &via.table_ref()) {
264                return Err(query_err(format!(
265                    "The given via Entity is incorrect: expected: {:?}, given: {:?}",
266                    via_rel.to_tbl,
267                    via.table_ref()
268                )));
269            }
270
271            if self.is_empty() {
272                return Ok(Vec::new());
273            }
274
275            let pkeys = self
276                .iter()
277                .map(|model| extract_key(&via_rel.from_col, model))
278                .collect::<Result<Vec<_>, _>>()?;
279
280            // Map of M::PK -> Vec<R::PK>
281            let mut keymap: HashMap<ValueTuple, Vec<ValueTuple>> = Default::default();
282
283            let keys: Vec<ValueTuple> = {
284                let condition = prepare_condition(&via_rel.to_tbl, &via_rel.to_col, &pkeys);
285                let stmt = V::find().filter(condition);
286                let data = stmt.all(db).await?;
287                for model in data {
288                    let pk = extract_key(&via_rel.to_col, &model)?;
289                    let entry = keymap.entry(pk).or_default();
290
291                    let fk = extract_key(&rel_def.from_col, &model)?;
292                    entry.push(fk);
293                }
294
295                keymap.values().flatten().cloned().collect()
296            };
297
298            let condition = prepare_condition(&rel_def.to_tbl, &rel_def.to_col, &keys);
299
300            let stmt = <Select<R> as QueryFilter>::filter(stmt.select(), condition);
301
302            let models = stmt.all(db).await?;
303
304            // Map of R::PK -> R::Model
305            let data = models.into_iter().try_fold(
306                HashMap::<ValueTuple, <R as EntityTrait>::Model>::new(),
307                |mut acc, model| {
308                    extract_key(&rel_def.to_col, &model).map(|key| {
309                        acc.insert(key, model);
310
311                        acc
312                    })
313                },
314            )?;
315
316            let result: Vec<Vec<R::Model>> = pkeys
317                .into_iter()
318                .map(|pkey| {
319                    let fkeys = keymap.get(&pkey).cloned().unwrap_or_default();
320
321                    let models: Vec<_> = fkeys
322                        .into_iter()
323                        .filter_map(|fkey| data.get(&fkey).cloned())
324                        .collect();
325
326                    models
327                })
328                .collect();
329
330            Ok(result)
331        } else {
332            return Err(query_err("Relation is not ManyToMany"));
333        }
334    }
335}
336
337fn cmp_table_ref(left: &TableRef, right: &TableRef) -> bool {
338    // not ideal; but
339    format!("{left:?}") == format!("{right:?}")
340}
341
342fn extract_key<Model>(target_col: &Identity, model: &Model) -> Result<ValueTuple, DbErr>
343where
344    Model: ModelTrait,
345{
346    Ok(match target_col {
347        Identity::Unary(a) => {
348            let a = a.to_string();
349            let column_a =
350                <<<Model as ModelTrait>::Entity as EntityTrait>::Column as FromStr>::from_str(&a)
351                    .map_err(|_| DbErr::Type(format!("Failed at mapping '{a}' to column A:1")))?;
352            ValueTuple::One(model.get(column_a))
353        }
354        Identity::Binary(a, b) => {
355            let a = a.to_string();
356            let b = b.to_string();
357            let column_a =
358                <<<Model as ModelTrait>::Entity as EntityTrait>::Column as FromStr>::from_str(&a)
359                    .map_err(|_| DbErr::Type(format!("Failed at mapping '{a}' to column A:2")))?;
360            let column_b =
361                <<<Model as ModelTrait>::Entity as EntityTrait>::Column as FromStr>::from_str(&b)
362                    .map_err(|_| DbErr::Type(format!("Failed at mapping '{b}' to column B:2")))?;
363            ValueTuple::Two(model.get(column_a), model.get(column_b))
364        }
365        Identity::Ternary(a, b, c) => {
366            let a = a.to_string();
367            let b = b.to_string();
368            let c = c.to_string();
369            let column_a =
370                <<<Model as ModelTrait>::Entity as EntityTrait>::Column as FromStr>::from_str(
371                    &a.to_string(),
372                )
373                .map_err(|_| DbErr::Type(format!("Failed at mapping '{a}' to column A:3")))?;
374            let column_b =
375                <<<Model as ModelTrait>::Entity as EntityTrait>::Column as FromStr>::from_str(
376                    &b.to_string(),
377                )
378                .map_err(|_| DbErr::Type(format!("Failed at mapping '{b}' to column B:3")))?;
379            let column_c =
380                <<<Model as ModelTrait>::Entity as EntityTrait>::Column as FromStr>::from_str(
381                    &c.to_string(),
382                )
383                .map_err(|_| DbErr::Type(format!("Failed at mapping '{c}' to column C:3")))?;
384            ValueTuple::Three(
385                model.get(column_a),
386                model.get(column_b),
387                model.get(column_c),
388            )
389        }
390        Identity::Many(cols) => {
391            let mut values = Vec::new();
392            for col in cols {
393                let col_name = col.to_string();
394                let column =
395                    <<<Model as ModelTrait>::Entity as EntityTrait>::Column as FromStr>::from_str(
396                        &col_name,
397                    )
398                    .map_err(|_| DbErr::Type(format!("Failed at mapping '{col_name}' to colum")))?;
399                values.push(model.get(column))
400            }
401            ValueTuple::Many(values)
402        }
403    })
404}
405
406fn prepare_condition(table: &TableRef, col: &Identity, keys: &[ValueTuple]) -> Condition {
407    let keys = if !keys.is_empty() {
408        let set: HashSet<_> = keys.iter().cloned().collect();
409        set.into_iter().collect()
410    } else {
411        Vec::new()
412    };
413
414    match col {
415        Identity::Unary(column_a) => {
416            let column_a = table_column(table, column_a);
417            Condition::all().add(Expr::col(column_a).is_in(keys.into_iter().flatten()))
418        }
419        Identity::Binary(column_a, column_b) => Condition::all().add(
420            Expr::tuple([
421                SimpleExpr::Column(table_column(table, column_a)),
422                SimpleExpr::Column(table_column(table, column_b)),
423            ])
424            .in_tuples(keys),
425        ),
426        Identity::Ternary(column_a, column_b, column_c) => Condition::all().add(
427            Expr::tuple([
428                SimpleExpr::Column(table_column(table, column_a)),
429                SimpleExpr::Column(table_column(table, column_b)),
430                SimpleExpr::Column(table_column(table, column_c)),
431            ])
432            .in_tuples(keys),
433        ),
434        Identity::Many(cols) => {
435            let columns = cols
436                .iter()
437                .map(|col| SimpleExpr::Column(table_column(table, col)));
438            Condition::all().add(Expr::tuple(columns).in_tuples(keys))
439        }
440    }
441}
442
443fn table_column(tbl: &TableRef, col: &DynIden) -> ColumnRef {
444    match tbl.to_owned() {
445        TableRef::Table(tbl) => (tbl, col.clone()).into_column_ref(),
446        TableRef::SchemaTable(sch, tbl) => (sch, tbl, col.clone()).into_column_ref(),
447        val => unimplemented!("Unsupported TableRef {val:?}"),
448    }
449}
450
451#[cfg(test)]
452mod tests {
453    fn cake_model(id: i32) -> sea_orm::tests_cfg::cake::Model {
454        let name = match id {
455            1 => "apple cake",
456            2 => "orange cake",
457            3 => "fruit cake",
458            4 => "chocolate cake",
459            _ => "",
460        }
461        .to_string();
462        sea_orm::tests_cfg::cake::Model { id, name }
463    }
464
465    fn fruit_model(id: i32, cake_id: Option<i32>) -> sea_orm::tests_cfg::fruit::Model {
466        let name = match id {
467            1 => "apple",
468            2 => "orange",
469            3 => "grape",
470            4 => "strawberry",
471            _ => "",
472        }
473        .to_string();
474        sea_orm::tests_cfg::fruit::Model { id, name, cake_id }
475    }
476
477    fn filling_model(id: i32) -> sea_orm::tests_cfg::filling::Model {
478        let name = match id {
479            1 => "apple juice",
480            2 => "orange jam",
481            3 => "chocolate crust",
482            4 => "strawberry jam",
483            _ => "",
484        }
485        .to_string();
486        sea_orm::tests_cfg::filling::Model {
487            id,
488            name,
489            vendor_id: Some(1),
490            ignored_attr: 0,
491        }
492    }
493
494    fn cake_filling_model(
495        cake_id: i32,
496        filling_id: i32,
497    ) -> sea_orm::tests_cfg::cake_filling::Model {
498        sea_orm::tests_cfg::cake_filling::Model {
499            cake_id,
500            filling_id,
501        }
502    }
503
504    #[tokio::test]
505    async fn test_load_one() {
506        use sea_orm::{entity::prelude::*, tests_cfg::*, DbBackend, LoaderTrait, MockDatabase};
507
508        let db = MockDatabase::new(DbBackend::Postgres)
509            .append_query_results([[cake_model(1), cake_model(2)]])
510            .into_connection();
511
512        let fruits = vec![fruit_model(1, Some(1))];
513
514        let cakes = fruits
515            .load_one(cake::Entity::find(), &db)
516            .await
517            .expect("Should return something");
518
519        assert_eq!(cakes, [Some(cake_model(1))]);
520    }
521
522    #[tokio::test]
523    async fn test_load_one_same_cake() {
524        use sea_orm::{entity::prelude::*, tests_cfg::*, DbBackend, LoaderTrait, MockDatabase};
525
526        let db = MockDatabase::new(DbBackend::Postgres)
527            .append_query_results([[cake_model(1), cake_model(2)]])
528            .into_connection();
529
530        let fruits = vec![fruit_model(1, Some(1)), fruit_model(2, Some(1))];
531
532        let cakes = fruits
533            .load_one(cake::Entity::find(), &db)
534            .await
535            .expect("Should return something");
536
537        assert_eq!(cakes, [Some(cake_model(1)), Some(cake_model(1))]);
538    }
539
540    #[tokio::test]
541    async fn test_load_one_empty() {
542        use sea_orm::{entity::prelude::*, tests_cfg::*, DbBackend, LoaderTrait, MockDatabase};
543
544        let db = MockDatabase::new(DbBackend::Postgres)
545            .append_query_results([[cake_model(1), cake_model(2)]])
546            .into_connection();
547
548        let fruits: Vec<fruit::Model> = vec![];
549
550        let cakes = fruits
551            .load_one(cake::Entity::find(), &db)
552            .await
553            .expect("Should return something");
554
555        assert_eq!(cakes, []);
556    }
557
558    #[tokio::test]
559    async fn test_load_many() {
560        use sea_orm::{entity::prelude::*, tests_cfg::*, DbBackend, LoaderTrait, MockDatabase};
561
562        let db = MockDatabase::new(DbBackend::Postgres)
563            .append_query_results([[fruit_model(1, Some(1))]])
564            .into_connection();
565
566        let cakes = vec![cake_model(1), cake_model(2)];
567
568        let fruits = cakes
569            .load_many(fruit::Entity::find(), &db)
570            .await
571            .expect("Should return something");
572
573        assert_eq!(fruits, [vec![fruit_model(1, Some(1))], vec![]]);
574    }
575
576    #[tokio::test]
577    async fn test_load_many_same_fruit() {
578        use sea_orm::{entity::prelude::*, tests_cfg::*, DbBackend, LoaderTrait, MockDatabase};
579
580        let db = MockDatabase::new(DbBackend::Postgres)
581            .append_query_results([[fruit_model(1, Some(1)), fruit_model(2, Some(1))]])
582            .into_connection();
583
584        let cakes = vec![cake_model(1), cake_model(2)];
585
586        let fruits = cakes
587            .load_many(fruit::Entity::find(), &db)
588            .await
589            .expect("Should return something");
590
591        assert_eq!(
592            fruits,
593            [
594                vec![fruit_model(1, Some(1)), fruit_model(2, Some(1))],
595                vec![]
596            ]
597        );
598    }
599
600    #[tokio::test]
601    async fn test_load_many_empty() {
602        use sea_orm::{entity::prelude::*, tests_cfg::*, DbBackend, MockDatabase};
603
604        let db = MockDatabase::new(DbBackend::Postgres).into_connection();
605
606        let cakes: Vec<cake::Model> = vec![];
607
608        let fruits = cakes
609            .load_many(fruit::Entity::find(), &db)
610            .await
611            .expect("Should return something");
612
613        let empty_vec: Vec<Vec<fruit::Model>> = vec![];
614
615        assert_eq!(fruits, empty_vec);
616    }
617
618    #[tokio::test]
619    async fn test_load_many_to_many_base() {
620        use sea_orm::{tests_cfg::*, DbBackend, IntoMockRow, LoaderTrait, MockDatabase};
621
622        let db = MockDatabase::new(DbBackend::Postgres)
623            .append_query_results([
624                [cake_filling_model(1, 1).into_mock_row()],
625                [filling_model(1).into_mock_row()],
626            ])
627            .into_connection();
628
629        let cakes = vec![cake_model(1)];
630
631        let fillings = cakes
632            .load_many_to_many(Filling, CakeFilling, &db)
633            .await
634            .expect("Should return something");
635
636        assert_eq!(fillings, vec![vec![filling_model(1)]]);
637    }
638
639    #[tokio::test]
640    async fn test_load_many_to_many_complex() {
641        use sea_orm::{tests_cfg::*, DbBackend, IntoMockRow, LoaderTrait, MockDatabase};
642
643        let db = MockDatabase::new(DbBackend::Postgres)
644            .append_query_results([
645                [
646                    cake_filling_model(1, 1).into_mock_row(),
647                    cake_filling_model(1, 2).into_mock_row(),
648                    cake_filling_model(1, 3).into_mock_row(),
649                    cake_filling_model(2, 1).into_mock_row(),
650                    cake_filling_model(2, 2).into_mock_row(),
651                ],
652                [
653                    filling_model(1).into_mock_row(),
654                    filling_model(2).into_mock_row(),
655                    filling_model(3).into_mock_row(),
656                    filling_model(4).into_mock_row(),
657                    filling_model(5).into_mock_row(),
658                ],
659            ])
660            .into_connection();
661
662        let cakes = vec![cake_model(1), cake_model(2), cake_model(3)];
663
664        let fillings = cakes
665            .load_many_to_many(Filling, CakeFilling, &db)
666            .await
667            .expect("Should return something");
668
669        assert_eq!(
670            fillings,
671            vec![
672                vec![filling_model(1), filling_model(2), filling_model(3)],
673                vec![filling_model(1), filling_model(2)],
674                vec![],
675            ]
676        );
677    }
678
679    #[tokio::test]
680    async fn test_load_many_to_many_empty() {
681        use sea_orm::{tests_cfg::*, DbBackend, IntoMockRow, LoaderTrait, MockDatabase};
682
683        let db = MockDatabase::new(DbBackend::Postgres)
684            .append_query_results([
685                [cake_filling_model(1, 1).into_mock_row()],
686                [filling_model(1).into_mock_row()],
687            ])
688            .into_connection();
689
690        let cakes: Vec<cake::Model> = vec![];
691
692        let fillings = cakes
693            .load_many_to_many(Filling, CakeFilling, &db)
694            .await
695            .expect("Should return something");
696
697        let empty_vec: Vec<Vec<filling::Model>> = vec![];
698
699        assert_eq!(fillings, empty_vec);
700    }
701
702    #[tokio::test]
703    async fn test_load_one_duplicate_keys() {
704        use sea_orm::{entity::prelude::*, tests_cfg::*, DbBackend, LoaderTrait, MockDatabase};
705
706        let db = MockDatabase::new(DbBackend::Postgres)
707            .append_query_results([[cake_model(1), cake_model(2)]])
708            .into_connection();
709
710        let fruits = vec![
711            fruit_model(1, Some(1)),
712            fruit_model(2, Some(1)),
713            fruit_model(3, Some(1)),
714            fruit_model(4, Some(1)),
715        ];
716
717        let cakes = fruits
718            .load_one(cake::Entity::find(), &db)
719            .await
720            .expect("Should return something");
721
722        assert_eq!(cakes.len(), 4);
723        for cake in &cakes {
724            assert_eq!(cake, &Some(cake_model(1)));
725        }
726        let logs = db.into_transaction_log();
727        let sql = format!("{:?}", logs[0]);
728
729        let values_count = sql.matches("$1").count();
730        assert_eq!(values_count, 1, "Duplicate values were not removed");
731    }
732
733    #[tokio::test]
734    async fn test_load_many_duplicate_keys() {
735        use sea_orm::{entity::prelude::*, tests_cfg::*, DbBackend, LoaderTrait, MockDatabase};
736
737        let db = MockDatabase::new(DbBackend::Postgres)
738            .append_query_results([[
739                fruit_model(1, Some(1)),
740                fruit_model(2, Some(1)),
741                fruit_model(3, Some(2)),
742            ]])
743            .into_connection();
744
745        let cakes = vec![cake_model(1), cake_model(1), cake_model(2), cake_model(2)];
746
747        let fruits = cakes
748            .load_many(fruit::Entity::find(), &db)
749            .await
750            .expect("Should return something");
751
752        assert_eq!(fruits.len(), 4);
753
754        let logs = db.into_transaction_log();
755        let sql = format!("{:?}", logs[0]);
756
757        let values_count = sql.matches("$1").count() + sql.matches("$2").count();
758        assert_eq!(values_count, 2, "Duplicate values were not removed");
759    }
760}