1use crate::{Result, SqlRsError, Storage, Value, Schema, Column, ColumnType};
2use super::{Statement, WhereExpression, WhereClause, OrderBy, SelectColumn, AggregateFunction, Join, JoinType};
3use std::collections::HashMap;
4
5pub struct QueryExecutor<S: Storage> {
6 storage: S,
7 schemas: HashMap<String, Schema>,
8}
9
10impl<S: Storage> QueryExecutor<S> {
11 pub fn new(storage: S) -> Self {
12 Self {
13 storage,
14 schemas: HashMap::new(),
15 }
16 }
17
18 pub fn execute(&mut self, statement: Statement) -> Result<Vec<Vec<Value>>> {
19 match statement {
20 Statement::CreateTable { name, columns } => {
21 self.create_table(name, columns)?;
22 Ok(vec![])
23 }
24 Statement::Insert { table, columns, values } => {
25 self.insert(table, columns, values)?;
26 Ok(vec![])
27 }
28 Statement::Select { table, columns, joins, where_clause, group_by, order_by, limit, offset } => {
29 self.select(table, columns, joins, where_clause, group_by, order_by, limit, offset)
30 }
31 Statement::Update { table, sets, where_clause } => {
32 let count = self.update(table, sets, where_clause)?;
33 Ok(vec![vec![Value::Integer(count as i64)]])
34 }
35 Statement::Delete { table, where_clause } => {
36 let count = self.delete(table, where_clause)?;
37 Ok(vec![vec![Value::Integer(count as i64)]])
38 }
39 Statement::DropTable { name } => {
40 self.drop_table(name)?;
41 Ok(vec![])
42 }
43 Statement::CreateIndex { index_name, table, column } => {
44 self.create_index(index_name, table, column)?;
45 Ok(vec![])
46 }
47 Statement::DropIndex { index_name } => {
48 self.drop_index(index_name)?;
49 Ok(vec![])
50 }
51 }
52 }
53
54 fn create_table(&mut self, name: String, columns: Vec<(String, String)>) -> Result<()> {
55 let mut schema = Schema::new(&name);
56
57 for (col_name, col_type) in columns {
58 let column_type = match col_type.as_str() {
59 "INTEGER" | "INT" => ColumnType::Integer,
60 "FLOAT" | "REAL" | "DOUBLE" => ColumnType::Float,
61 "TEXT" | "VARCHAR" | "STRING" => ColumnType::Text,
62 "BLOB" => ColumnType::Blob,
63 "BOOLEAN" | "BOOL" => ColumnType::Boolean,
64 _ => return Err(SqlRsError::Query(format!("Unknown type: {}", col_type))),
65 };
66
67 schema = schema.add_column(Column::new(col_name, column_type));
68 }
69
70 let schema_key = format!("__schema__{}", name);
71 let schema_bytes = bincode::serialize(&schema)
72 .map_err(|e| SqlRsError::Serialization(e.to_string()))?;
73
74 self.storage.put(schema_key.as_bytes(), &schema_bytes)?;
75 self.schemas.insert(name, schema);
76
77 Ok(())
78 }
79
80 fn insert(&mut self, table: String, _columns: Vec<String>, values: Vec<Value>) -> Result<()> {
81 let schema = self.get_schema(&table)?;
82
83 if values.len() != schema.columns.len() {
84 return Err(SqlRsError::Query(format!(
85 "Expected {} values, got {}",
86 schema.columns.len(),
87 values.len()
88 )));
89 }
90
91 let row_id = self.next_row_id(&table)?;
92 let key = format!("{}:{}", table, row_id);
93
94 let row_bytes = bincode::serialize(&values)
95 .map_err(|e| SqlRsError::Serialization(e.to_string()))?;
96
97 self.storage.put(key.as_bytes(), &row_bytes)?;
98
99 Ok(())
100 }
101
102 fn update(&mut self, table: String, sets: Vec<(String, Value)>, where_clause: Option<WhereClause>) -> Result<usize> {
103 let schema = self.get_schema(&table)?;
104
105 let mut update_map = std::collections::HashMap::new();
107 for (col_name, value) in &sets {
108 if let Some(idx) = schema.columns.iter().position(|c| &c.name == col_name) {
109 update_map.insert(idx, value.clone());
110 } else {
111 return Err(SqlRsError::Query(format!("Column '{}' not found", col_name)));
112 }
113 }
114
115 let prefix = format!("{}:", table);
116 let scan_results = self.storage.scan(prefix.as_bytes(), &[0xFF; 256])?;
117
118 let mut updated_count = 0;
119 let mut keys_to_update = Vec::new();
120 let mut new_rows = Vec::new();
121
122 for (key, value_bytes) in scan_results {
124 let row: Vec<Value> = bincode::deserialize(&value_bytes)
125 .map_err(|e| SqlRsError::Serialization(e.to_string()))?;
126
127 if let Some(ref wc) = where_clause {
128 if !self.evaluate_where(&row, wc, &schema)? {
129 continue;
130 }
131 }
132
133 let mut new_row = row.clone();
135 for (idx, value) in &update_map {
136 if *idx < new_row.len() {
137 new_row[*idx] = value.clone();
138 }
139 }
140
141 keys_to_update.push(key);
142 new_rows.push(new_row);
143 updated_count += 1;
144 }
145
146 for (key, new_row) in keys_to_update.into_iter().zip(new_rows.into_iter()) {
148 let row_bytes = bincode::serialize(&new_row)
149 .map_err(|e| SqlRsError::Serialization(e.to_string()))?;
150 self.storage.put(&key, &row_bytes)?;
151 }
152
153 Ok(updated_count)
154 }
155
156 fn select(
157 &self,
158 table: String,
159 columns: Vec<SelectColumn>,
160 joins: Vec<Join>,
161 where_clause: Option<WhereClause>,
162 group_by: Option<Vec<String>>,
163 order_by: Option<OrderBy>,
164 limit: Option<usize>,
165 offset: Option<usize>,
166 ) -> Result<Vec<Vec<Value>>> {
167 let schema = self.get_schema(&table)?;
168
169 if !joins.is_empty() {
171 return self.select_with_joins(table, columns, joins, where_clause, order_by, limit, offset);
172 }
173
174 let has_aggregates = columns.iter().any(|c| matches!(c, SelectColumn::Aggregate { .. }));
176
177 if has_aggregates || group_by.is_some() {
178 return self.select_aggregate(table, columns, where_clause, group_by, schema);
180 }
181
182 let prefix = format!("{}:", table);
183 let scan_results = self.storage.scan(prefix.as_bytes(), &[0xFF; 256])?;
184
185 let mut results = Vec::new();
186
187 for (_key, value_bytes) in scan_results {
188 let row: Vec<Value> = bincode::deserialize(&value_bytes)
189 .map_err(|e| SqlRsError::Serialization(e.to_string()))?;
190
191 if let Some(ref wc) = where_clause {
192 if !self.evaluate_where(&row, wc, &schema)? {
193 continue;
194 }
195 }
196
197 if columns.len() == 1 && matches!(&columns[0], SelectColumn::Star) {
198 results.push(row);
199 } else {
200 let mut selected_row = Vec::new();
202 for col in &columns {
203 match col {
204 SelectColumn::Star => results.push(row.clone()),
205 SelectColumn::Column(col_name) => {
206 if let Some(idx) = schema.columns.iter().position(|c| &c.name == col_name) {
207 if idx < row.len() {
208 selected_row.push(row[idx].clone());
209 }
210 } else {
211 return Err(SqlRsError::Query(format!("Column '{}' not found", col_name)));
212 }
213 }
214 SelectColumn::Aggregate { .. } => {
215 return Err(SqlRsError::Query("Cannot mix aggregates and non-aggregates without GROUP BY".to_string()));
217 }
218 }
219 }
220 if !selected_row.is_empty() && (columns.len() != 1 || !matches!(&columns[0], SelectColumn::Star)) {
221 results.push(selected_row);
222 }
223 }
224 }
225
226 if let Some(order) = order_by {
228 let col_idx = schema.columns.iter().position(|c| c.name == order.column)
229 .ok_or_else(|| SqlRsError::Query(format!("Column '{}' not found in ORDER BY", order.column)))?;
230
231 results.sort_by(|a, b| {
232 if col_idx >= a.len() || col_idx >= b.len() {
233 return std::cmp::Ordering::Equal;
234 }
235
236 let cmp = self.compare_values(&a[col_idx], &b[col_idx]);
237 if order.ascending {
238 cmp
239 } else {
240 cmp.reverse()
241 }
242 });
243 }
244
245 let start = offset.unwrap_or(0);
247 if start >= results.len() {
248 return Ok(vec![]);
249 }
250
251 let end = if let Some(lim) = limit {
253 std::cmp::min(start + lim, results.len())
254 } else {
255 results.len()
256 };
257
258 Ok(results[start..end].to_vec())
259 }
260
261 fn select_aggregate(
262 &self,
263 table: String,
264 columns: Vec<SelectColumn>,
265 where_clause: Option<WhereClause>,
266 group_by: Option<Vec<String>>,
267 schema: Schema,
268 ) -> Result<Vec<Vec<Value>>> {
269 let prefix = format!("{}:", table);
270 let scan_results = self.storage.scan(prefix.as_bytes(), &[0xFF; 256])?;
271
272 let mut rows = Vec::new();
274 for (_key, value_bytes) in scan_results {
275 let row: Vec<Value> = bincode::deserialize(&value_bytes)
276 .map_err(|e| SqlRsError::Serialization(e.to_string()))?;
277
278 if let Some(ref wc) = where_clause {
279 if !self.evaluate_where(&row, wc, &schema)? {
280 continue;
281 }
282 }
283
284 rows.push(row);
285 }
286
287 if let Some(ref group_cols) = group_by {
289 return self.select_with_group_by(rows, columns, group_cols, &schema);
290 }
291
292 let mut result_row = Vec::new();
294 for col in &columns {
295 if let SelectColumn::Aggregate { function, column } = col {
296 let value = match function {
297 AggregateFunction::Count => {
298 if column == "*" {
299 Value::Integer(rows.len() as i64)
300 } else {
301 let col_idx = schema.columns.iter().position(|c| &c.name == column)
303 .ok_or_else(|| SqlRsError::Query(format!("Column '{}' not found", column)))?;
304 let count = rows.iter().filter(|r| col_idx < r.len() && !matches!(r[col_idx], Value::Null)).count();
305 Value::Integer(count as i64)
306 }
307 }
308 AggregateFunction::Sum => {
309 let col_idx = schema.columns.iter().position(|c| &c.name == column)
310 .ok_or_else(|| SqlRsError::Query(format!("Column '{}' not found", column)))?;
311 let mut sum: i64 = 0;
312 let mut float_sum: f64 = 0.0;
313 let mut is_float = false;
314
315 for row in &rows {
316 if col_idx < row.len() {
317 match &row[col_idx] {
318 Value::Integer(i) => sum += i,
319 Value::Float(f) => {
320 float_sum += f;
321 is_float = true;
322 }
323 Value::Null => {}
324 _ => return Err(SqlRsError::Query(format!("Cannot SUM non-numeric column '{}'", column))),
325 }
326 }
327 }
328
329 if is_float {
330 Value::Float(float_sum + sum as f64)
331 } else {
332 Value::Integer(sum)
333 }
334 }
335 AggregateFunction::Avg => {
336 let col_idx = schema.columns.iter().position(|c| &c.name == column)
337 .ok_or_else(|| SqlRsError::Query(format!("Column '{}' not found", column)))?;
338 let mut sum: f64 = 0.0;
339 let mut count = 0;
340
341 for row in &rows {
342 if col_idx < row.len() {
343 match &row[col_idx] {
344 Value::Integer(i) => {
345 sum += *i as f64;
346 count += 1;
347 }
348 Value::Float(f) => {
349 sum += f;
350 count += 1;
351 }
352 Value::Null => {}
353 _ => return Err(SqlRsError::Query(format!("Cannot AVG non-numeric column '{}'", column))),
354 }
355 }
356 }
357
358 if count > 0 {
359 Value::Float(sum / count as f64)
360 } else {
361 Value::Null
362 }
363 }
364 AggregateFunction::Min => {
365 let col_idx = schema.columns.iter().position(|c| &c.name == column)
366 .ok_or_else(|| SqlRsError::Query(format!("Column '{}' not found", column)))?;
367 let mut min_value: Option<Value> = None;
368
369 for row in &rows {
370 if col_idx < row.len() {
371 let current = &row[col_idx];
372 if !matches!(current, Value::Null) {
373 if let Some(ref min) = min_value {
374 if self.compare_values(current, min) == std::cmp::Ordering::Less {
375 min_value = Some(current.clone());
376 }
377 } else {
378 min_value = Some(current.clone());
379 }
380 }
381 }
382 }
383
384 min_value.unwrap_or(Value::Null)
385 }
386 AggregateFunction::Max => {
387 let col_idx = schema.columns.iter().position(|c| &c.name == column)
388 .ok_or_else(|| SqlRsError::Query(format!("Column '{}' not found", column)))?;
389 let mut max_value: Option<Value> = None;
390
391 for row in &rows {
392 if col_idx < row.len() {
393 let current = &row[col_idx];
394 if !matches!(current, Value::Null) {
395 if let Some(ref max) = max_value {
396 if self.compare_values(current, max) == std::cmp::Ordering::Greater {
397 max_value = Some(current.clone());
398 }
399 } else {
400 max_value = Some(current.clone());
401 }
402 }
403 }
404 }
405
406 max_value.unwrap_or(Value::Null)
407 }
408 };
409 result_row.push(value);
410 }
411 }
412
413 Ok(vec![result_row])
414 }
415
416 fn select_with_joins(
417 &self,
418 table: String,
419 columns: Vec<SelectColumn>,
420 joins: Vec<Join>,
421 where_clause: Option<WhereClause>,
422 order_by: Option<OrderBy>,
423 limit: Option<usize>,
424 offset: Option<usize>,
425 ) -> Result<Vec<Vec<Value>>> {
426
427 let main_schema = self.get_schema(&table)?;
429
430 let mut join_schemas = Vec::new();
432 for join in &joins {
433 let schema = self.get_schema(&join.table)?;
434 join_schemas.push(schema);
435 }
436
437 let prefix = format!("{}:", table);
439 let scan_results = self.storage.scan(prefix.as_bytes(), &[0xFF; 256])?;
440 let mut main_rows = Vec::new();
441 for (key, value_bytes) in scan_results {
442 let key_str = String::from_utf8_lossy(&key);
443 if !key_str.starts_with(&prefix) {
444 continue;
445 }
446 let row: Vec<Value> = bincode::deserialize(&value_bytes)
447 .map_err(|e| SqlRsError::Serialization(e.to_string()))?;
448 main_rows.push(row);
449 }
450
451 let mut joined_rows = main_rows.clone();
453 let mut combined_schema = main_schema.clone();
454
455 for (join_idx, join) in joins.iter().enumerate() {
456 let join_schema = &join_schemas[join_idx];
457
458 let join_prefix = format!("{}:", join.table);
460 let join_scan = self.storage.scan(join_prefix.as_bytes(), &[0xFF; 256])?;
461 let mut join_table_rows = Vec::new();
462 for (key, value_bytes) in join_scan {
463 let key_str = String::from_utf8_lossy(&key);
464 if !key_str.starts_with(&join_prefix) {
465 continue;
466 }
467 let row: Vec<Value> = bincode::deserialize(&value_bytes)
468 .map_err(|e| SqlRsError::Serialization(e.to_string()))?;
469 join_table_rows.push(row);
470 }
471
472 let (left_table, left_col) = self.parse_table_column(&join.on_left)?;
474 let (_right_table, right_col) = self.parse_table_column(&join.on_right)?;
475
476 let left_idx = if left_table == table || left_table.is_empty() {
478 combined_schema.columns.iter().position(|c| c.name == left_col)
479 .ok_or_else(|| SqlRsError::Query(format!("Column '{}' not found", left_col)))?
480 } else {
481 combined_schema.columns.iter().position(|c| c.name == left_col)
482 .ok_or_else(|| SqlRsError::Query(format!("Column '{}' not found", left_col)))?
483 };
484
485 let right_idx = join_schema.columns.iter().position(|c| c.name == right_col)
486 .ok_or_else(|| SqlRsError::Query(format!("Column '{}' not found in table '{}'", right_col, join.table)))?;
487
488 let mut new_joined_rows = Vec::new();
490
491 match join.join_type {
492 JoinType::Inner => {
493 for left_row in &joined_rows {
494 for right_row in &join_table_rows {
495 if left_idx < left_row.len() && right_idx < right_row.len() {
496 if self.compare_values(&left_row[left_idx], &right_row[right_idx]) == std::cmp::Ordering::Equal {
497 let mut combined_row = left_row.clone();
498 combined_row.extend(right_row.clone());
499 new_joined_rows.push(combined_row);
500 }
501 }
502 }
503 }
504 }
505 JoinType::Left => {
506 for left_row in &joined_rows {
507 let mut matched = false;
508 for right_row in &join_table_rows {
509 if left_idx < left_row.len() && right_idx < right_row.len() {
510 if self.compare_values(&left_row[left_idx], &right_row[right_idx]) == std::cmp::Ordering::Equal {
511 let mut combined_row = left_row.clone();
512 combined_row.extend(right_row.clone());
513 new_joined_rows.push(combined_row);
514 matched = true;
515 }
516 }
517 }
518 if !matched {
519 let mut combined_row = left_row.clone();
520 combined_row.extend(vec![Value::Null; join_schema.columns.len()]);
521 new_joined_rows.push(combined_row);
522 }
523 }
524 }
525 JoinType::Right => {
526 for right_row in &join_table_rows {
527 let mut matched = false;
528 for left_row in &joined_rows {
529 if left_idx < left_row.len() && right_idx < right_row.len() {
530 if self.compare_values(&left_row[left_idx], &right_row[right_idx]) == std::cmp::Ordering::Equal {
531 let mut combined_row = left_row.clone();
532 combined_row.extend(right_row.clone());
533 new_joined_rows.push(combined_row);
534 matched = true;
535 }
536 }
537 }
538 if !matched {
539 let mut combined_row = vec![Value::Null; combined_schema.columns.len()];
540 combined_row.extend(right_row.clone());
541 new_joined_rows.push(combined_row);
542 }
543 }
544 }
545 }
546
547 joined_rows = new_joined_rows;
548
549 for col in &join_schema.columns {
551 combined_schema = combined_schema.add_column(col.clone());
552 }
553 }
554
555 if let Some(ref wc) = where_clause {
557 joined_rows.retain(|row| {
558 self.evaluate_where(row, wc, &combined_schema).unwrap_or(false)
559 });
560 }
561
562 let mut results = Vec::new();
564 for row in joined_rows {
565 if columns.len() == 1 && matches!(&columns[0], SelectColumn::Star) {
566 results.push(row);
567 } else {
568 let mut selected_row = Vec::new();
569 for col in &columns {
570 match col {
571 SelectColumn::Star => {
572 selected_row.extend(row.clone());
573 }
574 SelectColumn::Column(col_name) => {
575 let (_table, column) = if col_name.contains('.') {
577 self.parse_table_column(col_name)?
578 } else {
579 (String::new(), col_name.clone())
580 };
581
582 if let Some(idx) = combined_schema.columns.iter().position(|c| c.name == column) {
583 if idx < row.len() {
584 selected_row.push(row[idx].clone());
585 }
586 } else {
587 return Err(SqlRsError::Query(format!("Column '{}' not found", col_name)));
588 }
589 }
590 SelectColumn::Aggregate { .. } => {
591 return Err(SqlRsError::Query("Aggregates with JOIN not yet supported".to_string()));
592 }
593 }
594 }
595 if !selected_row.is_empty() {
596 results.push(selected_row);
597 }
598 }
599 }
600
601 if let Some(order) = order_by {
603 let col_idx = combined_schema.columns.iter().position(|c| c.name == order.column)
604 .ok_or_else(|| SqlRsError::Query(format!("Column '{}' not found in ORDER BY", order.column)))?;
605
606 results.sort_by(|a, b| {
607 if col_idx >= a.len() || col_idx >= b.len() {
608 return std::cmp::Ordering::Equal;
609 }
610
611 let cmp = self.compare_values(&a[col_idx], &b[col_idx]);
612 if order.ascending {
613 cmp
614 } else {
615 cmp.reverse()
616 }
617 });
618 }
619
620 let start = offset.unwrap_or(0);
622 if start >= results.len() {
623 return Ok(vec![]);
624 }
625
626 let end = if let Some(lim) = limit {
628 std::cmp::min(start + lim, results.len())
629 } else {
630 results.len()
631 };
632
633 Ok(results[start..end].to_vec())
634 }
635
636 fn parse_table_column(&self, col_str: &str) -> Result<(String, String)> {
637 if let Some(dot_pos) = col_str.find('.') {
638 let table = col_str[..dot_pos].trim().to_string();
639 let column = col_str[dot_pos + 1..].trim().to_string();
640 Ok((table, column))
641 } else {
642 Ok((String::new(), col_str.trim().to_string()))
643 }
644 }
645
646 fn select_with_group_by(
647 &self,
648 rows: Vec<Vec<Value>>,
649 columns: Vec<SelectColumn>,
650 group_cols: &[String],
651 schema: &Schema,
652 ) -> Result<Vec<Vec<Value>>> {
653 use std::collections::HashMap;
654
655 let mut group_indices = Vec::new();
657 for col_name in group_cols {
658 let idx = schema.columns.iter().position(|c| &c.name == col_name)
659 .ok_or_else(|| SqlRsError::Query(format!("Column '{}' not found in GROUP BY", col_name)))?;
660 group_indices.push(idx);
661 }
662
663 let mut groups: HashMap<Vec<Value>, Vec<Vec<Value>>> = HashMap::new();
665 for row in rows {
666 let mut key = Vec::new();
667 for &idx in &group_indices {
668 if idx < row.len() {
669 key.push(row[idx].clone());
670 } else {
671 key.push(Value::Null);
672 }
673 }
674 groups.entry(key).or_insert_with(Vec::new).push(row);
675 }
676
677 let mut results = Vec::new();
679 for (group_key, group_rows) in groups {
680 let mut result_row = Vec::new();
681
682 for col in &columns {
683 match col {
684 SelectColumn::Column(col_name) => {
685 if !group_cols.contains(col_name) {
687 return Err(SqlRsError::Query(format!(
688 "Column '{}' must appear in GROUP BY or be used in an aggregate function",
689 col_name
690 )));
691 }
692 let col_idx = schema.columns.iter().position(|c| &c.name == col_name)
694 .ok_or_else(|| SqlRsError::Query(format!("Column '{}' not found", col_name)))?;
695 let key_idx = group_indices.iter().position(|&idx| idx == col_idx)
696 .ok_or_else(|| SqlRsError::Query(format!("Column '{}' not in GROUP BY", col_name)))?;
697 result_row.push(group_key[key_idx].clone());
698 }
699 SelectColumn::Aggregate { function, column } => {
700 let value = self.compute_aggregate(function, column, &group_rows, schema)?;
701 result_row.push(value);
702 }
703 SelectColumn::Star => {
704 return Err(SqlRsError::Query("Cannot use * with GROUP BY".to_string()));
705 }
706 }
707 }
708
709 results.push(result_row);
710 }
711
712 Ok(results)
713 }
714
715 fn compute_aggregate(
716 &self,
717 function: &AggregateFunction,
718 column: &str,
719 rows: &[Vec<Value>],
720 schema: &Schema,
721 ) -> Result<Value> {
722 match function {
723 AggregateFunction::Count => {
724 if column == "*" {
725 Ok(Value::Integer(rows.len() as i64))
726 } else {
727 let col_idx = schema.columns.iter().position(|c| &c.name == column)
728 .ok_or_else(|| SqlRsError::Query(format!("Column '{}' not found", column)))?;
729 let count = rows.iter().filter(|r| col_idx < r.len() && !matches!(r[col_idx], Value::Null)).count();
730 Ok(Value::Integer(count as i64))
731 }
732 }
733 AggregateFunction::Sum => {
734 let col_idx = schema.columns.iter().position(|c| &c.name == column)
735 .ok_or_else(|| SqlRsError::Query(format!("Column '{}' not found", column)))?;
736 let mut sum: i64 = 0;
737 let mut float_sum: f64 = 0.0;
738 let mut is_float = false;
739
740 for row in rows {
741 if col_idx < row.len() {
742 match &row[col_idx] {
743 Value::Integer(i) => sum += i,
744 Value::Float(f) => {
745 float_sum += f;
746 is_float = true;
747 }
748 Value::Null => {}
749 _ => return Err(SqlRsError::Query(format!("Cannot SUM non-numeric column '{}'", column))),
750 }
751 }
752 }
753
754 if is_float {
755 Ok(Value::Float(float_sum + sum as f64))
756 } else {
757 Ok(Value::Integer(sum))
758 }
759 }
760 AggregateFunction::Avg => {
761 let col_idx = schema.columns.iter().position(|c| &c.name == column)
762 .ok_or_else(|| SqlRsError::Query(format!("Column '{}' not found", column)))?;
763 let mut sum: f64 = 0.0;
764 let mut count = 0;
765
766 for row in rows {
767 if col_idx < row.len() {
768 match &row[col_idx] {
769 Value::Integer(i) => {
770 sum += *i as f64;
771 count += 1;
772 }
773 Value::Float(f) => {
774 sum += f;
775 count += 1;
776 }
777 Value::Null => {}
778 _ => return Err(SqlRsError::Query(format!("Cannot AVG non-numeric column '{}'", column))),
779 }
780 }
781 }
782
783 if count > 0 {
784 Ok(Value::Float(sum / count as f64))
785 } else {
786 Ok(Value::Null)
787 }
788 }
789 AggregateFunction::Min => {
790 let col_idx = schema.columns.iter().position(|c| &c.name == column)
791 .ok_or_else(|| SqlRsError::Query(format!("Column '{}' not found", column)))?;
792 let mut min_value: Option<Value> = None;
793
794 for row in rows {
795 if col_idx < row.len() {
796 let current = &row[col_idx];
797 if !matches!(current, Value::Null) {
798 if let Some(ref min) = min_value {
799 if self.compare_values(current, min) == std::cmp::Ordering::Less {
800 min_value = Some(current.clone());
801 }
802 } else {
803 min_value = Some(current.clone());
804 }
805 }
806 }
807 }
808
809 Ok(min_value.unwrap_or(Value::Null))
810 }
811 AggregateFunction::Max => {
812 let col_idx = schema.columns.iter().position(|c| &c.name == column)
813 .ok_or_else(|| SqlRsError::Query(format!("Column '{}' not found", column)))?;
814 let mut max_value: Option<Value> = None;
815
816 for row in rows {
817 if col_idx < row.len() {
818 let current = &row[col_idx];
819 if !matches!(current, Value::Null) {
820 if let Some(ref max) = max_value {
821 if self.compare_values(current, max) == std::cmp::Ordering::Greater {
822 max_value = Some(current.clone());
823 }
824 } else {
825 max_value = Some(current.clone());
826 }
827 }
828 }
829 }
830
831 Ok(max_value.unwrap_or(Value::Null))
832 }
833 }
834 }
835
836 fn get_schema(&self, table: &str) -> Result<Schema> {
837 if let Some(schema) = self.schemas.get(table) {
838 return Ok(schema.clone());
839 }
840
841 let schema_key = format!("__schema__{}", table);
842 let schema_bytes = self.storage.get(schema_key.as_bytes())?
843 .ok_or_else(|| SqlRsError::NotFound(format!("Table {} not found", table)))?;
844
845 let schema: Schema = bincode::deserialize(&schema_bytes)
846 .map_err(|e| SqlRsError::Serialization(e.to_string()))?;
847
848 Ok(schema)
849 }
850
851 fn next_row_id(&mut self, table: &str) -> Result<u64> {
852 let counter_key = format!("__counter__{}", table);
853
854 let current = if let Some(bytes) = self.storage.get(counter_key.as_bytes())? {
855 u64::from_le_bytes(bytes.try_into().unwrap_or([0; 8]))
856 } else {
857 0
858 };
859
860 let next = current + 1;
861 self.storage.put(counter_key.as_bytes(), &next.to_le_bytes())?;
862
863 Ok(next)
864 }
865
866 fn evaluate_where(&self, row: &[Value], where_clause: &WhereExpression, schema: &Schema) -> Result<bool> {
867 match where_clause {
868 WhereExpression::Condition { column, operator, value } => {
869 let col_idx = schema.columns.iter()
871 .position(|c| &c.name == column)
872 .ok_or_else(|| SqlRsError::Query(format!("Column '{}' not found in WHERE clause", column)))?;
873
874 let row_value = if col_idx < row.len() {
876 &row[col_idx]
877 } else {
878 &Value::Null
879 };
880
881 match operator.as_str() {
883 "=" => Ok(self.compare_values(row_value, value) == std::cmp::Ordering::Equal),
884 "!=" => Ok(self.compare_values(row_value, value) != std::cmp::Ordering::Equal),
885 ">" => Ok(self.compare_values(row_value, value) == std::cmp::Ordering::Greater),
886 "<" => Ok(self.compare_values(row_value, value) == std::cmp::Ordering::Less),
887 ">=" => Ok(self.compare_values(row_value, value) != std::cmp::Ordering::Less),
888 "<=" => Ok(self.compare_values(row_value, value) != std::cmp::Ordering::Greater),
889 _ => Err(SqlRsError::Query(format!("Unknown operator: {}", operator))),
890 }
891 }
892 WhereExpression::And { left, right } => {
893 Ok(self.evaluate_where(row, left, schema)? && self.evaluate_where(row, right, schema)?)
894 }
895 WhereExpression::Or { left, right } => {
896 Ok(self.evaluate_where(row, left, schema)? || self.evaluate_where(row, right, schema)?)
897 }
898 }
899 }
900
901 fn compare_values(&self, a: &Value, b: &Value) -> std::cmp::Ordering {
902 if matches!(a, Value::Null) || matches!(b, Value::Null) {
904 return std::cmp::Ordering::Equal; }
906
907 match (a, b) {
908 (Value::Integer(ai), Value::Integer(bi)) => ai.cmp(bi),
909 (Value::Integer(ai), Value::Float(bf)) => {
910 let af = *ai as f64;
911 af.partial_cmp(bf).unwrap_or(std::cmp::Ordering::Equal)
912 }
913 (Value::Float(af), Value::Integer(bi)) => {
914 let bf = *bi as f64;
915 af.partial_cmp(&bf).unwrap_or(std::cmp::Ordering::Equal)
916 }
917 (Value::Float(af), Value::Float(bf)) => {
918 af.partial_cmp(bf).unwrap_or(std::cmp::Ordering::Equal)
919 }
920 (Value::Text(as_), Value::Text(bs)) => as_.cmp(bs),
921 (Value::Boolean(ab), Value::Boolean(bb)) => ab.cmp(bb),
922 _ => std::cmp::Ordering::Equal, }
924 }
925
926 fn delete(&mut self, table: String, where_clause: Option<WhereClause>) -> Result<usize> {
927 let schema = self.get_schema(&table)?;
928
929 let prefix = format!("{}:", table);
930 let scan_results = self.storage.scan(prefix.as_bytes(), &[0xFF; 256])?;
931
932 let mut deleted_count = 0;
933 let mut keys_to_delete = Vec::new();
934
935 for (key, value_bytes) in scan_results {
936 let row: Vec<Value> = bincode::deserialize(&value_bytes)
937 .map_err(|e| SqlRsError::Serialization(e.to_string()))?;
938
939 if let Some(ref wc) = where_clause {
940 if !self.evaluate_where(&row, wc, &schema)? {
941 continue;
942 }
943 }
944
945 keys_to_delete.push(key);
946 deleted_count += 1;
947 }
948
949 for key in keys_to_delete {
950 self.storage.delete(&key)?;
951 }
952
953 Ok(deleted_count)
954 }
955
956 fn drop_table(&mut self, name: String) -> Result<()> {
957 let schema_key = format!("__schema__{}", name);
958 let counter_key = format!("__counter__{}", name);
959 let prefix = format!("{}:", name);
960
961 let scan_results = self.storage.scan(prefix.as_bytes(), &[0xFF; 256])?;
962
963 for (key, _) in scan_results {
964 self.storage.delete(&key)?;
965 }
966
967 self.storage.delete(schema_key.as_bytes())?;
968 self.storage.delete(counter_key.as_bytes())?;
969
970 self.schemas.remove(&name);
971
972 Ok(())
973 }
974
975 fn create_index(&mut self, index_name: String, table: String, column: String) -> Result<()> {
976 let schema = self.get_schema(&table)?;
978
979 let col_idx = schema.columns.iter().position(|c| c.name == column)
981 .ok_or_else(|| SqlRsError::Query(format!("Column '{}' not found in table '{}'", column, table)))?;
982
983 let index_key = format!("__index__{}", index_name);
985 let index_metadata = format!("{}:{}", table, column);
986 self.storage.put(index_key.as_bytes(), index_metadata.as_bytes())?;
987
988 let prefix = format!("{}:", table);
990 let scan = self.storage.scan(prefix.as_bytes(), &[0xFF; 256])?;
991
992 for (key, value_bytes) in scan {
993 let key_str = String::from_utf8_lossy(&key);
994 if !key_str.starts_with(&prefix) {
995 continue;
996 }
997
998 let row: Vec<Value> = bincode::deserialize(&value_bytes)
999 .map_err(|e| SqlRsError::Serialization(e.to_string()))?;
1000
1001 if col_idx < row.len() {
1002 let index_entry_key = format!("__idx__{}__{}__{}",
1004 index_name,
1005 value_to_index_string(&row[col_idx]),
1006 key_str
1007 );
1008 self.storage.put(index_entry_key.as_bytes(), key.as_ref())?;
1009 }
1010 }
1011
1012 Ok(())
1013 }
1014
1015 fn drop_index(&mut self, index_name: String) -> Result<()> {
1016 let index_key = format!("__index__{}", index_name);
1018 match self.storage.get(index_key.as_bytes())? {
1019 Some(_) => {},
1020 None => return Err(SqlRsError::Query(format!("Index '{}' does not exist", index_name))),
1021 }
1022
1023 self.storage.delete(index_key.as_bytes())?;
1025
1026 let index_prefix = format!("__idx__{}__", index_name);
1028 let scan = self.storage.scan(index_prefix.as_bytes(), &[0xFF; 256])?;
1029
1030 for (key, _) in scan {
1031 let key_str = String::from_utf8_lossy(&key);
1032 if key_str.starts_with(&index_prefix) {
1033 self.storage.delete(&key)?;
1034 }
1035 }
1036
1037 Ok(())
1038 }
1039}
1040
1041fn value_to_index_string(value: &Value) -> String {
1043 match value {
1044 Value::Null => "NULL".to_string(),
1045 Value::Integer(i) => format!("INT:{:020}", i), Value::Float(f) => format!("FLT:{:020.6}", f),
1047 Value::Text(s) => format!("TXT:{}", s),
1048 Value::Blob(b) => {
1049 let hex_str: String = b.iter().map(|byte| format!("{:02x}", byte)).collect();
1051 format!("BLB:{}", hex_str)
1052 },
1053 Value::Boolean(b) => format!("BOL:{}", if *b { "1" } else { "0" }),
1054 }
1055}