1use super::*;
2
3#[derive(Debug, Clone)]
4pub struct NativeJoin<'a> {
5 pub model: String,
6 pub table: &'a DbTable,
7 pub selection: Vec<(&'a str, &'a DbField)>,
8 pub from: String,
9 pub to: String,
10 pub relation: JoinRelation,
11 pub limit: usize,
12}
13
14pub fn resolve_native_joins<'a>(
15 schema: &'a DbSchema,
16 base_model: &str,
17 base_table: &'a DbTable,
18 joins: &IndexMap<String, JoinOption>,
19 default_limit: usize,
20) -> Result<Vec<NativeJoin<'a>>, RustAuthError> {
21 let mut resolved = Vec::new();
22 for (join_model, option) in joins {
23 if !option.enabled {
24 continue;
25 }
26 let (join_logical, join_table) = resolve_table_with_logical(schema, join_model)?;
27 let mut foreign_keys = foreign_keys_to_table(join_table, &base_table.name);
28 let is_forward_join = !foreign_keys.is_empty();
29 if foreign_keys.is_empty() {
30 foreign_keys = foreign_keys_to_table(base_table, &join_table.name);
31 }
32 let [(_foreign_key, field)] =
33 foreign_keys
34 .as_slice()
35 .try_into()
36 .map_err(|_| match foreign_keys.len() {
37 0 => RustAuthError::JoinForeignKeyNotFound {
38 base_model: base_model.to_owned(),
39 join_model: join_model.clone(),
40 },
41 _ => RustAuthError::JoinForeignKeyAmbiguous {
42 base_model: base_model.to_owned(),
43 join_model: join_model.clone(),
44 },
45 })?;
46 let reference =
47 field
48 .foreign_key
49 .as_ref()
50 .ok_or_else(|| RustAuthError::JoinForeignKeyNotFound {
51 base_model: base_model.to_owned(),
52 join_model: join_model.clone(),
53 })?;
54 let (from, to, is_unique) = if is_forward_join {
55 let (_, base_field) = resolve_field(base_table, &reference.field)?;
56 (base_field.name.clone(), field.name.clone(), field.unique)
57 } else {
58 let (_, join_field) = resolve_field(join_table, &reference.field)?;
59 (field.name.clone(), join_field.name.clone(), field.unique)
60 };
61 let relation = if !is_forward_join || is_unique {
62 JoinRelation::OneToOne
63 } else {
64 JoinRelation::OneToMany
65 };
66 let limit = if relation == JoinRelation::OneToOne {
67 1
68 } else {
69 option.limit.unwrap_or(default_limit)
70 };
71 resolved.push(NativeJoin {
72 model: join_logical.to_owned(),
73 table: join_table,
74 selection: select_fields(join_table, &[])?,
75 from,
76 to,
77 relation,
78 limit,
79 });
80 }
81 Ok(resolved)
82}
83
84pub fn internal_base_selection<'a>(
85 table: &'a DbTable,
86 select: &[String],
87 joins: &[NativeJoin<'_>],
88) -> Result<Vec<(&'a str, &'a DbField)>, RustAuthError> {
89 let mut selection = select_fields(table, select)?;
90 add_internal_field(table, &mut selection, "id")?;
91 for join in joins {
92 add_internal_field(table, &mut selection, &join.from)?;
93 }
94 Ok(selection)
95}
96
97pub fn joined_rows<Row, F>(
98 rows: &[Row],
99 base_selection: &[(&str, &DbField)],
100 output_select: &[String],
101 joins: &[NativeJoin<'_>],
102 mut row_value: F,
103) -> Result<Vec<DbRecord>, RustAuthError>
104where
105 F: FnMut(&Row, &DbField, &str) -> Result<DbValue, RustAuthError>,
106{
107 let mut records = Vec::<DbRecord>::new();
108 let mut groups = IndexMap::<String, usize>::new();
109
110 for row in rows {
111 let base_id = row_value(
112 row,
113 resolve_field_from_selection(base_selection, "id")?,
114 "__base_id",
115 )?;
116 let group_key = db_value_key(&base_id).ok_or_else(|| {
117 RustAuthError::Adapter("joined query base row is missing an id".to_owned())
118 })?;
119 let record_index = if let Some(index) = groups.get(&group_key) {
120 *index
121 } else {
122 let mut record = DbRecord::new();
123 for (index, (logical_name, field)) in base_selection.iter().enumerate() {
124 if !output_select.is_empty()
125 && !output_select.iter().any(|field| field == logical_name)
126 {
127 continue;
128 }
129 record.insert(
130 (*logical_name).to_owned(),
131 row_value(row, field, &base_alias(index))?,
132 );
133 }
134 for join in joins {
135 let value = if join.relation == JoinRelation::OneToOne {
136 DbValue::Null
137 } else {
138 DbValue::RecordArray(Vec::new())
139 };
140 record.insert(join.model.clone(), value);
141 }
142 records.push(record);
143 let index = records.len() - 1;
144 groups.insert(group_key, index);
145 index
146 };
147
148 for (join_index, join) in joins.iter().enumerate() {
149 let joined = joined_record(row, join_index, join, &mut row_value)?;
150 let Some(joined) = joined else {
151 continue;
152 };
153 if join.relation == JoinRelation::OneToOne {
154 records[record_index].insert(join.model.clone(), DbValue::Record(joined));
155 } else if let Some(DbValue::RecordArray(values)) =
156 records[record_index].get_mut(&join.model)
157 {
158 if values.len() < join.limit && !contains_record(values, &joined) {
159 values.push(joined);
160 }
161 }
162 }
163 }
164
165 Ok(records)
166}
167
168pub fn base_alias(index: usize) -> String {
169 format!("__base_{index}")
170}
171
172pub fn join_alias(index: usize) -> String {
173 format!("__join_{index}")
174}
175
176pub fn join_field_alias(join_index: usize, field_index: usize) -> String {
177 format!("__join_{join_index}_{field_index}")
178}
179
180fn add_internal_field<'a>(
181 table: &'a DbTable,
182 selection: &mut Vec<(&'a str, &'a DbField)>,
183 field: &str,
184) -> Result<(), RustAuthError> {
185 let resolved = resolve_field(table, field)?;
186 if !selection
187 .iter()
188 .any(|(_, existing)| existing.name == resolved.1.name)
189 {
190 selection.push(resolved);
191 }
192 Ok(())
193}
194
195fn joined_record<Row, F>(
196 row: &Row,
197 join_index: usize,
198 join: &NativeJoin<'_>,
199 row_value: &mut F,
200) -> Result<Option<DbRecord>, RustAuthError>
201where
202 F: FnMut(&Row, &DbField, &str) -> Result<DbValue, RustAuthError>,
203{
204 let mut record = DbRecord::new();
205 for (field_index, (logical_name, field)) in join.selection.iter().enumerate() {
206 record.insert(
207 (*logical_name).to_owned(),
208 row_value(row, field, &join_field_alias(join_index, field_index))?,
209 );
210 }
211 if record.values().all(|value| *value == DbValue::Null) {
212 Ok(None)
213 } else {
214 Ok(Some(record))
215 }
216}
217
218fn contains_record(records: &[DbRecord], candidate: &DbRecord) -> bool {
219 let candidate_id = candidate.get("id").and_then(db_value_key);
220 records.iter().any(|record| {
221 if let Some(candidate_id) = &candidate_id {
222 record.get("id").and_then(db_value_key).as_ref() == Some(candidate_id)
223 } else {
224 record == candidate
225 }
226 })
227}
228
229fn foreign_keys_to_table<'a>(
230 table: &'a DbTable,
231 target_table: &str,
232) -> Vec<(&'a str, &'a DbField)> {
233 table
234 .fields
235 .iter()
236 .filter_map(|(logical_name, field)| {
237 field
238 .foreign_key
239 .as_ref()
240 .filter(|foreign_key| foreign_key.table == target_table)
241 .map(|_| (logical_name.as_str(), field))
242 })
243 .collect()
244}
245
246fn db_value_key(value: &DbValue) -> Option<String> {
247 match value {
248 DbValue::String(value) => Some(value.clone()),
249 DbValue::Number(value) => Some(value.to_string()),
250 _ => None,
251 }
252}