Skip to main content

rustauth_core/db/sql/
joins.rs

1use super::*;
2
3#[derive(Debug, Clone)]
4pub struct NativeJoin<'a> {
5    pub model: String,
6    pub table: &'a DbTable,
7    pub selection: Vec<(&'a str, &'a DbField)>,
8    pub from: String,
9    pub to: String,
10    pub relation: JoinRelation,
11    pub limit: usize,
12}
13
14pub fn resolve_native_joins<'a>(
15    schema: &'a DbSchema,
16    base_model: &str,
17    base_table: &'a DbTable,
18    joins: &IndexMap<String, JoinOption>,
19    default_limit: usize,
20) -> Result<Vec<NativeJoin<'a>>, RustAuthError> {
21    let mut resolved = Vec::new();
22    for (join_model, option) in joins {
23        if !option.enabled {
24            continue;
25        }
26        let (join_logical, join_table) = resolve_table_with_logical(schema, join_model)?;
27        let mut foreign_keys = foreign_keys_to_table(join_table, &base_table.name);
28        let is_forward_join = !foreign_keys.is_empty();
29        if foreign_keys.is_empty() {
30            foreign_keys = foreign_keys_to_table(base_table, &join_table.name);
31        }
32        let [(_foreign_key, field)] =
33            foreign_keys
34                .as_slice()
35                .try_into()
36                .map_err(|_| match foreign_keys.len() {
37                    0 => RustAuthError::JoinForeignKeyNotFound {
38                        base_model: base_model.to_owned(),
39                        join_model: join_model.clone(),
40                    },
41                    _ => RustAuthError::JoinForeignKeyAmbiguous {
42                        base_model: base_model.to_owned(),
43                        join_model: join_model.clone(),
44                    },
45                })?;
46        let reference =
47            field
48                .foreign_key
49                .as_ref()
50                .ok_or_else(|| RustAuthError::JoinForeignKeyNotFound {
51                    base_model: base_model.to_owned(),
52                    join_model: join_model.clone(),
53                })?;
54        let (from, to, is_unique) = if is_forward_join {
55            let (_, base_field) = resolve_field(base_table, &reference.field)?;
56            (base_field.name.clone(), field.name.clone(), field.unique)
57        } else {
58            let (_, join_field) = resolve_field(join_table, &reference.field)?;
59            (field.name.clone(), join_field.name.clone(), field.unique)
60        };
61        let relation = if !is_forward_join || is_unique {
62            JoinRelation::OneToOne
63        } else {
64            JoinRelation::OneToMany
65        };
66        let limit = if relation == JoinRelation::OneToOne {
67            1
68        } else {
69            option.limit.unwrap_or(default_limit)
70        };
71        resolved.push(NativeJoin {
72            model: join_logical.to_owned(),
73            table: join_table,
74            selection: select_fields(join_table, &[])?,
75            from,
76            to,
77            relation,
78            limit,
79        });
80    }
81    Ok(resolved)
82}
83
84pub fn internal_base_selection<'a>(
85    table: &'a DbTable,
86    select: &[String],
87    joins: &[NativeJoin<'_>],
88) -> Result<Vec<(&'a str, &'a DbField)>, RustAuthError> {
89    let mut selection = select_fields(table, select)?;
90    add_internal_field(table, &mut selection, "id")?;
91    for join in joins {
92        add_internal_field(table, &mut selection, &join.from)?;
93    }
94    Ok(selection)
95}
96
97pub fn joined_rows<Row, F>(
98    rows: &[Row],
99    base_selection: &[(&str, &DbField)],
100    output_select: &[String],
101    joins: &[NativeJoin<'_>],
102    mut row_value: F,
103) -> Result<Vec<DbRecord>, RustAuthError>
104where
105    F: FnMut(&Row, &DbField, &str) -> Result<DbValue, RustAuthError>,
106{
107    let mut records = Vec::<DbRecord>::new();
108    let mut groups = IndexMap::<String, usize>::new();
109
110    for row in rows {
111        let base_id = row_value(
112            row,
113            resolve_field_from_selection(base_selection, "id")?,
114            "__base_id",
115        )?;
116        let group_key = db_value_key(&base_id).ok_or_else(|| {
117            RustAuthError::Adapter("joined query base row is missing an id".to_owned())
118        })?;
119        let record_index = if let Some(index) = groups.get(&group_key) {
120            *index
121        } else {
122            let mut record = DbRecord::new();
123            for (index, (logical_name, field)) in base_selection.iter().enumerate() {
124                if !output_select.is_empty()
125                    && !output_select.iter().any(|field| field == logical_name)
126                {
127                    continue;
128                }
129                record.insert(
130                    (*logical_name).to_owned(),
131                    row_value(row, field, &base_alias(index))?,
132                );
133            }
134            for join in joins {
135                let value = if join.relation == JoinRelation::OneToOne {
136                    DbValue::Null
137                } else {
138                    DbValue::RecordArray(Vec::new())
139                };
140                record.insert(join.model.clone(), value);
141            }
142            records.push(record);
143            let index = records.len() - 1;
144            groups.insert(group_key, index);
145            index
146        };
147
148        for (join_index, join) in joins.iter().enumerate() {
149            let joined = joined_record(row, join_index, join, &mut row_value)?;
150            let Some(joined) = joined else {
151                continue;
152            };
153            if join.relation == JoinRelation::OneToOne {
154                records[record_index].insert(join.model.clone(), DbValue::Record(joined));
155            } else if let Some(DbValue::RecordArray(values)) =
156                records[record_index].get_mut(&join.model)
157            {
158                if values.len() < join.limit && !contains_record(values, &joined) {
159                    values.push(joined);
160                }
161            }
162        }
163    }
164
165    Ok(records)
166}
167
168pub fn base_alias(index: usize) -> String {
169    format!("__base_{index}")
170}
171
172pub fn join_alias(index: usize) -> String {
173    format!("__join_{index}")
174}
175
176pub fn join_field_alias(join_index: usize, field_index: usize) -> String {
177    format!("__join_{join_index}_{field_index}")
178}
179
180fn add_internal_field<'a>(
181    table: &'a DbTable,
182    selection: &mut Vec<(&'a str, &'a DbField)>,
183    field: &str,
184) -> Result<(), RustAuthError> {
185    let resolved = resolve_field(table, field)?;
186    if !selection
187        .iter()
188        .any(|(_, existing)| existing.name == resolved.1.name)
189    {
190        selection.push(resolved);
191    }
192    Ok(())
193}
194
195fn joined_record<Row, F>(
196    row: &Row,
197    join_index: usize,
198    join: &NativeJoin<'_>,
199    row_value: &mut F,
200) -> Result<Option<DbRecord>, RustAuthError>
201where
202    F: FnMut(&Row, &DbField, &str) -> Result<DbValue, RustAuthError>,
203{
204    let mut record = DbRecord::new();
205    for (field_index, (logical_name, field)) in join.selection.iter().enumerate() {
206        record.insert(
207            (*logical_name).to_owned(),
208            row_value(row, field, &join_field_alias(join_index, field_index))?,
209        );
210    }
211    if record.values().all(|value| *value == DbValue::Null) {
212        Ok(None)
213    } else {
214        Ok(Some(record))
215    }
216}
217
218fn contains_record(records: &[DbRecord], candidate: &DbRecord) -> bool {
219    let candidate_id = candidate.get("id").and_then(db_value_key);
220    records.iter().any(|record| {
221        if let Some(candidate_id) = &candidate_id {
222            record.get("id").and_then(db_value_key).as_ref() == Some(candidate_id)
223        } else {
224            record == candidate
225        }
226    })
227}
228
229fn foreign_keys_to_table<'a>(
230    table: &'a DbTable,
231    target_table: &str,
232) -> Vec<(&'a str, &'a DbField)> {
233    table
234        .fields
235        .iter()
236        .filter_map(|(logical_name, field)| {
237            field
238                .foreign_key
239                .as_ref()
240                .filter(|foreign_key| foreign_key.table == target_table)
241                .map(|_| (logical_name.as_str(), field))
242        })
243        .collect()
244}
245
246fn db_value_key(value: &DbValue) -> Option<String> {
247    match value {
248        DbValue::String(value) => Some(value.clone()),
249        DbValue::Number(value) => Some(value.to_string()),
250        _ => None,
251    }
252}