1use crate::{
2 ColumnTrait, Condition, ConnectionTrait, DbBackend, DbErr, EntityTrait, Identity, JoinType,
3 ModelTrait, QueryFilter, QuerySelect, Related, RelationType, Select, dynamic, error::*,
4};
5use async_trait::async_trait;
6use sea_query::{ColumnRef, DynIden, Expr, ExprTrait, IntoColumnRef, TableRef, ValueTuple};
7use std::{collections::HashMap, str::FromStr};
8
9pub trait EntityOrSelect<E: EntityTrait>: Send {
13 fn select(self) -> Select<E>;
15}
16
17#[async_trait]
19pub trait LoaderTrait {
20 type Model: ModelTrait;
22
23 async fn load_one<R, S, C>(&self, stmt: S, db: &C) -> Result<Vec<Option<R::Model>>, DbErr>
25 where
26 C: ConnectionTrait,
27 R: EntityTrait,
28 R::Model: Send + Sync,
29 S: EntityOrSelect<R>,
30 <Self::Model as ModelTrait>::Entity: Related<R>;
31
32 async fn load_many<R, S, C>(&self, stmt: S, db: &C) -> Result<Vec<Vec<R::Model>>, DbErr>
34 where
35 C: ConnectionTrait,
36 R: EntityTrait,
37 R::Model: Send + Sync,
38 S: EntityOrSelect<R>,
39 <Self::Model as ModelTrait>::Entity: Related<R>;
40
41 async fn load_many_to_many<R, S, V, C>(
43 &self,
44 stmt: S,
45 via: V,
46 db: &C,
47 ) -> Result<Vec<Vec<R::Model>>, DbErr>
48 where
49 C: ConnectionTrait,
50 R: EntityTrait,
51 R::Model: Send + Sync,
52 S: EntityOrSelect<R>,
53 V: EntityTrait,
54 V::Model: Send + Sync,
55 <Self::Model as ModelTrait>::Entity: Related<R>;
56}
57
58#[doc(hidden)]
59#[async_trait]
60pub trait LoaderTraitEx {
61 type Model: ModelTrait;
63
64 async fn load_one_ex<R, S, C>(&self, stmt: S, db: &C) -> Result<Vec<Option<R::ModelEx>>, DbErr>
65 where
66 C: ConnectionTrait,
67 R: EntityTrait,
68 R::Model: Send + Sync,
69 S: EntityOrSelect<R>,
70 R::ModelEx: From<R::Model>,
71 <Self::Model as ModelTrait>::Entity: Related<R>;
72
73 async fn load_many_ex<R, S, C>(&self, stmt: S, db: &C) -> Result<Vec<Vec<R::ModelEx>>, DbErr>
74 where
75 C: ConnectionTrait,
76 R: EntityTrait,
77 R::Model: Send + Sync,
78 S: EntityOrSelect<R>,
79 R::ModelEx: From<R::Model>,
80 <Self::Model as ModelTrait>::Entity: Related<R>;
81}
82
83#[doc(hidden)]
84#[async_trait]
85pub trait NestedLoaderTrait {
86 type Model: ModelTrait;
88
89 async fn load_one_ex<R, S, C>(
90 &self,
91 stmt: S,
92 db: &C,
93 ) -> Result<Vec<Vec<Option<R::ModelEx>>>, DbErr>
94 where
95 C: ConnectionTrait,
96 R: EntityTrait,
97 R::Model: Send + Sync,
98 S: EntityOrSelect<R>,
99 R::ModelEx: From<R::Model>,
100 <Self::Model as ModelTrait>::Entity: Related<R>;
101
102 async fn load_many_ex<R, S, C>(
103 &self,
104 stmt: S,
105 db: &C,
106 ) -> Result<Vec<Vec<Vec<R::ModelEx>>>, DbErr>
107 where
108 C: ConnectionTrait,
109 R: EntityTrait,
110 R::Model: Send + Sync,
111 S: EntityOrSelect<R>,
112 R::ModelEx: From<R::Model>,
113 <Self::Model as ModelTrait>::Entity: Related<R>;
114}
115
116impl<E> EntityOrSelect<E> for E
117where
118 E: EntityTrait,
119{
120 fn select(self) -> Select<E> {
121 E::find()
122 }
123}
124
125impl<E> EntityOrSelect<E> for Select<E>
126where
127 E: EntityTrait,
128{
129 fn select(self) -> Select<E> {
130 self
131 }
132}
133
134#[async_trait]
135impl<M> LoaderTrait for Vec<M>
136where
137 M: ModelTrait + Sync,
138{
139 type Model = M;
140
141 async fn load_one<R, S, C>(&self, stmt: S, db: &C) -> Result<Vec<Option<R::Model>>, DbErr>
142 where
143 C: ConnectionTrait,
144 R: EntityTrait,
145 R::Model: Send + Sync,
146 S: EntityOrSelect<R>,
147 <Self::Model as ModelTrait>::Entity: Related<R>,
148 {
149 LoaderTrait::load_one(&self.as_slice(), stmt, db).await
150 }
151
152 async fn load_many<R, S, C>(&self, stmt: S, db: &C) -> Result<Vec<Vec<R::Model>>, DbErr>
153 where
154 C: ConnectionTrait,
155 R: EntityTrait,
156 R::Model: Send + Sync,
157 S: EntityOrSelect<R>,
158 <Self::Model as ModelTrait>::Entity: Related<R>,
159 {
160 LoaderTrait::load_many(&self.as_slice(), stmt, db).await
161 }
162
163 async fn load_many_to_many<R, S, V, C>(
164 &self,
165 stmt: S,
166 via: V,
167 db: &C,
168 ) -> Result<Vec<Vec<R::Model>>, DbErr>
169 where
170 C: ConnectionTrait,
171 R: EntityTrait,
172 R::Model: Send + Sync,
173 S: EntityOrSelect<R>,
174 V: EntityTrait,
175 V::Model: Send + Sync,
176 <Self::Model as ModelTrait>::Entity: Related<R>,
177 {
178 LoaderTrait::load_many_to_many(&self.as_slice(), stmt, via, db).await
179 }
180}
181
182#[async_trait]
183impl<M> LoaderTrait for &[M]
184where
185 M: ModelTrait + Sync,
186{
187 type Model = M;
188
189 async fn load_one<R, S, C>(&self, stmt: S, db: &C) -> Result<Vec<Option<R::Model>>, DbErr>
190 where
191 C: ConnectionTrait,
192 R: EntityTrait,
193 R::Model: Send + Sync,
194 S: EntityOrSelect<R>,
195 <Self::Model as ModelTrait>::Entity: Related<R>,
196 {
197 let rel_def = <<Self::Model as ModelTrait>::Entity as Related<R>>::to();
198 if rel_def.rel_type != RelationType::HasOne {
199 return Err(query_err("Relation is HasMany instead of HasOne"));
200 }
201 loader_impl(self.iter(), stmt.select(), db).await
202 }
203
204 async fn load_many<R, S, C>(&self, stmt: S, db: &C) -> Result<Vec<Vec<R::Model>>, DbErr>
205 where
206 C: ConnectionTrait,
207 R: EntityTrait,
208 R::Model: Send + Sync,
209 S: EntityOrSelect<R>,
210 <Self::Model as ModelTrait>::Entity: Related<R>,
211 {
212 loader_impl(self.iter(), stmt.select(), db).await
213 }
214
215 async fn load_many_to_many<R, S, V, C>(
216 &self,
217 stmt: S,
218 via: V,
219 db: &C,
220 ) -> Result<Vec<Vec<R::Model>>, DbErr>
221 where
222 C: ConnectionTrait,
223 R: EntityTrait,
224 R::Model: Send + Sync,
225 S: EntityOrSelect<R>,
226 V: EntityTrait,
227 V::Model: Send + Sync,
228 <Self::Model as ModelTrait>::Entity: Related<R>,
229 {
230 if let Some(via_rel) = <<Self::Model as ModelTrait>::Entity as Related<R>>::via() {
231 let rel_def = <<Self::Model as ModelTrait>::Entity as Related<R>>::to();
232 if rel_def.rel_type != RelationType::HasOne {
233 return Err(query_err("Relation to is not HasOne"));
234 }
235
236 if !cmp_table_ref(&via_rel.to_tbl, &via.table_ref()) {
237 return Err(query_err(format!(
238 "The given via Entity is incorrect: expected: {:?}, given: {:?}",
239 via_rel.to_tbl,
240 via.table_ref()
241 )));
242 }
243
244 if self.is_empty() {
245 return Ok(Vec::new());
246 }
247
248 let pkeys = self
249 .iter()
250 .map(|model| extract_key(&via_rel.from_col, model))
251 .collect::<Result<Vec<_>, _>>()?;
252
253 let mut keymap: HashMap<ValueTuple, Vec<ValueTuple>> = Default::default();
255
256 let keys: Vec<ValueTuple> = {
257 let condition = prepare_condition::<M>(
258 &via_rel.to_tbl,
259 &via_rel.from_col,
260 &via_rel.to_col,
261 &pkeys,
262 db,
263 )?;
264 let stmt = V::find().filter(condition);
265 let data = stmt.all(db).await?;
266 for model in data {
267 let pk = extract_key(&via_rel.to_col, &model)?;
268 let entry = keymap.entry(pk).or_default();
269
270 let fk = extract_key(&rel_def.from_col, &model)?;
271 entry.push(fk);
272 }
273
274 keymap.values().flatten().cloned().collect()
275 };
276
277 let condition = prepare_condition::<V::Model>(
278 &rel_def.to_tbl,
279 &rel_def.from_col,
280 &rel_def.to_col,
281 &keys,
282 db,
283 )?;
284
285 let stmt = QueryFilter::filter(stmt.select(), condition);
286
287 let models = stmt.all(db).await?;
288
289 let data = models.into_iter().try_fold(
291 HashMap::<ValueTuple, <R as EntityTrait>::Model>::new(),
292 |mut acc, model| {
293 extract_key(&rel_def.to_col, &model).map(|key| {
294 acc.insert(key, model);
295
296 acc
297 })
298 },
299 )?;
300
301 let result: Vec<Vec<R::Model>> = pkeys
302 .into_iter()
303 .map(|pkey| {
304 let fkeys = keymap.get(&pkey).cloned().unwrap_or_default();
305
306 let models: Vec<_> = fkeys
307 .into_iter()
308 .filter_map(|fkey| data.get(&fkey).cloned())
309 .collect();
310
311 models
312 })
313 .collect();
314
315 Ok(result)
316 } else {
317 return Err(query_err("Relation is not ManyToMany"));
318 }
319 }
320}
321
322#[async_trait]
323impl<M> LoaderTraitEx for &[M]
324where
325 M: ModelTrait + Sync,
326{
327 type Model = M;
328
329 async fn load_one_ex<R, S, C>(&self, stmt: S, db: &C) -> Result<Vec<Option<R::ModelEx>>, DbErr>
330 where
331 C: ConnectionTrait,
332 R: EntityTrait,
333 R::Model: Send + Sync,
334 S: EntityOrSelect<R>,
335 R::ModelEx: From<R::Model>,
336 <Self::Model as ModelTrait>::Entity: Related<R>,
337 {
338 let rel_def = <<Self::Model as ModelTrait>::Entity as Related<R>>::to();
339 if rel_def.rel_type != RelationType::HasOne {
340 return Err(query_err("Relation is HasMany instead of HasOne"));
341 }
342 loader_impl(self.iter(), stmt.select(), db).await
343 }
344
345 async fn load_many_ex<R, S, C>(&self, stmt: S, db: &C) -> Result<Vec<Vec<R::ModelEx>>, DbErr>
346 where
347 C: ConnectionTrait,
348 R: EntityTrait,
349 R::Model: Send + Sync,
350 S: EntityOrSelect<R>,
351 R::ModelEx: From<R::Model>,
352 <Self::Model as ModelTrait>::Entity: Related<R>,
353 {
354 loader_impl(self.iter(), stmt.select(), db).await
355 }
356}
357
358#[async_trait]
359impl<M> LoaderTraitEx for &[Option<M>]
360where
361 M: ModelTrait + Sync,
362{
363 type Model = M;
364
365 async fn load_one_ex<R, S, C>(&self, stmt: S, db: &C) -> Result<Vec<Option<R::ModelEx>>, DbErr>
366 where
367 C: ConnectionTrait,
368 R: EntityTrait,
369 R::Model: Send + Sync,
370 S: EntityOrSelect<R>,
371 R::ModelEx: From<R::Model>,
372 <Self::Model as ModelTrait>::Entity: Related<R>,
373 {
374 let rel_def = <<Self::Model as ModelTrait>::Entity as Related<R>>::to();
375 if rel_def.rel_type != RelationType::HasOne {
376 return Err(query_err("Relation is HasMany instead of HasOne"));
377 }
378 let items: Vec<Option<R::ModelEx>> =
379 loader_impl(self.iter().filter_map(|o| o.as_ref()), stmt.select(), db).await?;
380 Ok(assemble_options(self, items))
381 }
382
383 async fn load_many_ex<R, S, C>(&self, stmt: S, db: &C) -> Result<Vec<Vec<R::ModelEx>>, DbErr>
384 where
385 C: ConnectionTrait,
386 R: EntityTrait,
387 R::Model: Send + Sync,
388 S: EntityOrSelect<R>,
389 R::ModelEx: From<R::Model>,
390 <Self::Model as ModelTrait>::Entity: Related<R>,
391 {
392 let items: Vec<Vec<R::ModelEx>> =
393 loader_impl(self.iter().filter_map(|o| o.as_ref()), stmt.select(), db).await?;
394 Ok(assemble_options(self, items))
395 }
396}
397
398#[async_trait]
399impl<M> NestedLoaderTrait for &[Vec<M>]
400where
401 M: ModelTrait + Sync,
402{
403 type Model = M;
404
405 async fn load_one_ex<R, S, C>(
406 &self,
407 stmt: S,
408 db: &C,
409 ) -> Result<Vec<Vec<Option<R::ModelEx>>>, DbErr>
410 where
411 C: ConnectionTrait,
412 R: EntityTrait,
413 R::Model: Send + Sync,
414 S: EntityOrSelect<R>,
415 R::ModelEx: From<R::Model>,
416 <Self::Model as ModelTrait>::Entity: Related<R>,
417 {
418 let rel_def = <<Self::Model as ModelTrait>::Entity as Related<R>>::to();
419 if rel_def.rel_type != RelationType::HasOne {
420 return Err(query_err("Relation is HasMany instead of HasOne"));
421 }
422 let items: Vec<Option<R::ModelEx>> =
423 loader_impl(self.iter().flatten(), stmt.select(), db).await?;
424 Ok(assemble_vectors(self, items))
425 }
426
427 async fn load_many_ex<R, S, C>(
428 &self,
429 stmt: S,
430 db: &C,
431 ) -> Result<Vec<Vec<Vec<R::ModelEx>>>, DbErr>
432 where
433 C: ConnectionTrait,
434 R: EntityTrait,
435 R::Model: Send + Sync,
436 S: EntityOrSelect<R>,
437 R::ModelEx: From<R::Model>,
438 <Self::Model as ModelTrait>::Entity: Related<R>,
439 {
440 let items: Vec<Vec<R::ModelEx>> =
441 loader_impl(self.iter().flatten(), stmt.select(), db).await?;
442 Ok(assemble_vectors(self, items))
443 }
444}
445
446fn assemble_options<I, T: Default>(input: &[Option<I>], items: Vec<T>) -> Vec<T> {
447 let mut items = items.into_iter();
448 let mut output = Vec::new();
449 for input in input.iter() {
450 if input.is_some() {
451 output.push(items.next().unwrap_or_default());
452 } else {
453 output.push(T::default());
454 }
455 }
456 output
457}
458
459fn assemble_vectors<I, T: Default>(input: &[Vec<I>], items: Vec<T>) -> Vec<Vec<T>> {
460 let mut items = items.into_iter();
461
462 let mut output = Vec::new();
463
464 for input in input.iter() {
465 output.push(Vec::new());
466
467 for _inner in input.iter() {
468 output
469 .last_mut()
470 .expect("Pushed above")
471 .push(items.next().unwrap_or_default());
472 }
473 }
474
475 output
476}
477
478trait Container: Default + Clone {
479 type Item;
480 fn add(&mut self, item: Self::Item);
481}
482
483impl<T: Clone> Container for Vec<T> {
484 type Item = T;
485 fn add(&mut self, item: Self::Item) {
486 self.push(item);
487 }
488}
489
490impl<T: Clone> Container for Option<T> {
491 type Item = T;
492 fn add(&mut self, item: Self::Item) {
493 self.replace(item);
494 }
495}
496
497async fn loader_impl<'a, Model, Iter, R, C, T, Output>(
498 items: Iter,
499 stmt: Select<R>,
500 db: &C,
501) -> Result<Vec<T>, DbErr>
502where
503 Model: ModelTrait + Sync + 'a,
504 Iter: Iterator<Item = &'a Model> + 'a,
505 C: ConnectionTrait,
506 R: EntityTrait,
507 R::Model: Send + Sync,
508 Model::Entity: Related<R>,
509 Output: From<R::Model>,
510 T: Container<Item = Output>,
511{
512 let (keys, hashmap) = if let Some(via_def) = <Model::Entity as Related<R>>::via() {
513 let keys = items
514 .map(|model| extract_key(&via_def.from_col, model))
515 .collect::<Result<Vec<_>, _>>()?;
516
517 if keys.is_empty() {
518 return Ok(Vec::new());
519 }
520
521 let condition = prepare_condition::<Model>(
522 &via_def.to_tbl,
523 &via_def.from_col,
524 &via_def.to_col,
525 &keys,
526 db,
527 )?;
528
529 let stmt = QueryFilter::filter(
530 stmt.join_rev(JoinType::InnerJoin, <Model::Entity as Related<R>>::to()),
531 condition,
532 );
533
534 let data = stmt
544 .select_also_dyn_model(
545 via_def.to_tbl.sea_orm_table().clone(),
546 dynamic::ModelType {
547 fields: extract_col_type::<Model>(&via_def.from_col, &via_def.to_col)?,
549 },
550 )
551 .all(db)
552 .await?;
553
554 let mut hashmap: HashMap<ValueTuple, T> =
555 keys.iter()
556 .fold(HashMap::new(), |mut acc, key: &ValueTuple| {
557 acc.insert(key.clone(), T::default());
558 acc
559 });
560
561 for (item, key) in data {
562 let key = dyn_model_to_key(key)?;
563
564 let vec = hashmap.get_mut(&key).ok_or_else(|| {
565 DbErr::RecordNotFound(format!("Loader: failed to find model for {key:?}"))
566 })?;
567
568 vec.add(item.into());
569 }
570
571 (keys, hashmap)
572 } else {
573 let rel_def = <Model::Entity as Related<R>>::to();
574
575 let keys = items
576 .map(|model| extract_key(&rel_def.from_col, model))
577 .collect::<Result<Vec<_>, _>>()?;
578
579 if keys.is_empty() {
580 return Ok(Vec::new());
581 }
582
583 let condition = prepare_condition::<Model>(
584 &rel_def.to_tbl,
585 &rel_def.from_col,
586 &rel_def.to_col,
587 &keys,
588 db,
589 )?;
590
591 let stmt = QueryFilter::filter(stmt, condition);
592
593 let data = stmt.all(db).await?;
594
595 let mut hashmap: HashMap<ValueTuple, T> = Default::default();
596
597 for item in data {
598 let key = extract_key(&rel_def.to_col, &item)?;
599 let holder = hashmap.entry(key).or_default();
600 holder.add(item.into());
601 }
602
603 (keys, hashmap)
604 };
605
606 let result: Vec<T> = keys
607 .iter()
608 .map(|key: &ValueTuple| hashmap.get(key).cloned().unwrap_or_default())
609 .collect();
610
611 Ok(result)
612}
613
614fn cmp_table_ref(left: &TableRef, right: &TableRef) -> bool {
615 left == right
616}
617
618fn extract_key<Model>(target_col: &Identity, model: &Model) -> Result<ValueTuple, DbErr>
619where
620 Model: ModelTrait,
621{
622 let values = target_col
623 .iter()
624 .map(|col| {
625 let col_name = col.inner();
626 let column =
627 <<<Model as ModelTrait>::Entity as EntityTrait>::Column as FromStr>::from_str(
628 &col_name,
629 )
630 .map_err(|_| DbErr::Type(format!("Failed at mapping '{col_name}' to column")))?;
631 Ok(model.get(column))
632 })
633 .collect::<Result<Vec<_>, DbErr>>()?;
634
635 Ok(match values.len() {
636 0 => return Err(DbErr::Type("Identity zero?".into())),
637 1 => ValueTuple::One(values.into_iter().next().expect("checked")),
638 2 => {
639 let mut it = values.into_iter();
640 ValueTuple::Two(it.next().expect("checked"), it.next().expect("checked"))
641 }
642 3 => {
643 let mut it = values.into_iter();
644 ValueTuple::Three(
645 it.next().expect("checked"),
646 it.next().expect("checked"),
647 it.next().expect("checked"),
648 )
649 }
650 _ => ValueTuple::Many(values),
651 })
652}
653
654fn extract_col_type<Model>(
655 left: &Identity,
656 right: &Identity,
657) -> Result<Vec<dynamic::FieldType>, DbErr>
658where
659 Model: ModelTrait,
660{
661 use itertools::Itertools;
662
663 if left.arity() != right.arity() {
664 return Err(DbErr::Type(format!(
665 "Identity mismatch: left: {} != right: {}",
666 left.arity(),
667 right.arity()
668 )));
669 }
670
671 let vec = left
672 .iter()
673 .zip_eq(right.iter())
674 .map(|(l, r)| {
675 let col_a =
676 <<<Model as ModelTrait>::Entity as EntityTrait>::Column as FromStr>::from_str(
677 &l.inner(),
678 )
679 .map_err(|_| DbErr::Type(format!("Failed at mapping '{l}'")))?;
680 Ok(dynamic::FieldType::new(
681 r.clone(),
682 Model::get_value_type(col_a),
683 ))
684 })
685 .collect::<Result<Vec<_>, DbErr>>()?;
686
687 Ok(vec)
688}
689
690#[allow(clippy::unwrap_used)]
691fn dyn_model_to_key(dyn_model: dynamic::Model) -> Result<ValueTuple, DbErr> {
692 Ok(match dyn_model.fields.len() {
693 0 => return Err(DbErr::Type("Identity zero?".into())),
694 1 => ValueTuple::One(dyn_model.fields.into_iter().next().unwrap().value),
695 2 => {
696 let mut iter = dyn_model.fields.into_iter();
697 ValueTuple::Two(iter.next().unwrap().value, iter.next().unwrap().value)
698 }
699 3 => {
700 let mut iter = dyn_model.fields.into_iter();
701 ValueTuple::Three(
702 iter.next().unwrap().value,
703 iter.next().unwrap().value,
704 iter.next().unwrap().value,
705 )
706 }
707 _ => ValueTuple::Many(dyn_model.fields.into_iter().map(|v| v.value).collect()),
708 })
709}
710
711fn arity_mismatch(expected: usize, actual: &ValueTuple) -> DbErr {
712 DbErr::Type(format!(
713 "Loader: arity mismatch: expected {expected}, got {} in {actual:?}",
714 actual.arity()
715 ))
716}
717
718#[inline]
719fn prepare_condition<Model>(
720 table: &TableRef,
721 from: &Identity,
722 to: &Identity,
723 keys: &[ValueTuple],
724 db: &impl ConnectionTrait,
725) -> Result<Condition, DbErr>
726where
727 Model: ModelTrait,
728{
729 let db_backend = db.get_database_backend();
730 if matches!(db_backend, DbBackend::Postgres) {
731 prepare_condition_with_save_as::<Model>(table, from, to, keys)
732 } else {
733 prepare_condition_simple(table, to, keys, db_backend)
734 }
735}
736
737fn prepare_condition_with_save_as<Model>(
738 table: &TableRef,
739 from: &Identity,
740 to: &Identity,
741 keys: &[ValueTuple],
742) -> Result<Condition, DbErr>
743where
744 Model: ModelTrait,
745{
746 use itertools::Itertools;
747
748 let keys = keys.iter().unique();
749 let (from_cols, to_cols) = resolve_column_pairs::<Model>(table, from, to)?;
750
751 if from_cols.is_empty() || to_cols.is_empty() {
752 return Err(DbErr::Type(format!(
753 "Loader: resolved zero columns for identities {from:?} -> {to:?}"
754 )));
755 }
756
757 let arity = from_cols.len();
758
759 let value_tuples = keys
760 .map(|key| {
761 let key_arity = key.arity();
762 if arity != key_arity {
763 return Err(arity_mismatch(arity, key));
764 }
765
766 Ok(apply_save_as::<Model>(&from_cols, key.clone()))
768 })
769 .collect::<Result<Vec<_>, DbErr>>()?;
770
771 let expr = Expr::tuple(create_table_columns(table, to)).is_in(value_tuples);
773
774 Ok(expr.into())
775}
776
777fn prepare_condition_simple(
779 table: &TableRef,
780 to: &Identity,
781 keys: &[ValueTuple],
782 backend: DbBackend,
783) -> Result<Condition, DbErr> {
784 use itertools::Itertools;
785
786 let arity = to.arity();
787 let keys = keys.iter().unique();
788
789 let table_columns = create_table_columns(table, to);
790
791 if cfg!(feature = "sqlite-no-row-value-before-3_15") && matches!(backend, DbBackend::Sqlite) {
792 let mut outer = Condition::any();
795
796 for key in keys {
797 let key_arity = key.arity();
798 if arity != key_arity {
799 return Err(arity_mismatch(arity, key));
800 }
801
802 let table_columns = table_columns.iter().cloned();
803 let values = key.clone().into_iter().map(Expr::val);
804
805 let inner = table_columns
806 .zip(values)
807 .fold(Condition::all(), |cond, (column, value)| {
808 cond.add(column.eq(value))
809 });
810
811 outer = outer.add(inner);
813 }
814
815 Ok(outer)
816 } else {
817 let value_tuples = keys
819 .map(|key| {
820 let key_arity = key.arity();
821 if arity != key_arity {
822 return Err(arity_mismatch(arity, key));
823 }
824
825 let tuple_exprs = key.clone().into_iter().map(Expr::val);
826
827 Ok(Expr::tuple(tuple_exprs))
828 })
829 .collect::<Result<Vec<_>, DbErr>>()?;
830
831 let expr = Expr::tuple(table_columns).is_in(value_tuples);
833
834 Ok(expr.into())
835 }
836}
837
838type ModelColumn<M> = <<M as ModelTrait>::Entity as EntityTrait>::Column;
839
840type ColumnPairs<M> = (Vec<ModelColumn<M>>, Vec<ColumnRef>);
841
842fn resolve_column_pairs<Model>(
843 table: &TableRef,
844 from: &Identity,
845 to: &Identity,
846) -> Result<ColumnPairs<Model>, DbErr>
847where
848 Model: ModelTrait,
849 ModelColumn<Model>: ColumnTrait,
850{
851 let from_columns = parse_identity_columns::<Model>(from)?;
852 let to_columns = column_refs_from_identity(table, to);
853
854 if from_columns.len() != to_columns.len() {
855 return Err(DbErr::Type(format!(
856 "Loader: identity column count mismatch between {from:?} and {to:?}"
857 )));
858 }
859
860 Ok((from_columns, to_columns))
861}
862
863fn column_refs_from_identity(table: &TableRef, identity: &Identity) -> Vec<ColumnRef> {
864 identity
865 .iter()
866 .map(|col| table_column(table, col))
867 .collect()
868}
869
870fn parse_identity_columns<Model>(identity: &Identity) -> Result<Vec<ModelColumn<Model>>, DbErr>
871where
872 Model: ModelTrait,
873{
874 identity
875 .iter()
876 .map(|from_col| try_conv_ident_to_column::<Model>(from_col))
877 .collect()
878}
879
880fn try_conv_ident_to_column<Model>(ident: &DynIden) -> Result<ModelColumn<Model>, DbErr>
881where
882 Model: ModelTrait,
883{
884 let column_name = ident.inner();
885 ModelColumn::<Model>::from_str(&column_name)
886 .map_err(|_| DbErr::Type(format!("Failed at mapping '{column_name}' to column")))
887}
888
889fn table_column(tbl: &TableRef, col: &DynIden) -> ColumnRef {
890 (tbl.sea_orm_table().to_owned(), col.clone()).into_column_ref()
891}
892
893fn create_table_columns(table: &TableRef, cols: &Identity) -> Vec<Expr> {
895 cols.iter()
896 .cloned()
897 .map(|col| table_column(table, &col))
898 .map(Expr::col)
899 .collect()
900}
901
902fn apply_save_as<M: ModelTrait>(cols: &[ModelColumn<M>], values: ValueTuple) -> Expr {
904 let values_expr_iter = values.into_iter().map(Expr::val);
905
906 let tuple_exprs: Vec<_> = cols
907 .iter()
908 .zip(values_expr_iter)
909 .map(|(model_column, value)| model_column.save_as(value))
910 .collect();
911
912 Expr::tuple(tuple_exprs)
913}
914
915#[cfg(test)]
916mod tests {
917 fn cake_model(id: i32) -> sea_orm::tests_cfg::cake::Model {
918 let name = match id {
919 1 => "apple cake",
920 2 => "orange cake",
921 3 => "fruit cake",
922 4 => "chocolate cake",
923 _ => "",
924 }
925 .to_string();
926 sea_orm::tests_cfg::cake::Model { id, name }
927 }
928
929 fn fruit_model(id: i32, cake_id: Option<i32>) -> sea_orm::tests_cfg::fruit::Model {
930 let name = match id {
931 1 => "apple",
932 2 => "orange",
933 3 => "grape",
934 4 => "strawberry",
935 _ => "",
936 }
937 .to_string();
938 sea_orm::tests_cfg::fruit::Model { id, name, cake_id }
939 }
940
941 fn filling_model(id: i32) -> sea_orm::tests_cfg::filling::Model {
942 let name = match id {
943 1 => "apple juice",
944 2 => "orange jam",
945 3 => "chocolate crust",
946 4 => "strawberry jam",
947 _ => "",
948 }
949 .to_string();
950 sea_orm::tests_cfg::filling::Model {
951 id,
952 name,
953 vendor_id: Some(1),
954 ignored_attr: 0,
955 }
956 }
957
958 fn cake_filling_model(
959 cake_id: i32,
960 filling_id: i32,
961 ) -> sea_orm::tests_cfg::cake_filling::Model {
962 sea_orm::tests_cfg::cake_filling::Model {
963 cake_id,
964 filling_id,
965 }
966 }
967
968 #[tokio::test]
969 async fn test_load_one() {
970 use sea_orm::{DbBackend, LoaderTrait, MockDatabase, entity::prelude::*, tests_cfg::*};
971
972 let db = MockDatabase::new(DbBackend::Postgres)
973 .append_query_results([[cake_model(1), cake_model(2)]])
974 .into_connection();
975
976 let fruits = vec![fruit_model(1, Some(1))];
977
978 let cakes = fruits
979 .load_one(cake::Entity::find(), &db)
980 .await
981 .expect("Should return something");
982
983 assert_eq!(cakes, [Some(cake_model(1))]);
984 }
985
986 #[tokio::test]
987 async fn test_load_one_same_cake() {
988 use sea_orm::{DbBackend, LoaderTrait, MockDatabase, entity::prelude::*, tests_cfg::*};
989
990 let db = MockDatabase::new(DbBackend::Postgres)
991 .append_query_results([[cake_model(1), cake_model(2)]])
992 .into_connection();
993
994 let fruits = vec![fruit_model(1, Some(1)), fruit_model(2, Some(1))];
995
996 let cakes = fruits
997 .load_one(cake::Entity::find(), &db)
998 .await
999 .expect("Should return something");
1000
1001 assert_eq!(cakes, [Some(cake_model(1)), Some(cake_model(1))]);
1002 }
1003
1004 #[tokio::test]
1005 async fn test_load_one_empty() {
1006 use sea_orm::{DbBackend, LoaderTrait, MockDatabase, entity::prelude::*, tests_cfg::*};
1007
1008 let db = MockDatabase::new(DbBackend::Postgres)
1009 .append_query_results([[cake_model(1), cake_model(2)]])
1010 .into_connection();
1011
1012 let fruits: Vec<fruit::Model> = vec![];
1013
1014 let cakes = fruits
1015 .load_one(cake::Entity::find(), &db)
1016 .await
1017 .expect("Should return something");
1018
1019 assert_eq!(cakes, []);
1020 }
1021
1022 #[tokio::test]
1023 async fn test_load_many() {
1024 use sea_orm::{DbBackend, LoaderTrait, MockDatabase, entity::prelude::*, tests_cfg::*};
1025
1026 let db = MockDatabase::new(DbBackend::Postgres)
1027 .append_query_results([[fruit_model(1, Some(1))]])
1028 .into_connection();
1029
1030 let cakes = vec![cake_model(1), cake_model(2)];
1031
1032 let fruits = cakes
1033 .load_many(fruit::Entity::find(), &db)
1034 .await
1035 .expect("Should return something");
1036
1037 assert_eq!(fruits, [vec![fruit_model(1, Some(1))], vec![]]);
1038 }
1039
1040 #[tokio::test]
1041 async fn test_load_many_same_fruit() {
1042 use sea_orm::{DbBackend, LoaderTrait, MockDatabase, entity::prelude::*, tests_cfg::*};
1043
1044 let db = MockDatabase::new(DbBackend::Postgres)
1045 .append_query_results([[fruit_model(1, Some(1)), fruit_model(2, Some(1))]])
1046 .into_connection();
1047
1048 let cakes = vec![cake_model(1), cake_model(2)];
1049
1050 let fruits = cakes
1051 .load_many(fruit::Entity::find(), &db)
1052 .await
1053 .expect("Should return something");
1054
1055 assert_eq!(
1056 fruits,
1057 [
1058 vec![fruit_model(1, Some(1)), fruit_model(2, Some(1))],
1059 vec![]
1060 ]
1061 );
1062 }
1063
1064 #[tokio::test]
1065 async fn test_load_many_empty() {
1066 use sea_orm::{DbBackend, MockDatabase, entity::prelude::*, tests_cfg::*};
1067
1068 let db = MockDatabase::new(DbBackend::Postgres).into_connection();
1069
1070 let cakes: Vec<cake::Model> = vec![];
1071
1072 let fruits = cakes
1073 .load_many(fruit::Entity::find(), &db)
1074 .await
1075 .expect("Should return something");
1076
1077 let empty_vec: Vec<Vec<fruit::Model>> = vec![];
1078
1079 assert_eq!(fruits, empty_vec);
1080 }
1081
1082 #[tokio::test]
1083 async fn test_load_many_to_many_base() {
1084 use sea_orm::{DbBackend, IntoMockRow, LoaderTrait, MockDatabase, tests_cfg::*};
1085
1086 let db = MockDatabase::new(DbBackend::Postgres)
1087 .append_query_results([
1088 [cake_filling_model(1, 1).into_mock_row()],
1089 [filling_model(1).into_mock_row()],
1090 ])
1091 .into_connection();
1092
1093 let cakes = vec![cake_model(1)];
1094
1095 let fillings = cakes
1096 .load_many_to_many(Filling, CakeFilling, &db)
1097 .await
1098 .expect("Should return something");
1099
1100 assert_eq!(fillings, vec![vec![filling_model(1)]]);
1101 }
1102
1103 #[tokio::test]
1104 async fn test_load_many_to_many_complex() {
1105 use sea_orm::{DbBackend, IntoMockRow, LoaderTrait, MockDatabase, tests_cfg::*};
1106
1107 let db = MockDatabase::new(DbBackend::Postgres)
1108 .append_query_results([
1109 [
1110 cake_filling_model(1, 1).into_mock_row(),
1111 cake_filling_model(1, 2).into_mock_row(),
1112 cake_filling_model(1, 3).into_mock_row(),
1113 cake_filling_model(2, 1).into_mock_row(),
1114 cake_filling_model(2, 2).into_mock_row(),
1115 ],
1116 [
1117 filling_model(1).into_mock_row(),
1118 filling_model(2).into_mock_row(),
1119 filling_model(3).into_mock_row(),
1120 filling_model(4).into_mock_row(),
1121 filling_model(5).into_mock_row(),
1122 ],
1123 ])
1124 .into_connection();
1125
1126 let cakes = vec![cake_model(1), cake_model(2), cake_model(3)];
1127
1128 let fillings = cakes
1129 .load_many_to_many(Filling, CakeFilling, &db)
1130 .await
1131 .expect("Should return something");
1132
1133 assert_eq!(
1134 fillings,
1135 vec![
1136 vec![filling_model(1), filling_model(2), filling_model(3)],
1137 vec![filling_model(1), filling_model(2)],
1138 vec![],
1139 ]
1140 );
1141 }
1142
1143 #[tokio::test]
1144 async fn test_load_many_to_many_empty() {
1145 use sea_orm::{DbBackend, IntoMockRow, LoaderTrait, MockDatabase, tests_cfg::*};
1146
1147 let db = MockDatabase::new(DbBackend::Postgres)
1148 .append_query_results([
1149 [cake_filling_model(1, 1).into_mock_row()],
1150 [filling_model(1).into_mock_row()],
1151 ])
1152 .into_connection();
1153
1154 let cakes: Vec<cake::Model> = vec![];
1155
1156 let fillings = cakes
1157 .load_many_to_many(Filling, CakeFilling, &db)
1158 .await
1159 .expect("Should return something");
1160
1161 let empty_vec: Vec<Vec<filling::Model>> = vec![];
1162
1163 assert_eq!(fillings, empty_vec);
1164 }
1165
1166 #[tokio::test]
1167 async fn test_load_one_duplicate_keys() {
1168 use sea_orm::{DbBackend, LoaderTrait, MockDatabase, entity::prelude::*, tests_cfg::*};
1169
1170 let db = MockDatabase::new(DbBackend::Postgres)
1171 .append_query_results([[cake_model(1), cake_model(2)]])
1172 .into_connection();
1173
1174 let fruits = vec![
1175 fruit_model(1, Some(1)),
1176 fruit_model(2, Some(1)),
1177 fruit_model(3, Some(1)),
1178 fruit_model(4, Some(1)),
1179 ];
1180
1181 let cakes = fruits
1182 .load_one(cake::Entity::find(), &db)
1183 .await
1184 .expect("Should return something");
1185
1186 assert_eq!(cakes.len(), 4);
1187 for cake in &cakes {
1188 assert_eq!(cake, &Some(cake_model(1)));
1189 }
1190 let logs = db.into_transaction_log();
1191 let sql = format!("{:?}", logs[0]);
1192
1193 let values_count = sql.matches("$1").count();
1194 assert_eq!(values_count, 1, "Duplicate values were not removed");
1195 }
1196
1197 #[tokio::test]
1198 async fn test_load_many_duplicate_keys() {
1199 use sea_orm::{DbBackend, LoaderTrait, MockDatabase, entity::prelude::*, tests_cfg::*};
1200
1201 let db = MockDatabase::new(DbBackend::Postgres)
1202 .append_query_results([[
1203 fruit_model(1, Some(1)),
1204 fruit_model(2, Some(1)),
1205 fruit_model(3, Some(2)),
1206 ]])
1207 .into_connection();
1208
1209 let cakes = vec![cake_model(1), cake_model(1), cake_model(2), cake_model(2)];
1210
1211 let fruits = cakes
1212 .load_many(fruit::Entity::find(), &db)
1213 .await
1214 .expect("Should return something");
1215
1216 assert_eq!(fruits.len(), 4);
1217
1218 let logs = db.into_transaction_log();
1219 let sql = format!("{:?}", logs[0]);
1220
1221 let values_count = sql.matches("$1").count() + sql.matches("$2").count();
1222 assert_eq!(values_count, 2, "Duplicate values were not removed");
1223 }
1224
1225 #[test]
1226 fn test_assemble_vectors() {
1227 use super::assemble_vectors;
1228
1229 assert_eq!(
1230 assemble_vectors(&[vec![1], vec![], vec![2, 3], vec![]], vec![11, 22, 33]),
1231 [vec![11], vec![], vec![22, 33], vec![]]
1232 );
1233 }
1234}