1use crate::{
2 error::*, Condition, ConnectionTrait, DbErr, EntityTrait, Identity, ModelTrait, QueryFilter,
3 Related, RelationType, Select,
4};
5use async_trait::async_trait;
6use sea_query::{ColumnRef, DynIden, Expr, IntoColumnRef, SimpleExpr, TableRef, ValueTuple};
7use std::{collections::HashMap, str::FromStr};
8
9pub trait EntityOrSelect<E: EntityTrait>: Send {
11 fn select(self) -> Select<E>;
13}
14
15#[async_trait]
17pub trait LoaderTrait {
18 type Model: ModelTrait;
20
21 async fn load_one<R, S, C>(&self, stmt: S, db: &C) -> Result<Vec<Option<R::Model>>, DbErr>
23 where
24 C: ConnectionTrait,
25 R: EntityTrait,
26 R::Model: Send + Sync,
27 S: EntityOrSelect<R>,
28 <<Self as LoaderTrait>::Model as ModelTrait>::Entity: Related<R>;
29
30 async fn load_many<R, S, C>(&self, stmt: S, db: &C) -> Result<Vec<Vec<R::Model>>, DbErr>
32 where
33 C: ConnectionTrait,
34 R: EntityTrait,
35 R::Model: Send + Sync,
36 S: EntityOrSelect<R>,
37 <<Self as LoaderTrait>::Model as ModelTrait>::Entity: Related<R>;
38
39 async fn load_many_to_many<R, S, V, C>(
41 &self,
42 stmt: S,
43 via: V,
44 db: &C,
45 ) -> Result<Vec<Vec<R::Model>>, DbErr>
46 where
47 C: ConnectionTrait,
48 R: EntityTrait,
49 R::Model: Send + Sync,
50 S: EntityOrSelect<R>,
51 V: EntityTrait,
52 V::Model: Send + Sync,
53 <<Self as LoaderTrait>::Model as ModelTrait>::Entity: Related<R>;
54}
55
56impl<E> EntityOrSelect<E> for E
57where
58 E: EntityTrait,
59{
60 fn select(self) -> Select<E> {
61 E::find()
62 }
63}
64
65impl<E> EntityOrSelect<E> for Select<E>
66where
67 E: EntityTrait,
68{
69 fn select(self) -> Select<E> {
70 self
71 }
72}
73
74#[async_trait]
75impl<M> LoaderTrait for Vec<M>
76where
77 M: ModelTrait + Sync,
78{
79 type Model = M;
80
81 async fn load_one<R, S, C>(&self, stmt: S, db: &C) -> Result<Vec<Option<R::Model>>, DbErr>
82 where
83 C: ConnectionTrait,
84 R: EntityTrait,
85 R::Model: Send + Sync,
86 S: EntityOrSelect<R>,
87 <<Self as LoaderTrait>::Model as ModelTrait>::Entity: Related<R>,
88 {
89 self.as_slice().load_one(stmt, db).await
90 }
91
92 async fn load_many<R, S, C>(&self, stmt: S, db: &C) -> Result<Vec<Vec<R::Model>>, DbErr>
93 where
94 C: ConnectionTrait,
95 R: EntityTrait,
96 R::Model: Send + Sync,
97 S: EntityOrSelect<R>,
98 <<Self as LoaderTrait>::Model as ModelTrait>::Entity: Related<R>,
99 {
100 self.as_slice().load_many(stmt, db).await
101 }
102
103 async fn load_many_to_many<R, S, V, C>(
104 &self,
105 stmt: S,
106 via: V,
107 db: &C,
108 ) -> Result<Vec<Vec<R::Model>>, DbErr>
109 where
110 C: ConnectionTrait,
111 R: EntityTrait,
112 R::Model: Send + Sync,
113 S: EntityOrSelect<R>,
114 V: EntityTrait,
115 V::Model: Send + Sync,
116 <<Self as LoaderTrait>::Model as ModelTrait>::Entity: Related<R>,
117 {
118 self.as_slice().load_many_to_many(stmt, via, db).await
119 }
120}
121
122#[async_trait]
123impl<M> LoaderTrait for &[M]
124where
125 M: ModelTrait + Sync,
126{
127 type Model = M;
128
129 async fn load_one<R, S, C>(&self, stmt: S, db: &C) -> Result<Vec<Option<R::Model>>, DbErr>
130 where
131 C: ConnectionTrait,
132 R: EntityTrait,
133 R::Model: Send + Sync,
134 S: EntityOrSelect<R>,
135 <<Self as LoaderTrait>::Model as ModelTrait>::Entity: Related<R>,
136 {
137 if <<<Self as LoaderTrait>::Model as ModelTrait>::Entity as Related<R>>::via().is_some() {
139 return Err(query_err("Relation is ManytoMany instead of HasOne"));
140 }
141 let rel_def = <<<Self as LoaderTrait>::Model as ModelTrait>::Entity as Related<R>>::to();
142 if rel_def.rel_type == RelationType::HasMany {
143 return Err(query_err("Relation is HasMany instead of HasOne"));
144 }
145
146 if self.is_empty() {
147 return Ok(Vec::new());
148 }
149
150 let keys: Vec<ValueTuple> = self
151 .iter()
152 .map(|model: &M| extract_key(&rel_def.from_col, model))
153 .collect();
154
155 let condition = prepare_condition(&rel_def.to_tbl, &rel_def.to_col, &keys);
156
157 let stmt = <Select<R> as QueryFilter>::filter(stmt.select(), condition);
158
159 let data = stmt.all(db).await?;
160
161 let hashmap: HashMap<ValueTuple, <R as EntityTrait>::Model> = data.into_iter().fold(
162 HashMap::new(),
163 |mut acc, value: <R as EntityTrait>::Model| {
164 {
165 let key = extract_key(&rel_def.to_col, &value);
166 acc.insert(key, value);
167 }
168
169 acc
170 },
171 );
172
173 let result: Vec<Option<<R as EntityTrait>::Model>> =
174 keys.iter().map(|key| hashmap.get(key).cloned()).collect();
175
176 Ok(result)
177 }
178
179 async fn load_many<R, S, C>(&self, stmt: S, db: &C) -> Result<Vec<Vec<R::Model>>, DbErr>
180 where
181 C: ConnectionTrait,
182 R: EntityTrait,
183 R::Model: Send + Sync,
184 S: EntityOrSelect<R>,
185 <<Self as LoaderTrait>::Model as ModelTrait>::Entity: Related<R>,
186 {
187 if <<<Self as LoaderTrait>::Model as ModelTrait>::Entity as Related<R>>::via().is_some() {
190 return Err(query_err("Relation is ManyToMany instead of HasMany"));
191 }
192 let rel_def = <<<Self as LoaderTrait>::Model as ModelTrait>::Entity as Related<R>>::to();
193 if rel_def.rel_type == RelationType::HasOne {
194 return Err(query_err("Relation is HasOne instead of HasMany"));
195 }
196
197 if self.is_empty() {
198 return Ok(Vec::new());
199 }
200
201 let keys: Vec<ValueTuple> = self
202 .iter()
203 .map(|model: &M| extract_key(&rel_def.from_col, model))
204 .collect();
205
206 let condition = prepare_condition(&rel_def.to_tbl, &rel_def.to_col, &keys);
207
208 let stmt = <Select<R> as QueryFilter>::filter(stmt.select(), condition);
209
210 let data = stmt.all(db).await?;
211
212 let mut hashmap: HashMap<ValueTuple, Vec<<R as EntityTrait>::Model>> =
213 keys.iter()
214 .fold(HashMap::new(), |mut acc, key: &ValueTuple| {
215 acc.insert(key.clone(), Vec::new());
216 acc
217 });
218
219 data.into_iter()
220 .for_each(|value: <R as EntityTrait>::Model| {
221 let key = extract_key(&rel_def.to_col, &value);
222
223 let vec = hashmap
224 .get_mut(&key)
225 .expect("Failed at finding key on hashmap");
226
227 vec.push(value);
228 });
229
230 let result: Vec<Vec<R::Model>> = keys
231 .iter()
232 .map(|key: &ValueTuple| hashmap.get(key).cloned().unwrap_or_default())
233 .collect();
234
235 Ok(result)
236 }
237
238 async fn load_many_to_many<R, S, V, C>(
239 &self,
240 stmt: S,
241 via: V,
242 db: &C,
243 ) -> Result<Vec<Vec<R::Model>>, DbErr>
244 where
245 C: ConnectionTrait,
246 R: EntityTrait,
247 R::Model: Send + Sync,
248 S: EntityOrSelect<R>,
249 V: EntityTrait,
250 V::Model: Send + Sync,
251 <<Self as LoaderTrait>::Model as ModelTrait>::Entity: Related<R>,
252 {
253 if let Some(via_rel) =
254 <<<Self as LoaderTrait>::Model as ModelTrait>::Entity as Related<R>>::via()
255 {
256 let rel_def =
257 <<<Self as LoaderTrait>::Model as ModelTrait>::Entity as Related<R>>::to();
258 if rel_def.rel_type != RelationType::HasOne {
259 return Err(query_err("Relation to is not HasOne"));
260 }
261
262 if !cmp_table_ref(&via_rel.to_tbl, &via.table_ref()) {
263 return Err(query_err(format!(
264 "The given via Entity is incorrect: expected: {:?}, given: {:?}",
265 via_rel.to_tbl,
266 via.table_ref()
267 )));
268 }
269
270 if self.is_empty() {
271 return Ok(Vec::new());
272 }
273
274 let pkeys: Vec<ValueTuple> = self
275 .iter()
276 .map(|model: &M| extract_key(&via_rel.from_col, model))
277 .collect();
278
279 let mut keymap: HashMap<ValueTuple, Vec<ValueTuple>> = Default::default();
281
282 let keys: Vec<ValueTuple> = {
283 let condition = prepare_condition(&via_rel.to_tbl, &via_rel.to_col, &pkeys);
284 let stmt = V::find().filter(condition);
285 let data = stmt.all(db).await?;
286 data.into_iter().for_each(|model| {
287 let pk = extract_key(&via_rel.to_col, &model);
288 let entry = keymap.entry(pk).or_default();
289
290 let fk = extract_key(&rel_def.from_col, &model);
291 entry.push(fk);
292 });
293
294 keymap.values().flatten().cloned().collect()
295 };
296
297 let condition = prepare_condition(&rel_def.to_tbl, &rel_def.to_col, &keys);
298
299 let stmt = <Select<R> as QueryFilter>::filter(stmt.select(), condition);
300
301 let data = stmt.all(db).await?;
302
303 let data: HashMap<ValueTuple, <R as EntityTrait>::Model> = data
305 .into_iter()
306 .map(|model| {
307 let key = extract_key(&rel_def.to_col, &model);
308 (key, model)
309 })
310 .collect();
311
312 let result: Vec<Vec<R::Model>> = pkeys
313 .into_iter()
314 .map(|pkey| {
315 let fkeys = keymap.get(&pkey).cloned().unwrap_or_default();
316
317 let models: Vec<_> = fkeys
318 .into_iter()
319 .filter_map(|fkey| data.get(&fkey).cloned())
320 .collect();
321
322 models
323 })
324 .collect();
325
326 Ok(result)
327 } else {
328 return Err(query_err("Relation is not ManyToMany"));
329 }
330 }
331}
332
333fn cmp_table_ref(left: &TableRef, right: &TableRef) -> bool {
334 format!("{left:?}") == format!("{right:?}")
336}
337
338fn extract_key<Model>(target_col: &Identity, model: &Model) -> ValueTuple
339where
340 Model: ModelTrait,
341{
342 match target_col {
343 Identity::Unary(a) => {
344 let column_a =
345 <<<Model as ModelTrait>::Entity as EntityTrait>::Column as FromStr>::from_str(
346 &a.to_string(),
347 )
348 .unwrap_or_else(|_| panic!("Failed at mapping string to column A:1"));
349 ValueTuple::One(model.get(column_a))
350 }
351 Identity::Binary(a, b) => {
352 let column_a =
353 <<<Model as ModelTrait>::Entity as EntityTrait>::Column as FromStr>::from_str(
354 &a.to_string(),
355 )
356 .unwrap_or_else(|_| panic!("Failed at mapping string to column A:2"));
357 let column_b =
358 <<<Model as ModelTrait>::Entity as EntityTrait>::Column as FromStr>::from_str(
359 &b.to_string(),
360 )
361 .unwrap_or_else(|_| panic!("Failed at mapping string to column B:2"));
362 ValueTuple::Two(model.get(column_a), model.get(column_b))
363 }
364 Identity::Ternary(a, b, c) => {
365 let column_a =
366 <<<Model as ModelTrait>::Entity as EntityTrait>::Column as FromStr>::from_str(
367 &a.to_string(),
368 )
369 .unwrap_or_else(|_| panic!("Failed at mapping string to column A:3"));
370 let column_b =
371 <<<Model as ModelTrait>::Entity as EntityTrait>::Column as FromStr>::from_str(
372 &b.to_string(),
373 )
374 .unwrap_or_else(|_| panic!("Failed at mapping string to column B:3"));
375 let column_c =
376 <<<Model as ModelTrait>::Entity as EntityTrait>::Column as FromStr>::from_str(
377 &c.to_string(),
378 )
379 .unwrap_or_else(|_| panic!("Failed at mapping string to column C:3"));
380 ValueTuple::Three(
381 model.get(column_a),
382 model.get(column_b),
383 model.get(column_c),
384 )
385 }
386 Identity::Many(cols) => {
387 let values = cols.iter().map(|col| {
388 let col_name = col.to_string();
389 let column = <<<Model as ModelTrait>::Entity as EntityTrait>::Column as FromStr>::from_str(
390 &col_name,
391 )
392 .unwrap_or_else(|_| panic!("Failed at mapping '{}' to column", col_name));
393 model.get(column)
394 })
395 .collect();
396 ValueTuple::Many(values)
397 }
398 }
399}
400
401fn prepare_condition(table: &TableRef, col: &Identity, keys: &[ValueTuple]) -> Condition {
402 let keys = keys.to_owned();
404 match col {
405 Identity::Unary(column_a) => {
406 let column_a = table_column(table, column_a);
407 Condition::all().add(Expr::col(column_a).is_in(keys.into_iter().flatten()))
408 }
409 Identity::Binary(column_a, column_b) => Condition::all().add(
410 Expr::tuple([
411 SimpleExpr::Column(table_column(table, column_a)),
412 SimpleExpr::Column(table_column(table, column_b)),
413 ])
414 .in_tuples(keys),
415 ),
416 Identity::Ternary(column_a, column_b, column_c) => Condition::all().add(
417 Expr::tuple([
418 SimpleExpr::Column(table_column(table, column_a)),
419 SimpleExpr::Column(table_column(table, column_b)),
420 SimpleExpr::Column(table_column(table, column_c)),
421 ])
422 .in_tuples(keys),
423 ),
424 Identity::Many(cols) => {
425 let columns = cols
426 .iter()
427 .map(|col| SimpleExpr::Column(table_column(table, col)));
428 Condition::all().add(Expr::tuple(columns).in_tuples(keys))
429 }
430 }
431}
432
433fn table_column(tbl: &TableRef, col: &DynIden) -> ColumnRef {
434 match tbl.to_owned() {
435 TableRef::Table(tbl) => (tbl, col.clone()).into_column_ref(),
436 TableRef::SchemaTable(sch, tbl) => (sch, tbl, col.clone()).into_column_ref(),
437 val => unimplemented!("Unsupported TableRef {val:?}"),
438 }
439}
440
441#[cfg(test)]
442mod tests {
443 fn cake_model(id: i32) -> sea_orm::tests_cfg::cake::Model {
444 let name = match id {
445 1 => "apple cake",
446 2 => "orange cake",
447 3 => "fruit cake",
448 4 => "chocolate cake",
449 _ => "",
450 }
451 .to_string();
452 sea_orm::tests_cfg::cake::Model { id, name }
453 }
454
455 fn fruit_model(id: i32, cake_id: Option<i32>) -> sea_orm::tests_cfg::fruit::Model {
456 let name = match id {
457 1 => "apple",
458 2 => "orange",
459 3 => "grape",
460 4 => "strawberry",
461 _ => "",
462 }
463 .to_string();
464 sea_orm::tests_cfg::fruit::Model { id, name, cake_id }
465 }
466
467 fn filling_model(id: i32) -> sea_orm::tests_cfg::filling::Model {
468 let name = match id {
469 1 => "apple juice",
470 2 => "orange jam",
471 3 => "chocolate crust",
472 4 => "strawberry jam",
473 _ => "",
474 }
475 .to_string();
476 sea_orm::tests_cfg::filling::Model {
477 id,
478 name,
479 vendor_id: Some(1),
480 ignored_attr: 0,
481 }
482 }
483
484 fn cake_filling_model(
485 cake_id: i32,
486 filling_id: i32,
487 ) -> sea_orm::tests_cfg::cake_filling::Model {
488 sea_orm::tests_cfg::cake_filling::Model {
489 cake_id,
490 filling_id,
491 }
492 }
493
494 #[tokio::test]
495 async fn test_load_one() {
496 use sea_orm::{entity::prelude::*, tests_cfg::*, DbBackend, LoaderTrait, MockDatabase};
497
498 let db = MockDatabase::new(DbBackend::Postgres)
499 .append_query_results([[cake_model(1), cake_model(2)]])
500 .into_connection();
501
502 let fruits = vec![fruit_model(1, Some(1))];
503
504 let cakes = fruits
505 .load_one(cake::Entity::find(), &db)
506 .await
507 .expect("Should return something");
508
509 assert_eq!(cakes, [Some(cake_model(1))]);
510 }
511
512 #[tokio::test]
513 async fn test_load_one_same_cake() {
514 use sea_orm::{entity::prelude::*, tests_cfg::*, DbBackend, LoaderTrait, MockDatabase};
515
516 let db = MockDatabase::new(DbBackend::Postgres)
517 .append_query_results([[cake_model(1), cake_model(2)]])
518 .into_connection();
519
520 let fruits = vec![fruit_model(1, Some(1)), fruit_model(2, Some(1))];
521
522 let cakes = fruits
523 .load_one(cake::Entity::find(), &db)
524 .await
525 .expect("Should return something");
526
527 assert_eq!(cakes, [Some(cake_model(1)), Some(cake_model(1))]);
528 }
529
530 #[tokio::test]
531 async fn test_load_one_empty() {
532 use sea_orm::{entity::prelude::*, tests_cfg::*, DbBackend, LoaderTrait, MockDatabase};
533
534 let db = MockDatabase::new(DbBackend::Postgres)
535 .append_query_results([[cake_model(1), cake_model(2)]])
536 .into_connection();
537
538 let fruits: Vec<fruit::Model> = vec![];
539
540 let cakes = fruits
541 .load_one(cake::Entity::find(), &db)
542 .await
543 .expect("Should return something");
544
545 assert_eq!(cakes, []);
546 }
547
548 #[tokio::test]
549 async fn test_load_many() {
550 use sea_orm::{entity::prelude::*, tests_cfg::*, DbBackend, LoaderTrait, MockDatabase};
551
552 let db = MockDatabase::new(DbBackend::Postgres)
553 .append_query_results([[fruit_model(1, Some(1))]])
554 .into_connection();
555
556 let cakes = vec![cake_model(1), cake_model(2)];
557
558 let fruits = cakes
559 .load_many(fruit::Entity::find(), &db)
560 .await
561 .expect("Should return something");
562
563 assert_eq!(fruits, [vec![fruit_model(1, Some(1))], vec![]]);
564 }
565
566 #[tokio::test]
567 async fn test_load_many_same_fruit() {
568 use sea_orm::{entity::prelude::*, tests_cfg::*, DbBackend, LoaderTrait, MockDatabase};
569
570 let db = MockDatabase::new(DbBackend::Postgres)
571 .append_query_results([[fruit_model(1, Some(1)), fruit_model(2, Some(1))]])
572 .into_connection();
573
574 let cakes = vec![cake_model(1), cake_model(2)];
575
576 let fruits = cakes
577 .load_many(fruit::Entity::find(), &db)
578 .await
579 .expect("Should return something");
580
581 assert_eq!(
582 fruits,
583 [
584 vec![fruit_model(1, Some(1)), fruit_model(2, Some(1))],
585 vec![]
586 ]
587 );
588 }
589
590 #[tokio::test]
591 async fn test_load_many_empty() {
592 use sea_orm::{entity::prelude::*, tests_cfg::*, DbBackend, MockDatabase};
593
594 let db = MockDatabase::new(DbBackend::Postgres).into_connection();
595
596 let cakes: Vec<cake::Model> = vec![];
597
598 let fruits = cakes
599 .load_many(fruit::Entity::find(), &db)
600 .await
601 .expect("Should return something");
602
603 let empty_vec: Vec<Vec<fruit::Model>> = vec![];
604
605 assert_eq!(fruits, empty_vec);
606 }
607
608 #[tokio::test]
609 async fn test_load_many_to_many_base() {
610 use sea_orm::{tests_cfg::*, DbBackend, IntoMockRow, LoaderTrait, MockDatabase};
611
612 let db = MockDatabase::new(DbBackend::Postgres)
613 .append_query_results([
614 [cake_filling_model(1, 1).into_mock_row()],
615 [filling_model(1).into_mock_row()],
616 ])
617 .into_connection();
618
619 let cakes = vec![cake_model(1)];
620
621 let fillings = cakes
622 .load_many_to_many(Filling, CakeFilling, &db)
623 .await
624 .expect("Should return something");
625
626 assert_eq!(fillings, vec![vec![filling_model(1)]]);
627 }
628
629 #[tokio::test]
630 async fn test_load_many_to_many_complex() {
631 use sea_orm::{tests_cfg::*, DbBackend, IntoMockRow, LoaderTrait, MockDatabase};
632
633 let db = MockDatabase::new(DbBackend::Postgres)
634 .append_query_results([
635 [
636 cake_filling_model(1, 1).into_mock_row(),
637 cake_filling_model(1, 2).into_mock_row(),
638 cake_filling_model(1, 3).into_mock_row(),
639 cake_filling_model(2, 1).into_mock_row(),
640 cake_filling_model(2, 2).into_mock_row(),
641 ],
642 [
643 filling_model(1).into_mock_row(),
644 filling_model(2).into_mock_row(),
645 filling_model(3).into_mock_row(),
646 filling_model(4).into_mock_row(),
647 filling_model(5).into_mock_row(),
648 ],
649 ])
650 .into_connection();
651
652 let cakes = vec![cake_model(1), cake_model(2), cake_model(3)];
653
654 let fillings = cakes
655 .load_many_to_many(Filling, CakeFilling, &db)
656 .await
657 .expect("Should return something");
658
659 assert_eq!(
660 fillings,
661 vec![
662 vec![filling_model(1), filling_model(2), filling_model(3)],
663 vec![filling_model(1), filling_model(2)],
664 vec![],
665 ]
666 );
667 }
668
669 #[tokio::test]
670 async fn test_load_many_to_many_empty() {
671 use sea_orm::{tests_cfg::*, DbBackend, IntoMockRow, LoaderTrait, MockDatabase};
672
673 let db = MockDatabase::new(DbBackend::Postgres)
674 .append_query_results([
675 [cake_filling_model(1, 1).into_mock_row()],
676 [filling_model(1).into_mock_row()],
677 ])
678 .into_connection();
679
680 let cakes: Vec<cake::Model> = vec![];
681
682 let fillings = cakes
683 .load_many_to_many(Filling, CakeFilling, &db)
684 .await
685 .expect("Should return something");
686
687 let empty_vec: Vec<Vec<filling::Model>> = vec![];
688
689 assert_eq!(fillings, empty_vec);
690 }
691}