sea_orm/query/
loader.rs

1use crate::{
2    ColumnTrait, Condition, ConnectionTrait, DbBackend, DbErr, EntityTrait, Identity, JoinType,
3    ModelTrait, QueryFilter, QuerySelect, Related, RelationType, Select, dynamic, error::*,
4};
5use async_trait::async_trait;
6use sea_query::{ColumnRef, DynIden, Expr, ExprTrait, IntoColumnRef, TableRef, ValueTuple};
7use std::{collections::HashMap, str::FromStr};
8
9// TODO: Replace DynIden::inner with a better API that without clone
10
11/// Entity, or a Select<Entity>; to be used as parameters in [`LoaderTrait`]
12pub trait EntityOrSelect<E: EntityTrait>: Send {
13    /// If self is Entity, use Entity::find()
14    fn select(self) -> Select<E>;
15}
16
17/// This trait implements the Data Loader API
18#[async_trait]
19pub trait LoaderTrait {
20    /// Source model
21    type Model: ModelTrait;
22
23    /// Used to eager load has_one relations
24    async fn load_one<R, S, C>(&self, stmt: S, db: &C) -> Result<Vec<Option<R::Model>>, DbErr>
25    where
26        C: ConnectionTrait,
27        R: EntityTrait,
28        R::Model: Send + Sync,
29        S: EntityOrSelect<R>,
30        <Self::Model as ModelTrait>::Entity: Related<R>;
31
32    /// Used to eager load has_many relations
33    async fn load_many<R, S, C>(&self, stmt: S, db: &C) -> Result<Vec<Vec<R::Model>>, DbErr>
34    where
35        C: ConnectionTrait,
36        R: EntityTrait,
37        R::Model: Send + Sync,
38        S: EntityOrSelect<R>,
39        <Self::Model as ModelTrait>::Entity: Related<R>;
40
41    /// Used to eager load many_to_many relations
42    async fn load_many_to_many<R, S, V, C>(
43        &self,
44        stmt: S,
45        via: V,
46        db: &C,
47    ) -> Result<Vec<Vec<R::Model>>, DbErr>
48    where
49        C: ConnectionTrait,
50        R: EntityTrait,
51        R::Model: Send + Sync,
52        S: EntityOrSelect<R>,
53        V: EntityTrait,
54        V::Model: Send + Sync,
55        <Self::Model as ModelTrait>::Entity: Related<R>;
56}
57
58#[doc(hidden)]
59#[async_trait]
60pub trait LoaderTraitEx {
61    /// Source model
62    type Model: ModelTrait;
63
64    async fn load_one_ex<R, S, C>(&self, stmt: S, db: &C) -> Result<Vec<Option<R::ModelEx>>, DbErr>
65    where
66        C: ConnectionTrait,
67        R: EntityTrait,
68        R::Model: Send + Sync,
69        S: EntityOrSelect<R>,
70        R::ModelEx: From<R::Model>,
71        <Self::Model as ModelTrait>::Entity: Related<R>;
72
73    async fn load_many_ex<R, S, C>(&self, stmt: S, db: &C) -> Result<Vec<Vec<R::ModelEx>>, DbErr>
74    where
75        C: ConnectionTrait,
76        R: EntityTrait,
77        R::Model: Send + Sync,
78        S: EntityOrSelect<R>,
79        R::ModelEx: From<R::Model>,
80        <Self::Model as ModelTrait>::Entity: Related<R>;
81}
82
83#[doc(hidden)]
84#[async_trait]
85pub trait NestedLoaderTrait {
86    /// Source model
87    type Model: ModelTrait;
88
89    async fn load_one_ex<R, S, C>(
90        &self,
91        stmt: S,
92        db: &C,
93    ) -> Result<Vec<Vec<Option<R::ModelEx>>>, DbErr>
94    where
95        C: ConnectionTrait,
96        R: EntityTrait,
97        R::Model: Send + Sync,
98        S: EntityOrSelect<R>,
99        R::ModelEx: From<R::Model>,
100        <Self::Model as ModelTrait>::Entity: Related<R>;
101
102    async fn load_many_ex<R, S, C>(
103        &self,
104        stmt: S,
105        db: &C,
106    ) -> Result<Vec<Vec<Vec<R::ModelEx>>>, DbErr>
107    where
108        C: ConnectionTrait,
109        R: EntityTrait,
110        R::Model: Send + Sync,
111        S: EntityOrSelect<R>,
112        R::ModelEx: From<R::Model>,
113        <Self::Model as ModelTrait>::Entity: Related<R>;
114}
115
116impl<E> EntityOrSelect<E> for E
117where
118    E: EntityTrait,
119{
120    fn select(self) -> Select<E> {
121        E::find()
122    }
123}
124
125impl<E> EntityOrSelect<E> for Select<E>
126where
127    E: EntityTrait,
128{
129    fn select(self) -> Select<E> {
130        self
131    }
132}
133
134#[async_trait]
135impl<M> LoaderTrait for Vec<M>
136where
137    M: ModelTrait + Sync,
138{
139    type Model = M;
140
141    async fn load_one<R, S, C>(&self, stmt: S, db: &C) -> Result<Vec<Option<R::Model>>, DbErr>
142    where
143        C: ConnectionTrait,
144        R: EntityTrait,
145        R::Model: Send + Sync,
146        S: EntityOrSelect<R>,
147        <Self::Model as ModelTrait>::Entity: Related<R>,
148    {
149        LoaderTrait::load_one(&self.as_slice(), stmt, db).await
150    }
151
152    async fn load_many<R, S, C>(&self, stmt: S, db: &C) -> Result<Vec<Vec<R::Model>>, DbErr>
153    where
154        C: ConnectionTrait,
155        R: EntityTrait,
156        R::Model: Send + Sync,
157        S: EntityOrSelect<R>,
158        <Self::Model as ModelTrait>::Entity: Related<R>,
159    {
160        LoaderTrait::load_many(&self.as_slice(), stmt, db).await
161    }
162
163    async fn load_many_to_many<R, S, V, C>(
164        &self,
165        stmt: S,
166        via: V,
167        db: &C,
168    ) -> Result<Vec<Vec<R::Model>>, DbErr>
169    where
170        C: ConnectionTrait,
171        R: EntityTrait,
172        R::Model: Send + Sync,
173        S: EntityOrSelect<R>,
174        V: EntityTrait,
175        V::Model: Send + Sync,
176        <Self::Model as ModelTrait>::Entity: Related<R>,
177    {
178        LoaderTrait::load_many_to_many(&self.as_slice(), stmt, via, db).await
179    }
180}
181
182#[async_trait]
183impl<M> LoaderTrait for &[M]
184where
185    M: ModelTrait + Sync,
186{
187    type Model = M;
188
189    async fn load_one<R, S, C>(&self, stmt: S, db: &C) -> Result<Vec<Option<R::Model>>, DbErr>
190    where
191        C: ConnectionTrait,
192        R: EntityTrait,
193        R::Model: Send + Sync,
194        S: EntityOrSelect<R>,
195        <Self::Model as ModelTrait>::Entity: Related<R>,
196    {
197        let rel_def = <<Self::Model as ModelTrait>::Entity as Related<R>>::to();
198        if rel_def.rel_type != RelationType::HasOne {
199            return Err(query_err("Relation is HasMany instead of HasOne"));
200        }
201        loader_impl(self.iter(), stmt.select(), db).await
202    }
203
204    async fn load_many<R, S, C>(&self, stmt: S, db: &C) -> Result<Vec<Vec<R::Model>>, DbErr>
205    where
206        C: ConnectionTrait,
207        R: EntityTrait,
208        R::Model: Send + Sync,
209        S: EntityOrSelect<R>,
210        <Self::Model as ModelTrait>::Entity: Related<R>,
211    {
212        loader_impl(self.iter(), stmt.select(), db).await
213    }
214
215    async fn load_many_to_many<R, S, V, C>(
216        &self,
217        stmt: S,
218        via: V,
219        db: &C,
220    ) -> Result<Vec<Vec<R::Model>>, DbErr>
221    where
222        C: ConnectionTrait,
223        R: EntityTrait,
224        R::Model: Send + Sync,
225        S: EntityOrSelect<R>,
226        V: EntityTrait,
227        V::Model: Send + Sync,
228        <Self::Model as ModelTrait>::Entity: Related<R>,
229    {
230        if let Some(via_rel) = <<Self::Model as ModelTrait>::Entity as Related<R>>::via() {
231            let rel_def = <<Self::Model as ModelTrait>::Entity as Related<R>>::to();
232            if rel_def.rel_type != RelationType::HasOne {
233                return Err(query_err("Relation to is not HasOne"));
234            }
235
236            if !cmp_table_ref(&via_rel.to_tbl, &via.table_ref()) {
237                return Err(query_err(format!(
238                    "The given via Entity is incorrect: expected: {:?}, given: {:?}",
239                    via_rel.to_tbl,
240                    via.table_ref()
241                )));
242            }
243
244            if self.is_empty() {
245                return Ok(Vec::new());
246            }
247
248            let pkeys = self
249                .iter()
250                .map(|model| extract_key(&via_rel.from_col, model))
251                .collect::<Result<Vec<_>, _>>()?;
252
253            // Map of M::PK -> Vec<R::PK>
254            let mut keymap: HashMap<ValueTuple, Vec<ValueTuple>> = Default::default();
255
256            let keys: Vec<ValueTuple> = {
257                let condition = prepare_condition::<M>(
258                    &via_rel.to_tbl,
259                    &via_rel.from_col,
260                    &via_rel.to_col,
261                    &pkeys,
262                    db,
263                )?;
264                let stmt = V::find().filter(condition);
265                let data = stmt.all(db).await?;
266                for model in data {
267                    let pk = extract_key(&via_rel.to_col, &model)?;
268                    let entry = keymap.entry(pk).or_default();
269
270                    let fk = extract_key(&rel_def.from_col, &model)?;
271                    entry.push(fk);
272                }
273
274                keymap.values().flatten().cloned().collect()
275            };
276
277            let condition = prepare_condition::<V::Model>(
278                &rel_def.to_tbl,
279                &rel_def.from_col,
280                &rel_def.to_col,
281                &keys,
282                db,
283            )?;
284
285            let stmt = QueryFilter::filter(stmt.select(), condition);
286
287            let models = stmt.all(db).await?;
288
289            // Map of R::PK -> R::Model
290            let data = models.into_iter().try_fold(
291                HashMap::<ValueTuple, <R as EntityTrait>::Model>::new(),
292                |mut acc, model| {
293                    extract_key(&rel_def.to_col, &model).map(|key| {
294                        acc.insert(key, model);
295
296                        acc
297                    })
298                },
299            )?;
300
301            let result: Vec<Vec<R::Model>> = pkeys
302                .into_iter()
303                .map(|pkey| {
304                    let fkeys = keymap.get(&pkey).cloned().unwrap_or_default();
305
306                    let models: Vec<_> = fkeys
307                        .into_iter()
308                        .filter_map(|fkey| data.get(&fkey).cloned())
309                        .collect();
310
311                    models
312                })
313                .collect();
314
315            Ok(result)
316        } else {
317            return Err(query_err("Relation is not ManyToMany"));
318        }
319    }
320}
321
322#[async_trait]
323impl<M> LoaderTraitEx for &[M]
324where
325    M: ModelTrait + Sync,
326{
327    type Model = M;
328
329    async fn load_one_ex<R, S, C>(&self, stmt: S, db: &C) -> Result<Vec<Option<R::ModelEx>>, DbErr>
330    where
331        C: ConnectionTrait,
332        R: EntityTrait,
333        R::Model: Send + Sync,
334        S: EntityOrSelect<R>,
335        R::ModelEx: From<R::Model>,
336        <Self::Model as ModelTrait>::Entity: Related<R>,
337    {
338        let rel_def = <<Self::Model as ModelTrait>::Entity as Related<R>>::to();
339        if rel_def.rel_type != RelationType::HasOne {
340            return Err(query_err("Relation is HasMany instead of HasOne"));
341        }
342        loader_impl(self.iter(), stmt.select(), db).await
343    }
344
345    async fn load_many_ex<R, S, C>(&self, stmt: S, db: &C) -> Result<Vec<Vec<R::ModelEx>>, DbErr>
346    where
347        C: ConnectionTrait,
348        R: EntityTrait,
349        R::Model: Send + Sync,
350        S: EntityOrSelect<R>,
351        R::ModelEx: From<R::Model>,
352        <Self::Model as ModelTrait>::Entity: Related<R>,
353    {
354        loader_impl(self.iter(), stmt.select(), db).await
355    }
356}
357
358#[async_trait]
359impl<M> LoaderTraitEx for &[Option<M>]
360where
361    M: ModelTrait + Sync,
362{
363    type Model = M;
364
365    async fn load_one_ex<R, S, C>(&self, stmt: S, db: &C) -> Result<Vec<Option<R::ModelEx>>, DbErr>
366    where
367        C: ConnectionTrait,
368        R: EntityTrait,
369        R::Model: Send + Sync,
370        S: EntityOrSelect<R>,
371        R::ModelEx: From<R::Model>,
372        <Self::Model as ModelTrait>::Entity: Related<R>,
373    {
374        let rel_def = <<Self::Model as ModelTrait>::Entity as Related<R>>::to();
375        if rel_def.rel_type != RelationType::HasOne {
376            return Err(query_err("Relation is HasMany instead of HasOne"));
377        }
378        let items: Vec<Option<R::ModelEx>> =
379            loader_impl(self.iter().filter_map(|o| o.as_ref()), stmt.select(), db).await?;
380        Ok(assemble_options(self, items))
381    }
382
383    async fn load_many_ex<R, S, C>(&self, stmt: S, db: &C) -> Result<Vec<Vec<R::ModelEx>>, DbErr>
384    where
385        C: ConnectionTrait,
386        R: EntityTrait,
387        R::Model: Send + Sync,
388        S: EntityOrSelect<R>,
389        R::ModelEx: From<R::Model>,
390        <Self::Model as ModelTrait>::Entity: Related<R>,
391    {
392        let items: Vec<Vec<R::ModelEx>> =
393            loader_impl(self.iter().filter_map(|o| o.as_ref()), stmt.select(), db).await?;
394        Ok(assemble_options(self, items))
395    }
396}
397
398#[async_trait]
399impl<M> NestedLoaderTrait for &[Vec<M>]
400where
401    M: ModelTrait + Sync,
402{
403    type Model = M;
404
405    async fn load_one_ex<R, S, C>(
406        &self,
407        stmt: S,
408        db: &C,
409    ) -> Result<Vec<Vec<Option<R::ModelEx>>>, DbErr>
410    where
411        C: ConnectionTrait,
412        R: EntityTrait,
413        R::Model: Send + Sync,
414        S: EntityOrSelect<R>,
415        R::ModelEx: From<R::Model>,
416        <Self::Model as ModelTrait>::Entity: Related<R>,
417    {
418        let rel_def = <<Self::Model as ModelTrait>::Entity as Related<R>>::to();
419        if rel_def.rel_type != RelationType::HasOne {
420            return Err(query_err("Relation is HasMany instead of HasOne"));
421        }
422        let items: Vec<Option<R::ModelEx>> =
423            loader_impl(self.iter().flatten(), stmt.select(), db).await?;
424        Ok(assemble_vectors(self, items))
425    }
426
427    async fn load_many_ex<R, S, C>(
428        &self,
429        stmt: S,
430        db: &C,
431    ) -> Result<Vec<Vec<Vec<R::ModelEx>>>, DbErr>
432    where
433        C: ConnectionTrait,
434        R: EntityTrait,
435        R::Model: Send + Sync,
436        S: EntityOrSelect<R>,
437        R::ModelEx: From<R::Model>,
438        <Self::Model as ModelTrait>::Entity: Related<R>,
439    {
440        let items: Vec<Vec<R::ModelEx>> =
441            loader_impl(self.iter().flatten(), stmt.select(), db).await?;
442        Ok(assemble_vectors(self, items))
443    }
444}
445
446fn assemble_options<I, T: Default>(input: &[Option<I>], items: Vec<T>) -> Vec<T> {
447    let mut items = items.into_iter();
448    let mut output = Vec::new();
449    for input in input.iter() {
450        if input.is_some() {
451            output.push(items.next().unwrap_or_default());
452        } else {
453            output.push(T::default());
454        }
455    }
456    output
457}
458
459fn assemble_vectors<I, T: Default>(input: &[Vec<I>], items: Vec<T>) -> Vec<Vec<T>> {
460    let mut items = items.into_iter();
461
462    let mut output = Vec::new();
463
464    for input in input.iter() {
465        output.push(Vec::new());
466
467        for _inner in input.iter() {
468            output
469                .last_mut()
470                .expect("Pushed above")
471                .push(items.next().unwrap_or_default());
472        }
473    }
474
475    output
476}
477
478trait Container: Default + Clone {
479    type Item;
480    fn add(&mut self, item: Self::Item);
481}
482
483impl<T: Clone> Container for Vec<T> {
484    type Item = T;
485    fn add(&mut self, item: Self::Item) {
486        self.push(item);
487    }
488}
489
490impl<T: Clone> Container for Option<T> {
491    type Item = T;
492    fn add(&mut self, item: Self::Item) {
493        self.replace(item);
494    }
495}
496
497async fn loader_impl<'a, Model, Iter, R, C, T, Output>(
498    items: Iter,
499    stmt: Select<R>,
500    db: &C,
501) -> Result<Vec<T>, DbErr>
502where
503    Model: ModelTrait + Sync + 'a,
504    Iter: Iterator<Item = &'a Model> + 'a,
505    C: ConnectionTrait,
506    R: EntityTrait,
507    R::Model: Send + Sync,
508    Model::Entity: Related<R>,
509    Output: From<R::Model>,
510    T: Container<Item = Output>,
511{
512    let (keys, hashmap) = if let Some(via_def) = <Model::Entity as Related<R>>::via() {
513        let keys = items
514            .map(|model| extract_key(&via_def.from_col, model))
515            .collect::<Result<Vec<_>, _>>()?;
516
517        if keys.is_empty() {
518            return Ok(Vec::new());
519        }
520
521        let condition = prepare_condition::<Model>(
522            &via_def.to_tbl,
523            &via_def.from_col,
524            &via_def.to_col,
525            &keys,
526            db,
527        )?;
528
529        let stmt = QueryFilter::filter(
530            stmt.join_rev(JoinType::InnerJoin, <Model::Entity as Related<R>>::to()),
531            condition,
532        );
533
534        // The idea is to do a SelectTwo with join, then extract key via a dynamic model
535        // i.e. select (baker + cake_baker) and extract cake_id from result rows
536        // SELECT "baker"."id", "baker"."name", "baker"."contact_details", "baker"."bakery_id",
537        //     "cakes_bakers"."cake_id" <- extra select
538        // FROM "baker" <- target
539        // INNER JOIN "cakes_bakers" <- junction
540        //     ON "cakes_bakers"."baker_id" = "baker"."id" <- relation
541        // WHERE "cakes_bakers"."cake_id" IN (..)
542
543        let data = stmt
544            .select_also_dyn_model(
545                via_def.to_tbl.sea_orm_table().clone(),
546                dynamic::ModelType {
547                    // we uses the left Model's type but the right Model's field
548                    fields: extract_col_type::<Model>(&via_def.from_col, &via_def.to_col)?,
549                },
550            )
551            .all(db)
552            .await?;
553
554        let mut hashmap: HashMap<ValueTuple, T> =
555            keys.iter()
556                .fold(HashMap::new(), |mut acc, key: &ValueTuple| {
557                    acc.insert(key.clone(), T::default());
558                    acc
559                });
560
561        for (item, key) in data {
562            let key = dyn_model_to_key(key)?;
563
564            let vec = hashmap.get_mut(&key).ok_or_else(|| {
565                DbErr::RecordNotFound(format!("Loader: failed to find model for {key:?}"))
566            })?;
567
568            vec.add(item.into());
569        }
570
571        (keys, hashmap)
572    } else {
573        let rel_def = <Model::Entity as Related<R>>::to();
574
575        let keys = items
576            .map(|model| extract_key(&rel_def.from_col, model))
577            .collect::<Result<Vec<_>, _>>()?;
578
579        if keys.is_empty() {
580            return Ok(Vec::new());
581        }
582
583        let condition = prepare_condition::<Model>(
584            &rel_def.to_tbl,
585            &rel_def.from_col,
586            &rel_def.to_col,
587            &keys,
588            db,
589        )?;
590
591        let stmt = QueryFilter::filter(stmt, condition);
592
593        let data = stmt.all(db).await?;
594
595        let mut hashmap: HashMap<ValueTuple, T> = Default::default();
596
597        for item in data {
598            let key = extract_key(&rel_def.to_col, &item)?;
599            let holder = hashmap.entry(key).or_default();
600            holder.add(item.into());
601        }
602
603        (keys, hashmap)
604    };
605
606    let result: Vec<T> = keys
607        .iter()
608        .map(|key: &ValueTuple| hashmap.get(key).cloned().unwrap_or_default())
609        .collect();
610
611    Ok(result)
612}
613
614fn cmp_table_ref(left: &TableRef, right: &TableRef) -> bool {
615    left == right
616}
617
618fn extract_key<Model>(target_col: &Identity, model: &Model) -> Result<ValueTuple, DbErr>
619where
620    Model: ModelTrait,
621{
622    let values = target_col
623        .iter()
624        .map(|col| {
625            let col_name = col.inner();
626            let column =
627                <<<Model as ModelTrait>::Entity as EntityTrait>::Column as FromStr>::from_str(
628                    &col_name,
629                )
630                .map_err(|_| DbErr::Type(format!("Failed at mapping '{col_name}' to column")))?;
631            Ok(model.get(column))
632        })
633        .collect::<Result<Vec<_>, DbErr>>()?;
634
635    Ok(match values.len() {
636        0 => return Err(DbErr::Type("Identity zero?".into())),
637        1 => ValueTuple::One(values.into_iter().next().expect("checked")),
638        2 => {
639            let mut it = values.into_iter();
640            ValueTuple::Two(it.next().expect("checked"), it.next().expect("checked"))
641        }
642        3 => {
643            let mut it = values.into_iter();
644            ValueTuple::Three(
645                it.next().expect("checked"),
646                it.next().expect("checked"),
647                it.next().expect("checked"),
648            )
649        }
650        _ => ValueTuple::Many(values),
651    })
652}
653
654fn extract_col_type<Model>(
655    left: &Identity,
656    right: &Identity,
657) -> Result<Vec<dynamic::FieldType>, DbErr>
658where
659    Model: ModelTrait,
660{
661    use itertools::Itertools;
662
663    if left.arity() != right.arity() {
664        return Err(DbErr::Type(format!(
665            "Identity mismatch: left: {} != right: {}",
666            left.arity(),
667            right.arity()
668        )));
669    }
670
671    let vec = left
672        .iter()
673        .zip_eq(right.iter())
674        .map(|(l, r)| {
675            let col_a =
676                <<<Model as ModelTrait>::Entity as EntityTrait>::Column as FromStr>::from_str(
677                    &l.inner(),
678                )
679                .map_err(|_| DbErr::Type(format!("Failed at mapping '{l}'")))?;
680            Ok(dynamic::FieldType::new(
681                r.clone(),
682                Model::get_value_type(col_a),
683            ))
684        })
685        .collect::<Result<Vec<_>, DbErr>>()?;
686
687    Ok(vec)
688}
689
690#[allow(clippy::unwrap_used)]
691fn dyn_model_to_key(dyn_model: dynamic::Model) -> Result<ValueTuple, DbErr> {
692    Ok(match dyn_model.fields.len() {
693        0 => return Err(DbErr::Type("Identity zero?".into())),
694        1 => ValueTuple::One(dyn_model.fields.into_iter().next().unwrap().value),
695        2 => {
696            let mut iter = dyn_model.fields.into_iter();
697            ValueTuple::Two(iter.next().unwrap().value, iter.next().unwrap().value)
698        }
699        3 => {
700            let mut iter = dyn_model.fields.into_iter();
701            ValueTuple::Three(
702                iter.next().unwrap().value,
703                iter.next().unwrap().value,
704                iter.next().unwrap().value,
705            )
706        }
707        _ => ValueTuple::Many(dyn_model.fields.into_iter().map(|v| v.value).collect()),
708    })
709}
710
711fn arity_mismatch(expected: usize, actual: &ValueTuple) -> DbErr {
712    DbErr::Type(format!(
713        "Loader: arity mismatch: expected {expected}, got {} in {actual:?}",
714        actual.arity()
715    ))
716}
717
718#[inline]
719fn prepare_condition<Model>(
720    table: &TableRef,
721    from: &Identity,
722    to: &Identity,
723    keys: &[ValueTuple],
724    db: &impl ConnectionTrait,
725) -> Result<Condition, DbErr>
726where
727    Model: ModelTrait,
728{
729    let db_backend = db.get_database_backend();
730    if matches!(db_backend, DbBackend::Postgres) {
731        prepare_condition_with_save_as::<Model>(table, from, to, keys)
732    } else {
733        prepare_condition_simple(table, to, keys, db_backend)
734    }
735}
736
737fn prepare_condition_with_save_as<Model>(
738    table: &TableRef,
739    from: &Identity,
740    to: &Identity,
741    keys: &[ValueTuple],
742) -> Result<Condition, DbErr>
743where
744    Model: ModelTrait,
745{
746    use itertools::Itertools;
747
748    let keys = keys.iter().unique();
749    let (from_cols, to_cols) = resolve_column_pairs::<Model>(table, from, to)?;
750
751    if from_cols.is_empty() || to_cols.is_empty() {
752        return Err(DbErr::Type(format!(
753            "Loader: resolved zero columns for identities {from:?} -> {to:?}"
754        )));
755    }
756
757    let arity = from_cols.len();
758
759    let value_tuples = keys
760        .map(|key| {
761            let key_arity = key.arity();
762            if arity != key_arity {
763                return Err(arity_mismatch(arity, key));
764            }
765
766            // For Postgres, we need to use `AS` to cast the value to the correct type
767            Ok(apply_save_as::<Model>(&from_cols, key.clone()))
768        })
769        .collect::<Result<Vec<_>, DbErr>>()?;
770
771    // Build `(c1, c2, ...) IN ((v11, v12, ...), (v21, v22, ...), ...)`
772    let expr = Expr::tuple(create_table_columns(table, to)).is_in(value_tuples);
773
774    Ok(expr.into())
775}
776
777// For loaders that do not require calling save_as
778fn prepare_condition_simple(
779    table: &TableRef,
780    to: &Identity,
781    keys: &[ValueTuple],
782    backend: DbBackend,
783) -> Result<Condition, DbErr> {
784    use itertools::Itertools;
785
786    let arity = to.arity();
787    let keys = keys.iter().unique();
788
789    let table_columns = create_table_columns(table, to);
790
791    if cfg!(feature = "sqlite-no-row-value-before-3_15") && matches!(backend, DbBackend::Sqlite) {
792        // SQLite supports row value expressions since 3.15.0
793        // https://www.sqlite.org/releaselog/3_15_0.html
794        let mut outer = Condition::any();
795
796        for key in keys {
797            let key_arity = key.arity();
798            if arity != key_arity {
799                return Err(arity_mismatch(arity, key));
800            }
801
802            let table_columns = table_columns.iter().cloned();
803            let values = key.clone().into_iter().map(Expr::val);
804
805            let inner = table_columns
806                .zip(values)
807                .fold(Condition::all(), |cond, (column, value)| {
808                    cond.add(column.eq(value))
809                });
810
811            // Build `(c1 = v11 AND c2 = v12) OR (c1 = v21 AND c2 = v22) ...`
812            outer = outer.add(inner);
813        }
814
815        Ok(outer)
816    } else {
817        // A vector of tuples of values, e.g. [(v11, v12, ...), (v21, v22, ...), ...]
818        let value_tuples = keys
819            .map(|key| {
820                let key_arity = key.arity();
821                if arity != key_arity {
822                    return Err(arity_mismatch(arity, key));
823                }
824
825                let tuple_exprs = key.clone().into_iter().map(Expr::val);
826
827                Ok(Expr::tuple(tuple_exprs))
828            })
829            .collect::<Result<Vec<_>, DbErr>>()?;
830
831        // Build `(c1, c2, ...) IN ((v11, v12, ...), (v21, v22, ...), ...)`
832        let expr = Expr::tuple(table_columns).is_in(value_tuples);
833
834        Ok(expr.into())
835    }
836}
837
838type ModelColumn<M> = <<M as ModelTrait>::Entity as EntityTrait>::Column;
839
840type ColumnPairs<M> = (Vec<ModelColumn<M>>, Vec<ColumnRef>);
841
842fn resolve_column_pairs<Model>(
843    table: &TableRef,
844    from: &Identity,
845    to: &Identity,
846) -> Result<ColumnPairs<Model>, DbErr>
847where
848    Model: ModelTrait,
849    ModelColumn<Model>: ColumnTrait,
850{
851    let from_columns = parse_identity_columns::<Model>(from)?;
852    let to_columns = column_refs_from_identity(table, to);
853
854    if from_columns.len() != to_columns.len() {
855        return Err(DbErr::Type(format!(
856            "Loader: identity column count mismatch between {from:?} and {to:?}"
857        )));
858    }
859
860    Ok((from_columns, to_columns))
861}
862
863fn column_refs_from_identity(table: &TableRef, identity: &Identity) -> Vec<ColumnRef> {
864    identity
865        .iter()
866        .map(|col| table_column(table, col))
867        .collect()
868}
869
870fn parse_identity_columns<Model>(identity: &Identity) -> Result<Vec<ModelColumn<Model>>, DbErr>
871where
872    Model: ModelTrait,
873{
874    identity
875        .iter()
876        .map(|from_col| try_conv_ident_to_column::<Model>(from_col))
877        .collect()
878}
879
880fn try_conv_ident_to_column<Model>(ident: &DynIden) -> Result<ModelColumn<Model>, DbErr>
881where
882    Model: ModelTrait,
883{
884    let column_name = ident.inner();
885    ModelColumn::<Model>::from_str(&column_name)
886        .map_err(|_| DbErr::Type(format!("Failed at mapping '{column_name}' to column")))
887}
888
889fn table_column(tbl: &TableRef, col: &DynIden) -> ColumnRef {
890    (tbl.sea_orm_table().to_owned(), col.clone()).into_column_ref()
891}
892
893/// Create a vector of `Expr::col` from the table and identity, e.g. [Expr::col((table, col1)), Expr::col((table, col2)), ...]
894fn create_table_columns(table: &TableRef, cols: &Identity) -> Vec<Expr> {
895    cols.iter()
896        .cloned()
897        .map(|col| table_column(table, &col))
898        .map(Expr::col)
899        .collect()
900}
901
902/// Apply `save_as` to each value in the tuple, e.g. `(Cast(val1 as type1), Cast(val2 as type2), ...)`
903fn apply_save_as<M: ModelTrait>(cols: &[ModelColumn<M>], values: ValueTuple) -> Expr {
904    let values_expr_iter = values.into_iter().map(Expr::val);
905
906    let tuple_exprs: Vec<_> = cols
907        .iter()
908        .zip(values_expr_iter)
909        .map(|(model_column, value)| model_column.save_as(value))
910        .collect();
911
912    Expr::tuple(tuple_exprs)
913}
914
915#[cfg(test)]
916mod tests {
917    fn cake_model(id: i32) -> sea_orm::tests_cfg::cake::Model {
918        let name = match id {
919            1 => "apple cake",
920            2 => "orange cake",
921            3 => "fruit cake",
922            4 => "chocolate cake",
923            _ => "",
924        }
925        .to_string();
926        sea_orm::tests_cfg::cake::Model { id, name }
927    }
928
929    fn fruit_model(id: i32, cake_id: Option<i32>) -> sea_orm::tests_cfg::fruit::Model {
930        let name = match id {
931            1 => "apple",
932            2 => "orange",
933            3 => "grape",
934            4 => "strawberry",
935            _ => "",
936        }
937        .to_string();
938        sea_orm::tests_cfg::fruit::Model { id, name, cake_id }
939    }
940
941    fn filling_model(id: i32) -> sea_orm::tests_cfg::filling::Model {
942        let name = match id {
943            1 => "apple juice",
944            2 => "orange jam",
945            3 => "chocolate crust",
946            4 => "strawberry jam",
947            _ => "",
948        }
949        .to_string();
950        sea_orm::tests_cfg::filling::Model {
951            id,
952            name,
953            vendor_id: Some(1),
954            ignored_attr: 0,
955        }
956    }
957
958    fn cake_filling_model(
959        cake_id: i32,
960        filling_id: i32,
961    ) -> sea_orm::tests_cfg::cake_filling::Model {
962        sea_orm::tests_cfg::cake_filling::Model {
963            cake_id,
964            filling_id,
965        }
966    }
967
968    #[tokio::test]
969    async fn test_load_one() {
970        use sea_orm::{DbBackend, LoaderTrait, MockDatabase, entity::prelude::*, tests_cfg::*};
971
972        let db = MockDatabase::new(DbBackend::Postgres)
973            .append_query_results([[cake_model(1), cake_model(2)]])
974            .into_connection();
975
976        let fruits = vec![fruit_model(1, Some(1))];
977
978        let cakes = fruits
979            .load_one(cake::Entity::find(), &db)
980            .await
981            .expect("Should return something");
982
983        assert_eq!(cakes, [Some(cake_model(1))]);
984    }
985
986    #[tokio::test]
987    async fn test_load_one_same_cake() {
988        use sea_orm::{DbBackend, LoaderTrait, MockDatabase, entity::prelude::*, tests_cfg::*};
989
990        let db = MockDatabase::new(DbBackend::Postgres)
991            .append_query_results([[cake_model(1), cake_model(2)]])
992            .into_connection();
993
994        let fruits = vec![fruit_model(1, Some(1)), fruit_model(2, Some(1))];
995
996        let cakes = fruits
997            .load_one(cake::Entity::find(), &db)
998            .await
999            .expect("Should return something");
1000
1001        assert_eq!(cakes, [Some(cake_model(1)), Some(cake_model(1))]);
1002    }
1003
1004    #[tokio::test]
1005    async fn test_load_one_empty() {
1006        use sea_orm::{DbBackend, LoaderTrait, MockDatabase, entity::prelude::*, tests_cfg::*};
1007
1008        let db = MockDatabase::new(DbBackend::Postgres)
1009            .append_query_results([[cake_model(1), cake_model(2)]])
1010            .into_connection();
1011
1012        let fruits: Vec<fruit::Model> = vec![];
1013
1014        let cakes = fruits
1015            .load_one(cake::Entity::find(), &db)
1016            .await
1017            .expect("Should return something");
1018
1019        assert_eq!(cakes, []);
1020    }
1021
1022    #[tokio::test]
1023    async fn test_load_many() {
1024        use sea_orm::{DbBackend, LoaderTrait, MockDatabase, entity::prelude::*, tests_cfg::*};
1025
1026        let db = MockDatabase::new(DbBackend::Postgres)
1027            .append_query_results([[fruit_model(1, Some(1))]])
1028            .into_connection();
1029
1030        let cakes = vec![cake_model(1), cake_model(2)];
1031
1032        let fruits = cakes
1033            .load_many(fruit::Entity::find(), &db)
1034            .await
1035            .expect("Should return something");
1036
1037        assert_eq!(fruits, [vec![fruit_model(1, Some(1))], vec![]]);
1038    }
1039
1040    #[tokio::test]
1041    async fn test_load_many_same_fruit() {
1042        use sea_orm::{DbBackend, LoaderTrait, MockDatabase, entity::prelude::*, tests_cfg::*};
1043
1044        let db = MockDatabase::new(DbBackend::Postgres)
1045            .append_query_results([[fruit_model(1, Some(1)), fruit_model(2, Some(1))]])
1046            .into_connection();
1047
1048        let cakes = vec![cake_model(1), cake_model(2)];
1049
1050        let fruits = cakes
1051            .load_many(fruit::Entity::find(), &db)
1052            .await
1053            .expect("Should return something");
1054
1055        assert_eq!(
1056            fruits,
1057            [
1058                vec![fruit_model(1, Some(1)), fruit_model(2, Some(1))],
1059                vec![]
1060            ]
1061        );
1062    }
1063
1064    #[tokio::test]
1065    async fn test_load_many_empty() {
1066        use sea_orm::{DbBackend, MockDatabase, entity::prelude::*, tests_cfg::*};
1067
1068        let db = MockDatabase::new(DbBackend::Postgres).into_connection();
1069
1070        let cakes: Vec<cake::Model> = vec![];
1071
1072        let fruits = cakes
1073            .load_many(fruit::Entity::find(), &db)
1074            .await
1075            .expect("Should return something");
1076
1077        let empty_vec: Vec<Vec<fruit::Model>> = vec![];
1078
1079        assert_eq!(fruits, empty_vec);
1080    }
1081
1082    #[tokio::test]
1083    async fn test_load_many_to_many_base() {
1084        use sea_orm::{DbBackend, IntoMockRow, LoaderTrait, MockDatabase, tests_cfg::*};
1085
1086        let db = MockDatabase::new(DbBackend::Postgres)
1087            .append_query_results([
1088                [cake_filling_model(1, 1).into_mock_row()],
1089                [filling_model(1).into_mock_row()],
1090            ])
1091            .into_connection();
1092
1093        let cakes = vec![cake_model(1)];
1094
1095        let fillings = cakes
1096            .load_many_to_many(Filling, CakeFilling, &db)
1097            .await
1098            .expect("Should return something");
1099
1100        assert_eq!(fillings, vec![vec![filling_model(1)]]);
1101    }
1102
1103    #[tokio::test]
1104    async fn test_load_many_to_many_complex() {
1105        use sea_orm::{DbBackend, IntoMockRow, LoaderTrait, MockDatabase, tests_cfg::*};
1106
1107        let db = MockDatabase::new(DbBackend::Postgres)
1108            .append_query_results([
1109                [
1110                    cake_filling_model(1, 1).into_mock_row(),
1111                    cake_filling_model(1, 2).into_mock_row(),
1112                    cake_filling_model(1, 3).into_mock_row(),
1113                    cake_filling_model(2, 1).into_mock_row(),
1114                    cake_filling_model(2, 2).into_mock_row(),
1115                ],
1116                [
1117                    filling_model(1).into_mock_row(),
1118                    filling_model(2).into_mock_row(),
1119                    filling_model(3).into_mock_row(),
1120                    filling_model(4).into_mock_row(),
1121                    filling_model(5).into_mock_row(),
1122                ],
1123            ])
1124            .into_connection();
1125
1126        let cakes = vec![cake_model(1), cake_model(2), cake_model(3)];
1127
1128        let fillings = cakes
1129            .load_many_to_many(Filling, CakeFilling, &db)
1130            .await
1131            .expect("Should return something");
1132
1133        assert_eq!(
1134            fillings,
1135            vec![
1136                vec![filling_model(1), filling_model(2), filling_model(3)],
1137                vec![filling_model(1), filling_model(2)],
1138                vec![],
1139            ]
1140        );
1141    }
1142
1143    #[tokio::test]
1144    async fn test_load_many_to_many_empty() {
1145        use sea_orm::{DbBackend, IntoMockRow, LoaderTrait, MockDatabase, tests_cfg::*};
1146
1147        let db = MockDatabase::new(DbBackend::Postgres)
1148            .append_query_results([
1149                [cake_filling_model(1, 1).into_mock_row()],
1150                [filling_model(1).into_mock_row()],
1151            ])
1152            .into_connection();
1153
1154        let cakes: Vec<cake::Model> = vec![];
1155
1156        let fillings = cakes
1157            .load_many_to_many(Filling, CakeFilling, &db)
1158            .await
1159            .expect("Should return something");
1160
1161        let empty_vec: Vec<Vec<filling::Model>> = vec![];
1162
1163        assert_eq!(fillings, empty_vec);
1164    }
1165
1166    #[tokio::test]
1167    async fn test_load_one_duplicate_keys() {
1168        use sea_orm::{DbBackend, LoaderTrait, MockDatabase, entity::prelude::*, tests_cfg::*};
1169
1170        let db = MockDatabase::new(DbBackend::Postgres)
1171            .append_query_results([[cake_model(1), cake_model(2)]])
1172            .into_connection();
1173
1174        let fruits = vec![
1175            fruit_model(1, Some(1)),
1176            fruit_model(2, Some(1)),
1177            fruit_model(3, Some(1)),
1178            fruit_model(4, Some(1)),
1179        ];
1180
1181        let cakes = fruits
1182            .load_one(cake::Entity::find(), &db)
1183            .await
1184            .expect("Should return something");
1185
1186        assert_eq!(cakes.len(), 4);
1187        for cake in &cakes {
1188            assert_eq!(cake, &Some(cake_model(1)));
1189        }
1190        let logs = db.into_transaction_log();
1191        let sql = format!("{:?}", logs[0]);
1192
1193        let values_count = sql.matches("$1").count();
1194        assert_eq!(values_count, 1, "Duplicate values were not removed");
1195    }
1196
1197    #[tokio::test]
1198    async fn test_load_many_duplicate_keys() {
1199        use sea_orm::{DbBackend, LoaderTrait, MockDatabase, entity::prelude::*, tests_cfg::*};
1200
1201        let db = MockDatabase::new(DbBackend::Postgres)
1202            .append_query_results([[
1203                fruit_model(1, Some(1)),
1204                fruit_model(2, Some(1)),
1205                fruit_model(3, Some(2)),
1206            ]])
1207            .into_connection();
1208
1209        let cakes = vec![cake_model(1), cake_model(1), cake_model(2), cake_model(2)];
1210
1211        let fruits = cakes
1212            .load_many(fruit::Entity::find(), &db)
1213            .await
1214            .expect("Should return something");
1215
1216        assert_eq!(fruits.len(), 4);
1217
1218        let logs = db.into_transaction_log();
1219        let sql = format!("{:?}", logs[0]);
1220
1221        let values_count = sql.matches("$1").count() + sql.matches("$2").count();
1222        assert_eq!(values_count, 2, "Duplicate values were not removed");
1223    }
1224
1225    #[test]
1226    fn test_assemble_vectors() {
1227        use super::assemble_vectors;
1228
1229        assert_eq!(
1230            assemble_vectors(&[vec![1], vec![], vec![2, 3], vec![]], vec![11, 22, 33]),
1231            [vec![11], vec![], vec![22, 33], vec![]]
1232        );
1233    }
1234}