sea_orm/query/
loader.rs

1use crate::{
2    error::*, Condition, ConnectionTrait, DbErr, EntityTrait, Identity, ModelTrait, QueryFilter,
3    Related, RelationType, Select,
4};
5use async_trait::async_trait;
6use sea_query::{ColumnRef, DynIden, Expr, IntoColumnRef, SimpleExpr, TableRef, ValueTuple};
7use std::{
8    collections::{HashMap, HashSet},
9    str::FromStr,
10};
11
12/// Entity, or a Select<Entity>; to be used as parameters in [`LoaderTrait`]
13pub trait EntityOrSelect<E: EntityTrait>: Send {
14    /// If self is Entity, use Entity::find()
15    fn select(self) -> Select<E>;
16}
17
18/// This trait implements the Data Loader API
19#[async_trait]
20pub trait LoaderTrait {
21    /// Source model
22    type Model: ModelTrait;
23
24    /// Used to eager load has_one relations
25    async fn load_one<R, S, C>(&self, stmt: S, db: &C) -> Result<Vec<Option<R::Model>>, DbErr>
26    where
27        C: ConnectionTrait,
28        R: EntityTrait,
29        R::Model: Send + Sync,
30        S: EntityOrSelect<R>,
31        <<Self as LoaderTrait>::Model as ModelTrait>::Entity: Related<R>;
32
33    /// Used to eager load has_many relations
34    async fn load_many<R, S, C>(&self, stmt: S, db: &C) -> Result<Vec<Vec<R::Model>>, DbErr>
35    where
36        C: ConnectionTrait,
37        R: EntityTrait,
38        R::Model: Send + Sync,
39        S: EntityOrSelect<R>,
40        <<Self as LoaderTrait>::Model as ModelTrait>::Entity: Related<R>;
41
42    /// Used to eager load many_to_many relations
43    async fn load_many_to_many<R, S, V, C>(
44        &self,
45        stmt: S,
46        via: V,
47        db: &C,
48    ) -> Result<Vec<Vec<R::Model>>, DbErr>
49    where
50        C: ConnectionTrait,
51        R: EntityTrait,
52        R::Model: Send + Sync,
53        S: EntityOrSelect<R>,
54        V: EntityTrait,
55        V::Model: Send + Sync,
56        <<Self as LoaderTrait>::Model as ModelTrait>::Entity: Related<R>;
57}
58
59impl<E> EntityOrSelect<E> for E
60where
61    E: EntityTrait,
62{
63    fn select(self) -> Select<E> {
64        E::find()
65    }
66}
67
68impl<E> EntityOrSelect<E> for Select<E>
69where
70    E: EntityTrait,
71{
72    fn select(self) -> Select<E> {
73        self
74    }
75}
76
77#[async_trait]
78impl<M> LoaderTrait for Vec<M>
79where
80    M: ModelTrait + Sync,
81{
82    type Model = M;
83
84    async fn load_one<R, S, C>(&self, stmt: S, db: &C) -> Result<Vec<Option<R::Model>>, DbErr>
85    where
86        C: ConnectionTrait,
87        R: EntityTrait,
88        R::Model: Send + Sync,
89        S: EntityOrSelect<R>,
90        <<Self as LoaderTrait>::Model as ModelTrait>::Entity: Related<R>,
91    {
92        self.as_slice().load_one(stmt, db).await
93    }
94
95    async fn load_many<R, S, C>(&self, stmt: S, db: &C) -> Result<Vec<Vec<R::Model>>, DbErr>
96    where
97        C: ConnectionTrait,
98        R: EntityTrait,
99        R::Model: Send + Sync,
100        S: EntityOrSelect<R>,
101        <<Self as LoaderTrait>::Model as ModelTrait>::Entity: Related<R>,
102    {
103        self.as_slice().load_many(stmt, db).await
104    }
105
106    async fn load_many_to_many<R, S, V, C>(
107        &self,
108        stmt: S,
109        via: V,
110        db: &C,
111    ) -> Result<Vec<Vec<R::Model>>, DbErr>
112    where
113        C: ConnectionTrait,
114        R: EntityTrait,
115        R::Model: Send + Sync,
116        S: EntityOrSelect<R>,
117        V: EntityTrait,
118        V::Model: Send + Sync,
119        <<Self as LoaderTrait>::Model as ModelTrait>::Entity: Related<R>,
120    {
121        self.as_slice().load_many_to_many(stmt, via, db).await
122    }
123}
124
125#[async_trait]
126impl<M> LoaderTrait for &[M]
127where
128    M: ModelTrait + Sync,
129{
130    type Model = M;
131
132    async fn load_one<R, S, C>(&self, stmt: S, db: &C) -> Result<Vec<Option<R::Model>>, DbErr>
133    where
134        C: ConnectionTrait,
135        R: EntityTrait,
136        R::Model: Send + Sync,
137        S: EntityOrSelect<R>,
138        <<Self as LoaderTrait>::Model as ModelTrait>::Entity: Related<R>,
139    {
140        // we verify that is HasOne relation
141        if <<<Self as LoaderTrait>::Model as ModelTrait>::Entity as Related<R>>::via().is_some() {
142            return Err(query_err("Relation is ManytoMany instead of HasOne"));
143        }
144        let rel_def = <<<Self as LoaderTrait>::Model as ModelTrait>::Entity as Related<R>>::to();
145        if rel_def.rel_type == RelationType::HasMany {
146            return Err(query_err("Relation is HasMany instead of HasOne"));
147        }
148
149        if self.is_empty() {
150            return Ok(Vec::new());
151        }
152
153        let mut keys: Vec<ValueTuple> = Default::default();
154        for model in self.iter() {
155            keys.push(extract_key(&rel_def.from_col, model)?);
156        }
157        let keys = keys; // un-mut
158
159        let condition = prepare_condition(&rel_def.to_tbl, &rel_def.to_col, &keys);
160
161        let stmt = <Select<R> as QueryFilter>::filter(stmt.select(), condition);
162
163        let data = stmt.all(db).await?;
164
165        let mut hashmap: HashMap<ValueTuple, <R as EntityTrait>::Model> = Default::default();
166        for value in data {
167            let key = extract_key(&rel_def.to_col, &value)?;
168            hashmap.insert(key, value);
169        }
170        let hashmap = hashmap; // un-mut
171
172        let result: Vec<Option<<R as EntityTrait>::Model>> =
173            keys.iter().map(|key| hashmap.get(key).cloned()).collect();
174
175        Ok(result)
176    }
177
178    async fn load_many<R, S, C>(&self, stmt: S, db: &C) -> Result<Vec<Vec<R::Model>>, DbErr>
179    where
180        C: ConnectionTrait,
181        R: EntityTrait,
182        R::Model: Send + Sync,
183        S: EntityOrSelect<R>,
184        <<Self as LoaderTrait>::Model as ModelTrait>::Entity: Related<R>,
185    {
186        // we verify that is HasMany relation
187
188        if <<<Self as LoaderTrait>::Model as ModelTrait>::Entity as Related<R>>::via().is_some() {
189            return Err(query_err("Relation is ManyToMany instead of HasMany"));
190        }
191        let rel_def = <<<Self as LoaderTrait>::Model as ModelTrait>::Entity as Related<R>>::to();
192        if rel_def.rel_type == RelationType::HasOne {
193            return Err(query_err("Relation is HasOne instead of HasMany"));
194        }
195
196        if self.is_empty() {
197            return Ok(Vec::new());
198        }
199
200        let mut keys: Vec<ValueTuple> = Default::default();
201        for model in self.iter() {
202            keys.push(extract_key(&rel_def.from_col, model)?);
203        }
204        let keys = keys; // un-mut
205
206        let condition = prepare_condition(&rel_def.to_tbl, &rel_def.to_col, &keys);
207
208        let stmt = <Select<R> as QueryFilter>::filter(stmt.select(), condition);
209
210        let data = stmt.all(db).await?;
211
212        let mut hashmap: HashMap<ValueTuple, Vec<<R as EntityTrait>::Model>> =
213            keys.iter()
214                .fold(HashMap::new(), |mut acc, key: &ValueTuple| {
215                    acc.insert(key.clone(), Vec::new());
216                    acc
217                });
218
219        for value in data {
220            let key = extract_key(&rel_def.to_col, &value)?;
221
222            let vec = hashmap.get_mut(&key).ok_or_else(|| {
223                DbErr::RecordNotFound(format!("Loader: failed to find model for {key:?}"))
224            })?;
225
226            vec.push(value);
227        }
228
229        let result: Vec<Vec<R::Model>> = keys
230            .iter()
231            .map(|key: &ValueTuple| hashmap.get(key).cloned().unwrap_or_default())
232            .collect();
233
234        Ok(result)
235    }
236
237    async fn load_many_to_many<R, S, V, C>(
238        &self,
239        stmt: S,
240        via: V,
241        db: &C,
242    ) -> Result<Vec<Vec<R::Model>>, DbErr>
243    where
244        C: ConnectionTrait,
245        R: EntityTrait,
246        R::Model: Send + Sync,
247        S: EntityOrSelect<R>,
248        V: EntityTrait,
249        V::Model: Send + Sync,
250        <<Self as LoaderTrait>::Model as ModelTrait>::Entity: Related<R>,
251    {
252        if let Some(via_rel) =
253            <<<Self as LoaderTrait>::Model as ModelTrait>::Entity as Related<R>>::via()
254        {
255            let rel_def =
256                <<<Self as LoaderTrait>::Model as ModelTrait>::Entity as Related<R>>::to();
257            if rel_def.rel_type != RelationType::HasOne {
258                return Err(query_err("Relation to is not HasOne"));
259            }
260
261            if !cmp_table_ref(&via_rel.to_tbl, &via.table_ref()) {
262                return Err(query_err(format!(
263                    "The given via Entity is incorrect: expected: {:?}, given: {:?}",
264                    via_rel.to_tbl,
265                    via.table_ref()
266                )));
267            }
268
269            if self.is_empty() {
270                return Ok(Vec::new());
271            }
272
273            let mut pkeys: Vec<ValueTuple> = Default::default();
274            for model in self.iter() {
275                pkeys.push(extract_key(&via_rel.from_col, model)?);
276            }
277            let pkeys = pkeys; // un-mut
278
279            // Map of M::PK -> Vec<R::PK>
280            let mut keymap: HashMap<ValueTuple, Vec<ValueTuple>> = Default::default();
281
282            let keys: Vec<ValueTuple> = {
283                let condition = prepare_condition(&via_rel.to_tbl, &via_rel.to_col, &pkeys);
284                let stmt = V::find().filter(condition);
285                let data = stmt.all(db).await?;
286                for model in data {
287                    let pk = extract_key(&via_rel.to_col, &model)?;
288                    let entry = keymap.entry(pk).or_default();
289
290                    let fk = extract_key(&rel_def.from_col, &model)?;
291                    entry.push(fk);
292                }
293
294                keymap.values().flatten().cloned().collect()
295            };
296
297            let condition = prepare_condition(&rel_def.to_tbl, &rel_def.to_col, &keys);
298
299            let stmt = <Select<R> as QueryFilter>::filter(stmt.select(), condition);
300
301            let models = stmt.all(db).await?;
302
303            // Map of R::PK -> R::Model
304            let mut data: HashMap<ValueTuple, <R as EntityTrait>::Model> = Default::default();
305            for model in models {
306                data.insert(extract_key(&rel_def.to_col, &model)?, model);
307            }
308            let data = data; // un-mut
309
310            let result: Vec<Vec<R::Model>> = pkeys
311                .into_iter()
312                .map(|pkey| {
313                    let fkeys = keymap.get(&pkey).cloned().unwrap_or_default();
314
315                    let models: Vec<_> = fkeys
316                        .into_iter()
317                        .filter_map(|fkey| data.get(&fkey).cloned())
318                        .collect();
319
320                    models
321                })
322                .collect();
323
324            Ok(result)
325        } else {
326            return Err(query_err("Relation is not ManyToMany"));
327        }
328    }
329}
330
331fn cmp_table_ref(left: &TableRef, right: &TableRef) -> bool {
332    // not ideal; but
333    format!("{left:?}") == format!("{right:?}")
334}
335
336fn extract_key<Model>(target_col: &Identity, model: &Model) -> Result<ValueTuple, DbErr>
337where
338    Model: ModelTrait,
339{
340    Ok(match target_col {
341        Identity::Unary(a) => {
342            let a = a.to_string();
343            let column_a =
344                <<<Model as ModelTrait>::Entity as EntityTrait>::Column as FromStr>::from_str(&a)
345                    .map_err(|_| DbErr::Type(format!("Failed at mapping '{a}' to column A:1")))?;
346            ValueTuple::One(model.get(column_a))
347        }
348        Identity::Binary(a, b) => {
349            let a = a.to_string();
350            let b = b.to_string();
351            let column_a =
352                <<<Model as ModelTrait>::Entity as EntityTrait>::Column as FromStr>::from_str(&a)
353                    .map_err(|_| DbErr::Type(format!("Failed at mapping '{a}' to column A:2")))?;
354            let column_b =
355                <<<Model as ModelTrait>::Entity as EntityTrait>::Column as FromStr>::from_str(&b)
356                    .map_err(|_| DbErr::Type(format!("Failed at mapping '{b}' to column B:2")))?;
357            ValueTuple::Two(model.get(column_a), model.get(column_b))
358        }
359        Identity::Ternary(a, b, c) => {
360            let a = a.to_string();
361            let b = b.to_string();
362            let c = c.to_string();
363            let column_a =
364                <<<Model as ModelTrait>::Entity as EntityTrait>::Column as FromStr>::from_str(
365                    &a.to_string(),
366                )
367                .map_err(|_| DbErr::Type(format!("Failed at mapping '{a}' to column A:3")))?;
368            let column_b =
369                <<<Model as ModelTrait>::Entity as EntityTrait>::Column as FromStr>::from_str(
370                    &b.to_string(),
371                )
372                .map_err(|_| DbErr::Type(format!("Failed at mapping '{b}' to column B:3")))?;
373            let column_c =
374                <<<Model as ModelTrait>::Entity as EntityTrait>::Column as FromStr>::from_str(
375                    &c.to_string(),
376                )
377                .map_err(|_| DbErr::Type(format!("Failed at mapping '{c}' to column C:3")))?;
378            ValueTuple::Three(
379                model.get(column_a),
380                model.get(column_b),
381                model.get(column_c),
382            )
383        }
384        Identity::Many(cols) => {
385            let mut values = Vec::new();
386            for col in cols {
387                let col_name = col.to_string();
388                let column =
389                    <<<Model as ModelTrait>::Entity as EntityTrait>::Column as FromStr>::from_str(
390                        &col_name,
391                    )
392                    .map_err(|_| DbErr::Type(format!("Failed at mapping '{col_name}' to colum")))?;
393                values.push(model.get(column))
394            }
395            ValueTuple::Many(values)
396        }
397    })
398}
399
400fn prepare_condition(table: &TableRef, col: &Identity, keys: &[ValueTuple]) -> Condition {
401    let keys = if !keys.is_empty() {
402        let set: HashSet<_> = keys.iter().cloned().collect();
403        set.into_iter().collect()
404    } else {
405        Vec::new()
406    };
407
408    match col {
409        Identity::Unary(column_a) => {
410            let column_a = table_column(table, column_a);
411            Condition::all().add(Expr::col(column_a).is_in(keys.into_iter().flatten()))
412        }
413        Identity::Binary(column_a, column_b) => Condition::all().add(
414            Expr::tuple([
415                SimpleExpr::Column(table_column(table, column_a)),
416                SimpleExpr::Column(table_column(table, column_b)),
417            ])
418            .in_tuples(keys),
419        ),
420        Identity::Ternary(column_a, column_b, column_c) => Condition::all().add(
421            Expr::tuple([
422                SimpleExpr::Column(table_column(table, column_a)),
423                SimpleExpr::Column(table_column(table, column_b)),
424                SimpleExpr::Column(table_column(table, column_c)),
425            ])
426            .in_tuples(keys),
427        ),
428        Identity::Many(cols) => {
429            let columns = cols
430                .iter()
431                .map(|col| SimpleExpr::Column(table_column(table, col)));
432            Condition::all().add(Expr::tuple(columns).in_tuples(keys))
433        }
434    }
435}
436
437fn table_column(tbl: &TableRef, col: &DynIden) -> ColumnRef {
438    match tbl.to_owned() {
439        TableRef::Table(tbl) => (tbl, col.clone()).into_column_ref(),
440        TableRef::SchemaTable(sch, tbl) => (sch, tbl, col.clone()).into_column_ref(),
441        val => unimplemented!("Unsupported TableRef {val:?}"),
442    }
443}
444
445#[cfg(test)]
446mod tests {
447    fn cake_model(id: i32) -> sea_orm::tests_cfg::cake::Model {
448        let name = match id {
449            1 => "apple cake",
450            2 => "orange cake",
451            3 => "fruit cake",
452            4 => "chocolate cake",
453            _ => "",
454        }
455        .to_string();
456        sea_orm::tests_cfg::cake::Model { id, name }
457    }
458
459    fn fruit_model(id: i32, cake_id: Option<i32>) -> sea_orm::tests_cfg::fruit::Model {
460        let name = match id {
461            1 => "apple",
462            2 => "orange",
463            3 => "grape",
464            4 => "strawberry",
465            _ => "",
466        }
467        .to_string();
468        sea_orm::tests_cfg::fruit::Model { id, name, cake_id }
469    }
470
471    fn filling_model(id: i32) -> sea_orm::tests_cfg::filling::Model {
472        let name = match id {
473            1 => "apple juice",
474            2 => "orange jam",
475            3 => "chocolate crust",
476            4 => "strawberry jam",
477            _ => "",
478        }
479        .to_string();
480        sea_orm::tests_cfg::filling::Model {
481            id,
482            name,
483            vendor_id: Some(1),
484            ignored_attr: 0,
485        }
486    }
487
488    fn cake_filling_model(
489        cake_id: i32,
490        filling_id: i32,
491    ) -> sea_orm::tests_cfg::cake_filling::Model {
492        sea_orm::tests_cfg::cake_filling::Model {
493            cake_id,
494            filling_id,
495        }
496    }
497
498    #[tokio::test]
499    async fn test_load_one() {
500        use sea_orm::{entity::prelude::*, tests_cfg::*, DbBackend, LoaderTrait, MockDatabase};
501
502        let db = MockDatabase::new(DbBackend::Postgres)
503            .append_query_results([[cake_model(1), cake_model(2)]])
504            .into_connection();
505
506        let fruits = vec![fruit_model(1, Some(1))];
507
508        let cakes = fruits
509            .load_one(cake::Entity::find(), &db)
510            .await
511            .expect("Should return something");
512
513        assert_eq!(cakes, [Some(cake_model(1))]);
514    }
515
516    #[tokio::test]
517    async fn test_load_one_same_cake() {
518        use sea_orm::{entity::prelude::*, tests_cfg::*, DbBackend, LoaderTrait, MockDatabase};
519
520        let db = MockDatabase::new(DbBackend::Postgres)
521            .append_query_results([[cake_model(1), cake_model(2)]])
522            .into_connection();
523
524        let fruits = vec![fruit_model(1, Some(1)), fruit_model(2, Some(1))];
525
526        let cakes = fruits
527            .load_one(cake::Entity::find(), &db)
528            .await
529            .expect("Should return something");
530
531        assert_eq!(cakes, [Some(cake_model(1)), Some(cake_model(1))]);
532    }
533
534    #[tokio::test]
535    async fn test_load_one_empty() {
536        use sea_orm::{entity::prelude::*, tests_cfg::*, DbBackend, LoaderTrait, MockDatabase};
537
538        let db = MockDatabase::new(DbBackend::Postgres)
539            .append_query_results([[cake_model(1), cake_model(2)]])
540            .into_connection();
541
542        let fruits: Vec<fruit::Model> = vec![];
543
544        let cakes = fruits
545            .load_one(cake::Entity::find(), &db)
546            .await
547            .expect("Should return something");
548
549        assert_eq!(cakes, []);
550    }
551
552    #[tokio::test]
553    async fn test_load_many() {
554        use sea_orm::{entity::prelude::*, tests_cfg::*, DbBackend, LoaderTrait, MockDatabase};
555
556        let db = MockDatabase::new(DbBackend::Postgres)
557            .append_query_results([[fruit_model(1, Some(1))]])
558            .into_connection();
559
560        let cakes = vec![cake_model(1), cake_model(2)];
561
562        let fruits = cakes
563            .load_many(fruit::Entity::find(), &db)
564            .await
565            .expect("Should return something");
566
567        assert_eq!(fruits, [vec![fruit_model(1, Some(1))], vec![]]);
568    }
569
570    #[tokio::test]
571    async fn test_load_many_same_fruit() {
572        use sea_orm::{entity::prelude::*, tests_cfg::*, DbBackend, LoaderTrait, MockDatabase};
573
574        let db = MockDatabase::new(DbBackend::Postgres)
575            .append_query_results([[fruit_model(1, Some(1)), fruit_model(2, Some(1))]])
576            .into_connection();
577
578        let cakes = vec![cake_model(1), cake_model(2)];
579
580        let fruits = cakes
581            .load_many(fruit::Entity::find(), &db)
582            .await
583            .expect("Should return something");
584
585        assert_eq!(
586            fruits,
587            [
588                vec![fruit_model(1, Some(1)), fruit_model(2, Some(1))],
589                vec![]
590            ]
591        );
592    }
593
594    #[tokio::test]
595    async fn test_load_many_empty() {
596        use sea_orm::{entity::prelude::*, tests_cfg::*, DbBackend, MockDatabase};
597
598        let db = MockDatabase::new(DbBackend::Postgres).into_connection();
599
600        let cakes: Vec<cake::Model> = vec![];
601
602        let fruits = cakes
603            .load_many(fruit::Entity::find(), &db)
604            .await
605            .expect("Should return something");
606
607        let empty_vec: Vec<Vec<fruit::Model>> = vec![];
608
609        assert_eq!(fruits, empty_vec);
610    }
611
612    #[tokio::test]
613    async fn test_load_many_to_many_base() {
614        use sea_orm::{tests_cfg::*, DbBackend, IntoMockRow, LoaderTrait, MockDatabase};
615
616        let db = MockDatabase::new(DbBackend::Postgres)
617            .append_query_results([
618                [cake_filling_model(1, 1).into_mock_row()],
619                [filling_model(1).into_mock_row()],
620            ])
621            .into_connection();
622
623        let cakes = vec![cake_model(1)];
624
625        let fillings = cakes
626            .load_many_to_many(Filling, CakeFilling, &db)
627            .await
628            .expect("Should return something");
629
630        assert_eq!(fillings, vec![vec![filling_model(1)]]);
631    }
632
633    #[tokio::test]
634    async fn test_load_many_to_many_complex() {
635        use sea_orm::{tests_cfg::*, DbBackend, IntoMockRow, LoaderTrait, MockDatabase};
636
637        let db = MockDatabase::new(DbBackend::Postgres)
638            .append_query_results([
639                [
640                    cake_filling_model(1, 1).into_mock_row(),
641                    cake_filling_model(1, 2).into_mock_row(),
642                    cake_filling_model(1, 3).into_mock_row(),
643                    cake_filling_model(2, 1).into_mock_row(),
644                    cake_filling_model(2, 2).into_mock_row(),
645                ],
646                [
647                    filling_model(1).into_mock_row(),
648                    filling_model(2).into_mock_row(),
649                    filling_model(3).into_mock_row(),
650                    filling_model(4).into_mock_row(),
651                    filling_model(5).into_mock_row(),
652                ],
653            ])
654            .into_connection();
655
656        let cakes = vec![cake_model(1), cake_model(2), cake_model(3)];
657
658        let fillings = cakes
659            .load_many_to_many(Filling, CakeFilling, &db)
660            .await
661            .expect("Should return something");
662
663        assert_eq!(
664            fillings,
665            vec![
666                vec![filling_model(1), filling_model(2), filling_model(3)],
667                vec![filling_model(1), filling_model(2)],
668                vec![],
669            ]
670        );
671    }
672
673    #[tokio::test]
674    async fn test_load_many_to_many_empty() {
675        use sea_orm::{tests_cfg::*, DbBackend, IntoMockRow, LoaderTrait, MockDatabase};
676
677        let db = MockDatabase::new(DbBackend::Postgres)
678            .append_query_results([
679                [cake_filling_model(1, 1).into_mock_row()],
680                [filling_model(1).into_mock_row()],
681            ])
682            .into_connection();
683
684        let cakes: Vec<cake::Model> = vec![];
685
686        let fillings = cakes
687            .load_many_to_many(Filling, CakeFilling, &db)
688            .await
689            .expect("Should return something");
690
691        let empty_vec: Vec<Vec<filling::Model>> = vec![];
692
693        assert_eq!(fillings, empty_vec);
694    }
695
696    #[tokio::test]
697    async fn test_load_one_duplicate_keys() {
698        use sea_orm::{entity::prelude::*, tests_cfg::*, DbBackend, LoaderTrait, MockDatabase};
699
700        let db = MockDatabase::new(DbBackend::Postgres)
701            .append_query_results([[cake_model(1), cake_model(2)]])
702            .into_connection();
703
704        let fruits = vec![
705            fruit_model(1, Some(1)),
706            fruit_model(2, Some(1)),
707            fruit_model(3, Some(1)),
708            fruit_model(4, Some(1)),
709        ];
710
711        let cakes = fruits
712            .load_one(cake::Entity::find(), &db)
713            .await
714            .expect("Should return something");
715
716        assert_eq!(cakes.len(), 4);
717        for cake in &cakes {
718            assert_eq!(cake, &Some(cake_model(1)));
719        }
720        let logs = db.into_transaction_log();
721        let sql = format!("{:?}", logs[0]);
722
723        let values_count = sql.matches("$1").count();
724        assert_eq!(values_count, 1, "Duplicate values were not removed");
725    }
726
727    #[tokio::test]
728    async fn test_load_many_duplicate_keys() {
729        use sea_orm::{entity::prelude::*, tests_cfg::*, DbBackend, LoaderTrait, MockDatabase};
730
731        let db = MockDatabase::new(DbBackend::Postgres)
732            .append_query_results([[
733                fruit_model(1, Some(1)),
734                fruit_model(2, Some(1)),
735                fruit_model(3, Some(2)),
736            ]])
737            .into_connection();
738
739        let cakes = vec![cake_model(1), cake_model(1), cake_model(2), cake_model(2)];
740
741        let fruits = cakes
742            .load_many(fruit::Entity::find(), &db)
743            .await
744            .expect("Should return something");
745
746        assert_eq!(fruits.len(), 4);
747
748        let logs = db.into_transaction_log();
749        let sql = format!("{:?}", logs[0]);
750
751        let values_count = sql.matches("$1").count() + sql.matches("$2").count();
752        assert_eq!(values_count, 2, "Duplicate values were not removed");
753    }
754}