1use anyhow::{anyhow, Result};
4use std::collections::HashMap;
5use std::sync::Arc;
6use tracing::{debug, info};
7
8use crate::data::datatable::{DataColumn, DataRow, DataTable, DataValue};
9use crate::sql::parser::ast::{JoinClause, JoinOperator, JoinType};
10
11pub struct HashJoinExecutor {
13 case_insensitive: bool,
14}
15
16impl HashJoinExecutor {
17 pub fn new(case_insensitive: bool) -> Self {
18 Self { case_insensitive }
19 }
20
21 pub fn execute_join(
23 &self,
24 left_table: Arc<DataTable>,
25 join_clause: &JoinClause,
26 right_table: Arc<DataTable>,
27 ) -> Result<DataTable> {
28 info!(
29 "Executing {:?} JOIN: {} rows x {} rows",
30 join_clause.join_type,
31 left_table.row_count(),
32 right_table.row_count()
33 );
34
35 let (left_col_name, right_col_name) = self.parse_join_columns(join_clause)?;
37
38 let left_col_idx = self.find_column_index(&left_table, &left_col_name)?;
40 let right_col_idx = self.find_column_index(&right_table, &right_col_name)?;
41
42 match join_clause.join_type {
44 JoinType::Inner => self.hash_join_inner(
45 left_table,
46 right_table,
47 left_col_idx,
48 right_col_idx,
49 &left_col_name,
50 &right_col_name,
51 ),
52 JoinType::Left => self.hash_join_left(
53 left_table,
54 right_table,
55 left_col_idx,
56 right_col_idx,
57 &left_col_name,
58 &right_col_name,
59 ),
60 JoinType::Right => {
61 self.hash_join_left(
63 right_table,
64 left_table,
65 right_col_idx,
66 left_col_idx,
67 &right_col_name,
68 &left_col_name,
69 )
70 }
71 JoinType::Cross => self.cross_join(left_table, right_table),
72 JoinType::Full => {
73 return Err(anyhow!("FULL OUTER JOIN not yet implemented"));
74 }
75 }
76 }
77
78 fn parse_join_columns(&self, join_clause: &JoinClause) -> Result<(String, String)> {
80 if join_clause.condition.operator != JoinOperator::Equal {
82 return Err(anyhow!(
83 "Only equality JOIN conditions are currently supported"
84 ));
85 }
86
87 Ok((
88 join_clause.condition.left_column.clone(),
89 join_clause.condition.right_column.clone(),
90 ))
91 }
92
93 fn find_column_index(&self, table: &DataTable, col_name: &str) -> Result<usize> {
95 let col_name = if let Some(dot_pos) = col_name.rfind('.') {
97 &col_name[dot_pos + 1..]
98 } else {
99 col_name
100 };
101
102 table
103 .columns
104 .iter()
105 .position(|col| {
106 if self.case_insensitive {
107 col.name.to_lowercase() == col_name.to_lowercase()
108 } else {
109 col.name == col_name
110 }
111 })
112 .ok_or_else(|| anyhow!("Column '{}' not found in table", col_name))
113 }
114
115 fn hash_join_inner(
117 &self,
118 left_table: Arc<DataTable>,
119 right_table: Arc<DataTable>,
120 left_col_idx: usize,
121 right_col_idx: usize,
122 _left_col_name: &str,
123 _right_col_name: &str,
124 ) -> Result<DataTable> {
125 let start = std::time::Instant::now();
126
127 let (build_table, probe_table, build_col_idx, probe_col_idx, build_is_left) =
129 if left_table.row_count() <= right_table.row_count() {
130 (
131 left_table.clone(),
132 right_table.clone(),
133 left_col_idx,
134 right_col_idx,
135 true,
136 )
137 } else {
138 (
139 right_table.clone(),
140 left_table.clone(),
141 right_col_idx,
142 left_col_idx,
143 false,
144 )
145 };
146
147 debug!(
148 "Building hash index on {} table ({} rows)",
149 if build_is_left { "left" } else { "right" },
150 build_table.row_count()
151 );
152
153 let mut hash_index: HashMap<DataValue, Vec<usize>> = HashMap::new();
155 for (row_idx, row) in build_table.rows.iter().enumerate() {
156 let key = row.values[build_col_idx].clone();
157 hash_index.entry(key).or_default().push(row_idx);
158 }
159
160 debug!(
161 "Hash index built with {} unique keys in {:?}",
162 hash_index.len(),
163 start.elapsed()
164 );
165
166 let mut result = DataTable::new("joined");
168
169 for col in &left_table.columns {
171 result.add_column(DataColumn {
172 name: col.name.clone(),
173 data_type: col.data_type.clone(),
174 nullable: col.nullable,
175 unique_values: col.unique_values,
176 null_count: col.null_count,
177 metadata: col.metadata.clone(),
178 });
179 }
180
181 for col in &right_table.columns {
183 if !left_table
185 .columns
186 .iter()
187 .any(|left_col| left_col.name == col.name)
188 {
189 result.add_column(DataColumn {
190 name: col.name.clone(),
191 data_type: col.data_type.clone(),
192 nullable: col.nullable,
193 unique_values: col.unique_values,
194 null_count: col.null_count,
195 metadata: col.metadata.clone(),
196 });
197 } else {
198 result.add_column(DataColumn {
200 name: format!("{}_right", col.name),
201 data_type: col.data_type.clone(),
202 nullable: col.nullable,
203 unique_values: col.unique_values,
204 null_count: col.null_count,
205 metadata: col.metadata.clone(),
206 });
207 }
208 }
209
210 debug!(
211 "Joined table will have {} columns: {:?}",
212 result.column_count(),
213 result.column_names()
214 );
215
216 let mut match_count = 0;
218 for probe_row in &probe_table.rows {
219 let probe_key = &probe_row.values[probe_col_idx];
220
221 if let Some(matching_indices) = hash_index.get(probe_key) {
222 for &build_idx in matching_indices {
223 let build_row = &build_table.rows[build_idx];
224
225 let mut joined_row = DataRow { values: Vec::new() };
227
228 if build_is_left {
229 joined_row.values.extend_from_slice(&build_row.values);
231 joined_row.values.extend_from_slice(&probe_row.values);
232 } else {
233 joined_row.values.extend_from_slice(&probe_row.values);
235 joined_row.values.extend_from_slice(&build_row.values);
236 }
237
238 result.add_row(joined_row);
239 match_count += 1;
240 }
241 }
242 }
243
244 info!(
245 "INNER JOIN complete: {} matches found in {:?}",
246 match_count,
247 start.elapsed()
248 );
249
250 Ok(result)
251 }
252
253 fn hash_join_left(
255 &self,
256 left_table: Arc<DataTable>,
257 right_table: Arc<DataTable>,
258 left_col_idx: usize,
259 right_col_idx: usize,
260 _left_col_name: &str,
261 _right_col_name: &str,
262 ) -> Result<DataTable> {
263 let start = std::time::Instant::now();
264
265 debug!(
266 "Building hash index on right table ({} rows)",
267 right_table.row_count()
268 );
269
270 let mut hash_index: HashMap<DataValue, Vec<usize>> = HashMap::new();
272 for (row_idx, row) in right_table.rows.iter().enumerate() {
273 let key = row.values[right_col_idx].clone();
274 hash_index.entry(key).or_default().push(row_idx);
275 }
276
277 let mut result = DataTable::new("joined");
279
280 for col in &left_table.columns {
282 result.add_column(DataColumn {
283 name: col.name.clone(),
284 data_type: col.data_type.clone(),
285 nullable: col.nullable,
286 unique_values: col.unique_values,
287 null_count: col.null_count,
288 metadata: col.metadata.clone(),
289 });
290 }
291
292 for col in &right_table.columns {
294 if !left_table
296 .columns
297 .iter()
298 .any(|left_col| left_col.name == col.name)
299 {
300 result.add_column(DataColumn {
301 name: col.name.clone(),
302 data_type: col.data_type.clone(),
303 nullable: true, unique_values: col.unique_values,
305 null_count: col.null_count,
306 metadata: col.metadata.clone(),
307 });
308 } else {
309 result.add_column(DataColumn {
311 name: format!("{}_right", col.name),
312 data_type: col.data_type.clone(),
313 nullable: true, unique_values: col.unique_values,
315 null_count: col.null_count,
316 metadata: col.metadata.clone(),
317 });
318 }
319 }
320
321 debug!(
322 "LEFT JOIN table will have {} columns: {:?}",
323 result.column_count(),
324 result.column_names()
325 );
326
327 let mut match_count = 0;
329 let mut null_count = 0;
330
331 for left_row in &left_table.rows {
332 let left_key = &left_row.values[left_col_idx];
333
334 if let Some(matching_indices) = hash_index.get(left_key) {
335 for &right_idx in matching_indices {
337 let right_row = &right_table.rows[right_idx];
338
339 let mut joined_row = DataRow { values: Vec::new() };
340 joined_row.values.extend_from_slice(&left_row.values);
341 joined_row.values.extend_from_slice(&right_row.values);
342
343 result.add_row(joined_row);
344 match_count += 1;
345 }
346 } else {
347 let mut joined_row = DataRow { values: Vec::new() };
349 joined_row.values.extend_from_slice(&left_row.values);
350
351 for _ in 0..right_table.column_count() {
353 joined_row.values.push(DataValue::Null);
354 }
355
356 result.add_row(joined_row);
357 null_count += 1;
358 }
359 }
360
361 info!(
362 "LEFT JOIN complete: {} matches, {} nulls in {:?}",
363 match_count,
364 null_count,
365 start.elapsed()
366 );
367
368 Ok(result)
369 }
370
371 fn cross_join(
373 &self,
374 left_table: Arc<DataTable>,
375 right_table: Arc<DataTable>,
376 ) -> Result<DataTable> {
377 let start = std::time::Instant::now();
378
379 let result_rows = left_table.row_count() * right_table.row_count();
381 if result_rows > 1_000_000 {
382 return Err(anyhow!(
383 "CROSS JOIN would produce {} rows, which exceeds the safety limit",
384 result_rows
385 ));
386 }
387
388 let mut result = DataTable::new("joined");
390
391 for col in &left_table.columns {
393 result.add_column(col.clone());
394 }
395 for col in &right_table.columns {
396 result.add_column(col.clone());
397 }
398
399 for left_row in &left_table.rows {
401 for right_row in &right_table.rows {
402 let mut joined_row = DataRow { values: Vec::new() };
403 joined_row.values.extend_from_slice(&left_row.values);
404 joined_row.values.extend_from_slice(&right_row.values);
405 result.add_row(joined_row);
406 }
407 }
408
409 info!(
410 "CROSS JOIN complete: {} rows in {:?}",
411 result.row_count(),
412 start.elapsed()
413 );
414
415 Ok(result)
416 }
417
418 fn qualify_column_name(
420 &self,
421 col_name: &str,
422 table_side: &str,
423 left_join_col: &str,
424 right_join_col: &str,
425 ) -> String {
426 let base_name = if let Some(dot_pos) = col_name.rfind('.') {
428 &col_name[dot_pos + 1..]
429 } else {
430 col_name
431 };
432
433 let left_base = if let Some(dot_pos) = left_join_col.rfind('.') {
434 &left_join_col[dot_pos + 1..]
435 } else {
436 left_join_col
437 };
438
439 let right_base = if let Some(dot_pos) = right_join_col.rfind('.') {
440 &right_join_col[dot_pos + 1..]
441 } else {
442 right_join_col
443 };
444
445 if base_name == left_base || base_name == right_base {
447 format!("{}_{}", table_side, base_name)
448 } else {
449 col_name.to_string()
450 }
451 }
452}