Skip to main content

wasm_dbms/
join.rs

1// Rust guideline compliant 2026-03-01
2// X-WHERE-CLAUSE, M-CANONICAL-DOCS
3
4//! Join execution engine for cross-table queries.
5
6use std::collections::HashSet;
7
8use wasm_dbms_api::prelude::{
9    ColumnDef, DbmsResult, JoinColumnDef, JoinType, OrderDirection, Query, Value,
10};
11use wasm_dbms_memory::prelude::{AccessControl, AccessControlList, MemoryProvider};
12
13use crate::database::WasmDbmsDatabase;
14use crate::schema::DatabaseSchema;
15
16/// A row in the joined result, organized by source table.
17type JoinedRow = Vec<(String, Vec<(ColumnDef, Value)>)>;
18
19/// Engine that executes join queries using nested-loop join.
20pub struct JoinEngine<'a, Schema: ?Sized, M, A = AccessControlList>
21where
22    Schema: DatabaseSchema<M, A>,
23    M: MemoryProvider,
24    A: AccessControl,
25{
26    schema: &'a Schema,
27    _marker: std::marker::PhantomData<(M, A)>,
28}
29
30impl<'a, Schema: ?Sized, M, A> JoinEngine<'a, Schema, M, A>
31where
32    Schema: DatabaseSchema<M, A>,
33    M: MemoryProvider,
34    A: AccessControl,
35{
36    pub fn new(schema: &'a Schema) -> Self {
37        Self {
38            schema,
39            _marker: std::marker::PhantomData,
40        }
41    }
42}
43
44impl<Schema: ?Sized, M, A> JoinEngine<'_, Schema, M, A>
45where
46    Schema: DatabaseSchema<M, A>,
47    M: MemoryProvider,
48    A: AccessControl,
49{
50    /// Executes a join query using nested-loop join.
51    pub fn join(
52        &self,
53        dbms: &WasmDbmsDatabase<'_, M, A>,
54        from_table: &str,
55        query: Query,
56    ) -> DbmsResult<Vec<Vec<(JoinColumnDef, Value)>>> {
57        let from_rows = self
58            .schema
59            .select(dbms, from_table, Query::builder().all().build())?;
60
61        let mut joined_rows: Vec<JoinedRow> = from_rows
62            .into_iter()
63            .map(|row| vec![(from_table.to_string(), row)])
64            .collect();
65
66        for join in &query.joins {
67            let (left_table, left_col) = self.resolve_column_ref(&join.left_column, from_table);
68            let (_right_table_ref, right_col) =
69                self.resolve_column_ref(&join.right_column, &join.table);
70
71            let (keep_unmatched_left, keep_unmatched_right) = match join.join_type {
72                JoinType::Inner => (false, false),
73                JoinType::Left => (true, false),
74                JoinType::Right => (false, true),
75                JoinType::Full => (true, true),
76            };
77
78            let right_rows = self.load_join_right_rows(
79                dbms,
80                &joined_rows,
81                &join.table,
82                &left_table,
83                left_col,
84                right_col,
85                keep_unmatched_right,
86            )?;
87
88            joined_rows = self.nested_loop_join(
89                joined_rows,
90                &right_rows,
91                &join.table,
92                &left_table,
93                left_col,
94                right_col,
95                keep_unmatched_left,
96                keep_unmatched_right,
97            );
98        }
99
100        if let Some(filter) = &query.filter {
101            joined_rows.retain(|row| {
102                let groups: Vec<(&str, Vec<(ColumnDef, Value)>)> = row
103                    .iter()
104                    .map(|(t, cols)| (t.as_str(), cols.clone()))
105                    .collect();
106                filter.matches_joined_row(&groups).unwrap_or(false)
107            });
108        }
109
110        for (column, direction) in query.order_by.iter().rev() {
111            self.sort_joined_rows(&mut joined_rows, column, *direction);
112        }
113
114        let offset = query.offset.unwrap_or_default();
115        if offset > 0 {
116            if offset >= joined_rows.len() {
117                joined_rows.clear();
118            } else {
119                joined_rows = joined_rows.into_iter().skip(offset).collect();
120            }
121        }
122
123        if let Some(limit) = query.limit {
124            joined_rows.truncate(limit);
125        }
126
127        let results = joined_rows
128            .into_iter()
129            .map(|row| self.flatten_joined_row(row, &query))
130            .collect::<DbmsResult<Vec<_>>>()?;
131
132        Ok(results)
133    }
134
135    #[expect(
136        clippy::too_many_arguments,
137        reason = "arguments are necessary for loading right table rows based on join conditions"
138    )]
139    fn load_join_right_rows(
140        &self,
141        dbms: &WasmDbmsDatabase<'_, M, A>,
142        left_rows: &[JoinedRow],
143        right_table: &str,
144        left_table: &str,
145        left_col: &str,
146        right_col: &str,
147        keep_unmatched_right: bool,
148    ) -> DbmsResult<Vec<Vec<(ColumnDef, Value)>>> {
149        let unique_join_values: Vec<Value> = {
150            let mut seen = HashSet::new();
151            left_rows
152                .iter()
153                .filter_map(|row| self.get_column_value(row, left_table, left_col).cloned())
154                .filter(|value| seen.insert(value.clone()))
155                .collect()
156        };
157
158        if unique_join_values.is_empty() || keep_unmatched_right {
159            return self
160                .schema
161                .select(dbms, right_table, Query::builder().all().build());
162        }
163
164        self.schema.select(
165            dbms,
166            right_table,
167            Query::builder()
168                .all()
169                .filter(Some(wasm_dbms_api::prelude::Filter::in_list(
170                    right_col,
171                    unique_join_values,
172                )))
173                .build(),
174        )
175    }
176
177    /// Unified nested-loop join.
178    #[allow(clippy::too_many_arguments)]
179    fn nested_loop_join(
180        &self,
181        left_rows: Vec<JoinedRow>,
182        right_rows: &[Vec<(ColumnDef, Value)>],
183        right_table: &str,
184        left_table: &str,
185        left_col: &str,
186        right_col: &str,
187        keep_unmatched_left: bool,
188        keep_unmatched_right: bool,
189    ) -> Vec<JoinedRow> {
190        let mut results = Vec::new();
191        let mut right_matched = vec![false; right_rows.len()];
192
193        for left_row in &left_rows {
194            let left_value = self.get_column_value(left_row, left_table, left_col);
195            let mut matched = false;
196
197            for (i, right_row) in right_rows.iter().enumerate() {
198                let right_value = right_row
199                    .iter()
200                    .find(|(c, _)| c.name == right_col)
201                    .map(|(_, v)| v);
202
203                if left_value == right_value && left_value.is_some() {
204                    let mut new_row = left_row.clone();
205                    new_row.push((right_table.to_string(), right_row.clone()));
206                    results.push(new_row);
207                    right_matched[i] = true;
208                    matched = true;
209                }
210            }
211
212            if keep_unmatched_left && !matched {
213                let mut new_row = left_row.clone();
214                let null_cols = right_rows
215                    .first()
216                    .map(|sample| self.null_pad_columns(sample))
217                    .unwrap_or_default();
218                new_row.push((right_table.to_string(), null_cols));
219                results.push(new_row);
220            }
221        }
222
223        if keep_unmatched_right {
224            for (i, right_row) in right_rows.iter().enumerate() {
225                if !right_matched[i] {
226                    let mut new_row: JoinedRow = Vec::new();
227                    if let Some(sample_left) = left_rows.first() {
228                        for (table_name, cols) in sample_left {
229                            new_row.push((table_name.clone(), self.null_pad_columns(cols)));
230                        }
231                    }
232                    new_row.push((right_table.to_string(), right_row.clone()));
233                    results.push(new_row);
234                }
235            }
236        }
237
238        results
239    }
240
241    /// Resolves a column reference to (table_name, column_name).
242    fn resolve_column_ref<'a>(&self, field: &'a str, default_table: &'a str) -> (String, &'a str) {
243        if let Some((table, column)) = field.split_once('.') {
244            (table.to_string(), column)
245        } else {
246            (default_table.to_string(), field)
247        }
248    }
249
250    /// Finds a column value in a joined row.
251    fn get_column_value<'a>(
252        &self,
253        row: &'a JoinedRow,
254        table: &str,
255        column: &str,
256    ) -> Option<&'a Value> {
257        row.iter()
258            .find(|(t, _)| t == table)
259            .and_then(|(_, cols)| cols.iter().find(|(c, _)| c.name == column).map(|(_, v)| v))
260    }
261
262    /// Creates a NULL-padded row.
263    fn null_pad_columns(&self, sample_row: &[(ColumnDef, Value)]) -> Vec<(ColumnDef, Value)> {
264        sample_row
265            .iter()
266            .map(|(col, _)| (*col, Value::Null))
267            .collect()
268    }
269
270    /// Sorts joined rows by a column.
271    fn sort_joined_rows(&self, rows: &mut [JoinedRow], column: &str, direction: OrderDirection) {
272        let (table, col) = if let Some((t, c)) = column.split_once('.') {
273            (Some(t), c)
274        } else {
275            (None, column)
276        };
277
278        rows.sort_by(|a, b| {
279            let a_val = self.find_value_in_joined_row(a, table, col);
280            let b_val = self.find_value_in_joined_row(b, table, col);
281
282            crate::database::sort_values_with_direction(a_val, b_val, direction)
283        });
284    }
285
286    /// Finds a column value in a joined row, optionally scoped to a table.
287    fn find_value_in_joined_row<'a>(
288        &self,
289        row: &'a JoinedRow,
290        table: Option<&str>,
291        column: &str,
292    ) -> Option<&'a Value> {
293        if let Some(table) = table {
294            return self.get_column_value(row, table, column);
295        }
296        row.iter()
297            .flat_map(|(_, cols)| cols)
298            .find_map(|(col, value)| {
299                if col.name == column {
300                    Some(value)
301                } else {
302                    None
303                }
304            })
305    }
306
307    /// Flattens a joined row into the output format.
308    fn flatten_joined_row(
309        &self,
310        row: JoinedRow,
311        query: &Query,
312    ) -> DbmsResult<Vec<(JoinColumnDef, Value)>> {
313        let mut result = Vec::new();
314
315        for (table_name, cols) in row {
316            for (col, val) in cols {
317                let mut candid_col = JoinColumnDef::from(col);
318                candid_col.table = Some(table_name.clone());
319
320                if !query.all_selected() {
321                    let selected = query.raw_columns();
322                    let qualified_name = format!("{table_name}.{col}", col = candid_col.name);
323                    if !selected.contains(&candid_col.name) && !selected.contains(&qualified_name) {
324                        continue;
325                    }
326                }
327
328                result.push((candid_col, val));
329            }
330        }
331
332        Ok(result)
333    }
334}
335
336#[cfg(test)]
337mod tests {
338
339    use wasm_dbms_api::prelude::{
340        Database as _, Filter, InsertRecord as _, Query, TableSchema as _, Text, Uint32, Value,
341    };
342    use wasm_dbms_macros::{DatabaseSchema, Table};
343    use wasm_dbms_memory::prelude::HeapMemoryProvider;
344
345    use crate::prelude::{DbmsContext, WasmDbmsDatabase};
346
347    // Use tables WITHOUT foreign key constraints so we can test all join
348    // types including unmatched rows without FK validation failures.
349
350    #[derive(Debug, Table, Clone, PartialEq, Eq)]
351    #[table = "departments"]
352    pub struct Department {
353        #[primary_key]
354        pub id: Uint32,
355        pub name: Text,
356    }
357
358    #[derive(Debug, Table, Clone, PartialEq, Eq)]
359    #[table = "employees"]
360    pub struct Employee {
361        #[primary_key]
362        pub id: Uint32,
363        pub name: Text,
364        pub dept_id: Uint32,
365    }
366
367    #[derive(DatabaseSchema)]
368    #[tables(Department = "departments", Employee = "employees")]
369    pub struct TestSchema;
370
371    #[derive(Debug, Table, Clone, PartialEq, Eq)]
372    #[table = "indexed_departments"]
373    pub struct IndexedDepartment {
374        #[primary_key]
375        pub id: Uint32,
376        pub name: Text,
377    }
378
379    #[derive(Debug, Table, Clone, PartialEq, Eq)]
380    #[table = "indexed_employees"]
381    pub struct IndexedEmployee {
382        #[primary_key]
383        pub id: Uint32,
384        pub name: Text,
385        #[index]
386        pub dept_id: Uint32,
387    }
388
389    #[derive(DatabaseSchema)]
390    #[tables(
391        IndexedDepartment = "indexed_departments",
392        IndexedEmployee = "indexed_employees"
393    )]
394    pub struct IndexedJoinSchema;
395
396    fn setup() -> DbmsContext<HeapMemoryProvider> {
397        let ctx = DbmsContext::new(HeapMemoryProvider::default());
398        TestSchema::register_tables(&ctx).unwrap();
399        ctx
400    }
401
402    fn setup_indexed() -> DbmsContext<HeapMemoryProvider> {
403        let ctx = DbmsContext::new(HeapMemoryProvider::default());
404        IndexedJoinSchema::register_tables(&ctx).unwrap();
405        ctx
406    }
407
408    fn insert_dept(db: &WasmDbmsDatabase<'_, HeapMemoryProvider>, id: u32, name: &str) {
409        let insert = DepartmentInsertRequest::from_values(&[
410            (Department::columns()[0], Value::Uint32(Uint32(id))),
411            (
412                Department::columns()[1],
413                Value::Text(Text(name.to_string())),
414            ),
415        ])
416        .unwrap();
417        db.insert::<Department>(insert).unwrap();
418    }
419
420    fn insert_emp(
421        db: &WasmDbmsDatabase<'_, HeapMemoryProvider>,
422        id: u32,
423        name: &str,
424        dept_id: u32,
425    ) {
426        let insert = EmployeeInsertRequest::from_values(&[
427            (Employee::columns()[0], Value::Uint32(Uint32(id))),
428            (Employee::columns()[1], Value::Text(Text(name.to_string()))),
429            (Employee::columns()[2], Value::Uint32(Uint32(dept_id))),
430        ])
431        .unwrap();
432        db.insert::<Employee>(insert).unwrap();
433    }
434
435    fn insert_indexed_dept(db: &WasmDbmsDatabase<'_, HeapMemoryProvider>, id: u32, name: &str) {
436        let insert = IndexedDepartmentInsertRequest::from_values(&[
437            (IndexedDepartment::columns()[0], Value::Uint32(Uint32(id))),
438            (
439                IndexedDepartment::columns()[1],
440                Value::Text(Text(name.to_string())),
441            ),
442        ])
443        .unwrap();
444        db.insert::<IndexedDepartment>(insert).unwrap();
445    }
446
447    fn insert_indexed_emp(
448        db: &WasmDbmsDatabase<'_, HeapMemoryProvider>,
449        id: u32,
450        name: &str,
451        dept_id: u32,
452    ) {
453        let insert = IndexedEmployeeInsertRequest::from_values(&[
454            (IndexedEmployee::columns()[0], Value::Uint32(Uint32(id))),
455            (
456                IndexedEmployee::columns()[1],
457                Value::Text(Text(name.to_string())),
458            ),
459            (
460                IndexedEmployee::columns()[2],
461                Value::Uint32(Uint32(dept_id)),
462            ),
463        ])
464        .unwrap();
465        db.insert::<IndexedEmployee>(insert).unwrap();
466    }
467
468    #[test]
469    fn test_inner_join() {
470        let ctx = setup();
471        let db = WasmDbmsDatabase::oneshot(&ctx, TestSchema);
472        insert_dept(&db, 1, "eng");
473        insert_dept(&db, 2, "hr");
474        insert_emp(&db, 10, "alice", 1);
475        insert_emp(&db, 11, "bob", 1);
476
477        let query = Query::builder()
478            .all()
479            .inner_join("employees", "id", "dept_id")
480            .build();
481        let results = db.select_join("departments", query).unwrap();
482        // eng has 2 employees, hr has 0 → 2 rows
483        assert_eq!(results.len(), 2);
484    }
485
486    #[test]
487    fn test_left_join() {
488        let ctx = setup();
489        let db = WasmDbmsDatabase::oneshot(&ctx, TestSchema);
490        insert_dept(&db, 1, "eng");
491        insert_dept(&db, 2, "hr");
492        insert_emp(&db, 10, "alice", 1);
493
494        let query = Query::builder()
495            .all()
496            .left_join("employees", "id", "dept_id")
497            .build();
498        let results = db.select_join("departments", query).unwrap();
499        // eng has 1 employee, hr has 0 but LEFT keeps unmatched left → 2 rows
500        assert_eq!(results.len(), 2);
501
502        // Find hr's row: employee columns should be Null
503        let hr_row = results
504            .iter()
505            .find(|row| {
506                row.iter().any(|(col, val)| {
507                    col.name == "name"
508                        && col.table.as_deref() == Some("departments")
509                        && *val == Value::Text(Text("hr".to_string()))
510                })
511            })
512            .expect("hr should be in results");
513
514        // hr's employee name should be Null
515        let emp_name = hr_row
516            .iter()
517            .find(|(col, _)| col.name == "name" && col.table.as_deref() == Some("employees"))
518            .expect("employee name column should exist for hr");
519        assert_eq!(emp_name.1, Value::Null);
520    }
521
522    #[test]
523    fn test_right_join() {
524        let ctx = setup();
525        let db = WasmDbmsDatabase::oneshot(&ctx, TestSchema);
526        insert_dept(&db, 1, "eng");
527        insert_emp(&db, 10, "alice", 1);
528        // charlie references dept 999 which doesn't exist (no FK constraint)
529        insert_emp(&db, 11, "charlie", 999);
530
531        let query = Query::builder()
532            .all()
533            .right_join("employees", "id", "dept_id")
534            .build();
535        let results = db.select_join("departments", query).unwrap();
536        // alice matches eng, charlie (dept_id=999) is unmatched right → 2 rows
537        assert_eq!(results.len(), 2);
538
539        // charlie should have null department columns
540        let charlie_row = results
541            .iter()
542            .find(|row| {
543                row.iter().any(|(col, val)| {
544                    col.name == "name"
545                        && col.table.as_deref() == Some("employees")
546                        && *val == Value::Text(Text("charlie".to_string()))
547                })
548            })
549            .expect("charlie should be in results");
550
551        let dept_name = charlie_row
552            .iter()
553            .find(|(col, _)| col.name == "name" && col.table.as_deref() == Some("departments"))
554            .expect("department name column should exist for charlie");
555        assert_eq!(dept_name.1, Value::Null);
556    }
557
558    #[test]
559    fn test_full_join() {
560        let ctx = setup();
561        let db = WasmDbmsDatabase::oneshot(&ctx, TestSchema);
562        insert_dept(&db, 1, "eng");
563        insert_dept(&db, 2, "hr");
564        insert_emp(&db, 10, "alice", 1);
565        // charlie references dept 999 which doesn't exist
566        insert_emp(&db, 11, "charlie", 999);
567
568        let query = Query::builder()
569            .all()
570            .full_join("employees", "id", "dept_id")
571            .build();
572        let results = db.select_join("departments", query).unwrap();
573        // eng-alice matched (1), hr unmatched left (1), charlie unmatched right (1) = 3
574        assert_eq!(results.len(), 3);
575    }
576
577    #[test]
578    fn test_join_with_filter() {
579        let ctx = setup();
580        let db = WasmDbmsDatabase::oneshot(&ctx, TestSchema);
581        insert_dept(&db, 1, "eng");
582        insert_dept(&db, 2, "hr");
583        insert_emp(&db, 10, "alice", 1);
584        insert_emp(&db, 11, "bob", 2);
585
586        let query = Query::builder()
587            .all()
588            .inner_join("employees", "id", "dept_id")
589            .and_where(Filter::eq(
590                "departments.name",
591                Value::Text(Text("eng".to_string())),
592            ))
593            .build();
594        let results = db.select_join("departments", query).unwrap();
595        assert_eq!(results.len(), 1);
596    }
597
598    #[test]
599    fn test_join_with_order_by() {
600        let ctx = setup();
601        let db = WasmDbmsDatabase::oneshot(&ctx, TestSchema);
602        insert_dept(&db, 1, "eng");
603        insert_dept(&db, 2, "hr");
604        insert_emp(&db, 10, "zzz", 1);
605        insert_emp(&db, 11, "aaa", 2);
606
607        let query = Query::builder()
608            .all()
609            .inner_join("employees", "id", "dept_id")
610            .order_by_asc("employees.name")
611            .build();
612        let results = db.select_join("departments", query).unwrap();
613        assert_eq!(results.len(), 2);
614        let first_name = results[0]
615            .iter()
616            .find(|(col, _)| col.name == "name" && col.table.as_deref() == Some("employees"))
617            .unwrap();
618        assert_eq!(first_name.1, Value::Text(Text("aaa".to_string())));
619    }
620
621    #[test]
622    fn test_join_with_limit() {
623        let ctx = setup();
624        let db = WasmDbmsDatabase::oneshot(&ctx, TestSchema);
625        insert_dept(&db, 1, "eng");
626        insert_dept(&db, 2, "hr");
627        insert_emp(&db, 10, "alice", 1);
628        insert_emp(&db, 11, "bob", 2);
629
630        let query = Query::builder()
631            .all()
632            .inner_join("employees", "id", "dept_id")
633            .limit(1)
634            .build();
635        let results = db.select_join("departments", query).unwrap();
636        assert_eq!(results.len(), 1);
637    }
638
639    #[test]
640    fn test_join_with_offset() {
641        let ctx = setup();
642        let db = WasmDbmsDatabase::oneshot(&ctx, TestSchema);
643        insert_dept(&db, 1, "eng");
644        insert_dept(&db, 2, "hr");
645        insert_emp(&db, 10, "alice", 1);
646        insert_emp(&db, 11, "bob", 2);
647
648        let query = Query::builder()
649            .all()
650            .inner_join("employees", "id", "dept_id")
651            .offset(1)
652            .build();
653        let results = db.select_join("departments", query).unwrap();
654        assert_eq!(results.len(), 1);
655    }
656
657    #[test]
658    fn test_join_with_column_selection() {
659        let ctx = setup();
660        let db = WasmDbmsDatabase::oneshot(&ctx, TestSchema);
661        insert_dept(&db, 1, "eng");
662        insert_emp(&db, 10, "alice", 1);
663
664        let query = Query::builder()
665            .field("departments.name")
666            .field("employees.name")
667            .inner_join("employees", "id", "dept_id")
668            .build();
669        let results = db.select_join("departments", query).unwrap();
670        assert_eq!(results.len(), 1);
671        assert_eq!(results[0].len(), 2);
672    }
673
674    #[test]
675    fn test_inner_join_empty_result() {
676        let ctx = setup();
677        let db = WasmDbmsDatabase::oneshot(&ctx, TestSchema);
678        insert_dept(&db, 1, "eng");
679        // No employees
680
681        let query = Query::builder()
682            .all()
683            .inner_join("employees", "id", "dept_id")
684            .build();
685        let results = db.select_join("departments", query).unwrap();
686        assert!(results.is_empty());
687    }
688
689    #[test]
690    fn test_join_offset_exceeding_results_returns_empty() {
691        let ctx = setup();
692        let db = WasmDbmsDatabase::oneshot(&ctx, TestSchema);
693        insert_dept(&db, 1, "eng");
694        insert_emp(&db, 10, "alice", 1);
695
696        let query = Query::builder()
697            .all()
698            .inner_join("employees", "id", "dept_id")
699            .offset(100)
700            .build();
701        let results = db.select_join("departments", query).unwrap();
702        assert!(results.is_empty());
703    }
704
705    #[test]
706    fn test_join_on_indexed_column() {
707        let ctx = setup_indexed();
708        let db = WasmDbmsDatabase::oneshot(&ctx, IndexedJoinSchema);
709        insert_indexed_dept(&db, 1, "eng");
710        insert_indexed_dept(&db, 2, "hr");
711        insert_indexed_emp(&db, 10, "alice", 1);
712        insert_indexed_emp(&db, 11, "bob", 2);
713
714        let query = Query::builder()
715            .all()
716            .inner_join(
717                "indexed_employees",
718                "indexed_departments.id",
719                "indexed_employees.dept_id",
720            )
721            .build();
722        let results = db.select_join("indexed_departments", query).unwrap();
723
724        assert_eq!(results.len(), 2);
725        assert!(results.iter().any(|row| {
726            row.iter().any(|(column, value)| {
727                column.name == "name"
728                    && column.table.as_deref() == Some("indexed_employees")
729                    && *value == Value::Text(Text("alice".to_string()))
730            })
731        }));
732    }
733}