1use crate::error::{Result, SQLRiteError};
2use crate::sql::db::secondary_index::{IndexOrigin, SecondaryIndex};
3use crate::sql::parser::create::CreateQuery;
4use std::collections::{BTreeMap, HashMap};
5use std::fmt;
6use std::sync::{Arc, Mutex};
7
8use prettytable::{Cell as PrintCell, Row as PrintRow, Table as PrintTable};
9
10#[derive(PartialEq, Debug, Clone)]
19pub enum DataType {
20 Integer,
21 Text,
22 Real,
23 Bool,
24 Vector(usize),
28 None,
29 Invalid,
30}
31
32impl DataType {
33 pub fn new(cmd: String) -> DataType {
40 let lower = cmd.to_lowercase();
41 match lower.as_str() {
42 "integer" => DataType::Integer,
43 "text" => DataType::Text,
44 "real" => DataType::Real,
45 "bool" => DataType::Bool,
46 "none" => DataType::None,
47 other if other.starts_with("vector(") && other.ends_with(')') => {
48 let inside = &other["vector(".len()..other.len() - 1];
52 match inside.trim().parse::<usize>() {
53 Ok(dim) if dim > 0 => DataType::Vector(dim),
54 _ => {
55 eprintln!("Invalid VECTOR dimension in {cmd}");
56 DataType::Invalid
57 }
58 }
59 }
60 _ => {
61 eprintln!("Invalid data type given {}", cmd);
62 DataType::Invalid
63 }
64 }
65 }
66
67 pub fn to_wire_string(&self) -> String {
73 match self {
74 DataType::Integer => "Integer".to_string(),
75 DataType::Text => "Text".to_string(),
76 DataType::Real => "Real".to_string(),
77 DataType::Bool => "Bool".to_string(),
78 DataType::Vector(dim) => format!("vector({dim})"),
79 DataType::None => "None".to_string(),
80 DataType::Invalid => "Invalid".to_string(),
81 }
82 }
83}
84
85impl fmt::Display for DataType {
86 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
87 match self {
88 DataType::Integer => f.write_str("Integer"),
89 DataType::Text => f.write_str("Text"),
90 DataType::Real => f.write_str("Real"),
91 DataType::Bool => f.write_str("Boolean"),
92 DataType::Vector(dim) => write!(f, "Vector({dim})"),
93 DataType::None => f.write_str("None"),
94 DataType::Invalid => f.write_str("Invalid"),
95 }
96 }
97}
98
99#[derive(Debug)]
107pub struct Table {
108 pub tb_name: String,
110 pub columns: Vec<Column>,
112 pub rows: Arc<Mutex<HashMap<String, Row>>>,
116 pub secondary_indexes: Vec<SecondaryIndex>,
121 pub last_rowid: i64,
123 pub primary_key: String,
125}
126
127impl Table {
128 pub fn new(create_query: CreateQuery) -> Self {
129 let table_name = create_query.table_name;
130 let mut primary_key: String = String::from("-1");
131 let columns = create_query.columns;
132
133 let mut table_cols: Vec<Column> = vec![];
134 let table_rows: Arc<Mutex<HashMap<String, Row>>> = Arc::new(Mutex::new(HashMap::new()));
135 let mut secondary_indexes: Vec<SecondaryIndex> = Vec::new();
136 for col in &columns {
137 let col_name = &col.name;
138 if col.is_pk {
139 primary_key = col_name.to_string();
140 }
141 table_cols.push(Column::new(
142 col_name.to_string(),
143 col.datatype.to_string(),
144 col.is_pk,
145 col.not_null,
146 col.is_unique,
147 ));
148
149 let dt = DataType::new(col.datatype.to_string());
150 let row_storage = match &dt {
151 DataType::Integer => Row::Integer(BTreeMap::new()),
152 DataType::Real => Row::Real(BTreeMap::new()),
153 DataType::Text => Row::Text(BTreeMap::new()),
154 DataType::Bool => Row::Bool(BTreeMap::new()),
155 DataType::Vector(_dim) => Row::Vector(BTreeMap::new()),
160 DataType::Invalid | DataType::None => Row::None,
161 };
162 table_rows
163 .lock()
164 .expect("Table row storage mutex poisoned")
165 .insert(col.name.to_string(), row_storage);
166
167 if (col.is_pk || col.is_unique) && matches!(dt, DataType::Integer | DataType::Text) {
174 let name = SecondaryIndex::auto_name(&table_name, &col.name);
175 match SecondaryIndex::new(
176 name,
177 table_name.clone(),
178 col.name.clone(),
179 &dt,
180 true,
181 IndexOrigin::Auto,
182 ) {
183 Ok(idx) => secondary_indexes.push(idx),
184 Err(_) => {
185 }
188 }
189 }
190 }
191
192 Table {
193 tb_name: table_name,
194 columns: table_cols,
195 rows: table_rows,
196 secondary_indexes,
197 last_rowid: 0,
198 primary_key,
199 }
200 }
201
202 pub fn deep_clone(&self) -> Self {
211 let cloned_rows: HashMap<String, Row> = {
212 let guard = self.rows.lock().expect("row mutex poisoned");
213 guard.clone()
214 };
215 Table {
216 tb_name: self.tb_name.clone(),
217 columns: self.columns.clone(),
218 rows: Arc::new(Mutex::new(cloned_rows)),
219 secondary_indexes: self.secondary_indexes.clone(),
220 last_rowid: self.last_rowid,
221 primary_key: self.primary_key.clone(),
222 }
223 }
224
225 pub fn index_for_column(&self, column: &str) -> Option<&SecondaryIndex> {
228 self.secondary_indexes
229 .iter()
230 .find(|i| i.column_name == column)
231 }
232
233 fn index_for_column_mut(&mut self, column: &str) -> Option<&mut SecondaryIndex> {
234 self.secondary_indexes
235 .iter_mut()
236 .find(|i| i.column_name == column)
237 }
238
239 #[allow(dead_code)]
243 pub fn index_by_name(&self, name: &str) -> Option<&SecondaryIndex> {
244 self.secondary_indexes.iter().find(|i| i.name == name)
245 }
246
247 pub fn contains_column(&self, column: String) -> bool {
250 self.columns.iter().any(|col| col.column_name == column)
251 }
252
253 pub fn column_names(&self) -> Vec<String> {
255 self.columns.iter().map(|c| c.column_name.clone()).collect()
256 }
257
258 pub fn rowids(&self) -> Vec<i64> {
261 let Some(first) = self.columns.first() else {
262 return vec![];
263 };
264 let rows = self.rows.lock().expect("rows mutex poisoned");
265 rows.get(&first.column_name)
266 .map(|r| r.rowids())
267 .unwrap_or_default()
268 }
269
270 pub fn get_value(&self, column: &str, rowid: i64) -> Option<Value> {
272 let rows = self.rows.lock().expect("rows mutex poisoned");
273 rows.get(column).and_then(|r| r.get(rowid))
274 }
275
276 pub fn delete_row(&mut self, rowid: i64) {
279 let per_column_values: Vec<(String, Option<Value>)> = self
283 .columns
284 .iter()
285 .map(|c| (c.column_name.clone(), self.get_value(&c.column_name, rowid)))
286 .collect();
287
288 {
290 let rows_clone = Arc::clone(&self.rows);
291 let mut row_data = rows_clone.lock().expect("rows mutex poisoned");
292 for col in &self.columns {
293 if let Some(r) = row_data.get_mut(&col.column_name) {
294 match r {
295 Row::Integer(m) => {
296 m.remove(&rowid);
297 }
298 Row::Text(m) => {
299 m.remove(&rowid);
300 }
301 Row::Real(m) => {
302 m.remove(&rowid);
303 }
304 Row::Bool(m) => {
305 m.remove(&rowid);
306 }
307 Row::Vector(m) => {
308 m.remove(&rowid);
309 }
310 Row::None => {}
311 }
312 }
313 }
314 }
315
316 for (col_name, value) in per_column_values {
319 if let Some(idx) = self.index_for_column_mut(&col_name) {
320 if let Some(v) = value {
321 idx.remove(&v, rowid);
322 }
323 }
324 }
325 }
326
327 pub fn restore_row(&mut self, rowid: i64, values: Vec<Option<Value>>) -> Result<()> {
333 if values.len() != self.columns.len() {
334 return Err(SQLRiteError::Internal(format!(
335 "cell has {} values but table '{}' has {} columns",
336 values.len(),
337 self.tb_name,
338 self.columns.len()
339 )));
340 }
341
342 let column_names: Vec<String> =
343 self.columns.iter().map(|c| c.column_name.clone()).collect();
344
345 for (i, value) in values.into_iter().enumerate() {
346 let col_name = &column_names[i];
347
348 {
351 let rows_clone = Arc::clone(&self.rows);
352 let mut row_data = rows_clone.lock().expect("rows mutex poisoned");
353 let cell = row_data.get_mut(col_name).ok_or_else(|| {
354 SQLRiteError::Internal(format!("Row storage missing for column '{col_name}'"))
355 })?;
356
357 match (cell, &value) {
358 (Row::Integer(map), Some(Value::Integer(v))) => {
359 map.insert(rowid, *v as i32);
360 }
361 (Row::Integer(_), None) => {
362 return Err(SQLRiteError::Internal(format!(
363 "Integer column '{col_name}' cannot store NULL — corrupt cell?"
364 )));
365 }
366 (Row::Text(map), Some(Value::Text(s))) => {
367 map.insert(rowid, s.clone());
368 }
369 (Row::Text(map), None) => {
370 map.insert(rowid, "Null".to_string());
374 }
375 (Row::Real(map), Some(Value::Real(v))) => {
376 map.insert(rowid, *v as f32);
377 }
378 (Row::Real(_), None) => {
379 return Err(SQLRiteError::Internal(format!(
380 "Real column '{col_name}' cannot store NULL — corrupt cell?"
381 )));
382 }
383 (Row::Bool(map), Some(Value::Bool(v))) => {
384 map.insert(rowid, *v);
385 }
386 (Row::Bool(_), None) => {
387 return Err(SQLRiteError::Internal(format!(
388 "Bool column '{col_name}' cannot store NULL — corrupt cell?"
389 )));
390 }
391 (Row::Vector(map), Some(Value::Vector(v))) => {
392 map.insert(rowid, v.clone());
393 }
394 (Row::Vector(_), None) => {
395 return Err(SQLRiteError::Internal(format!(
396 "Vector column '{col_name}' cannot store NULL — corrupt cell?"
397 )));
398 }
399 (row, v) => {
400 return Err(SQLRiteError::Internal(format!(
401 "Type mismatch restoring column '{col_name}': storage {row:?} vs value {v:?}"
402 )));
403 }
404 }
405 }
406
407 if let Some(v) = &value {
410 if let Some(idx) = self.index_for_column_mut(col_name) {
411 idx.insert(v, rowid)?;
412 }
413 }
414 }
415
416 if rowid > self.last_rowid {
417 self.last_rowid = rowid;
418 }
419 Ok(())
420 }
421
422 pub fn extract_row(&self, rowid: i64) -> Vec<Option<Value>> {
426 self.columns
427 .iter()
428 .map(|c| match self.get_value(&c.column_name, rowid) {
429 Some(Value::Null) => None,
430 Some(v) => Some(v),
431 None => None,
432 })
433 .collect()
434 }
435
436 pub fn set_value(&mut self, column: &str, rowid: i64, new_val: Value) -> Result<()> {
443 let col_index = self
444 .columns
445 .iter()
446 .position(|c| c.column_name == column)
447 .ok_or_else(|| SQLRiteError::General(format!("Column '{column}' not found")))?;
448
449 let current = self.get_value(column, rowid);
451 if current.as_ref() == Some(&new_val) {
452 return Ok(());
453 }
454
455 if self.columns[col_index].is_unique && !matches!(new_val, Value::Null) {
459 if let Some(idx) = self.index_for_column(column) {
460 for other in idx.lookup(&new_val) {
461 if other != rowid {
462 return Err(SQLRiteError::General(format!(
463 "UNIQUE constraint violated for column '{column}'"
464 )));
465 }
466 }
467 } else {
468 for other in self.rowids() {
469 if other == rowid {
470 continue;
471 }
472 if self.get_value(column, other).as_ref() == Some(&new_val) {
473 return Err(SQLRiteError::General(format!(
474 "UNIQUE constraint violated for column '{column}'"
475 )));
476 }
477 }
478 }
479 }
480
481 if let Some(old) = current {
484 if let Some(idx) = self.index_for_column_mut(column) {
485 idx.remove(&old, rowid);
486 }
487 }
488
489 let declared = &self.columns[col_index].datatype;
491 {
492 let rows_clone = Arc::clone(&self.rows);
493 let mut row_data = rows_clone.lock().expect("rows mutex poisoned");
494 let cell = row_data.get_mut(column).ok_or_else(|| {
495 SQLRiteError::Internal(format!("Row storage missing for column '{column}'"))
496 })?;
497
498 match (cell, &new_val, declared) {
499 (Row::Integer(m), Value::Integer(v), _) => {
500 m.insert(rowid, *v as i32);
501 }
502 (Row::Real(m), Value::Real(v), _) => {
503 m.insert(rowid, *v as f32);
504 }
505 (Row::Real(m), Value::Integer(v), _) => {
506 m.insert(rowid, *v as f32);
507 }
508 (Row::Text(m), Value::Text(v), _) => {
509 m.insert(rowid, v.clone());
510 }
511 (Row::Bool(m), Value::Bool(v), _) => {
512 m.insert(rowid, *v);
513 }
514 (Row::Vector(m), Value::Vector(v), DataType::Vector(declared_dim)) => {
515 if v.len() != *declared_dim {
516 return Err(SQLRiteError::General(format!(
517 "Vector dimension mismatch for column '{column}': declared {declared_dim}, got {}",
518 v.len()
519 )));
520 }
521 m.insert(rowid, v.clone());
522 }
523 (Row::Text(m), Value::Null, _) => {
526 m.insert(rowid, "Null".to_string());
527 }
528 (_, new, dt) => {
529 return Err(SQLRiteError::General(format!(
530 "Type mismatch: cannot assign {} to column '{column}' of type {dt}",
531 new.to_display_string()
532 )));
533 }
534 }
535 }
536
537 if !matches!(new_val, Value::Null) {
540 if let Some(idx) = self.index_for_column_mut(column) {
541 idx.insert(&new_val, rowid)?;
542 }
543 }
544
545 Ok(())
546 }
547
548 #[allow(dead_code)]
552 pub fn get_column(&mut self, column_name: String) -> Result<&Column> {
553 if let Some(column) = self
554 .columns
555 .iter()
556 .filter(|c| c.column_name == column_name)
557 .collect::<Vec<&Column>>()
558 .first()
559 {
560 Ok(column)
561 } else {
562 Err(SQLRiteError::General(String::from("Column not found.")))
563 }
564 }
565
566 pub fn validate_unique_constraint(
572 &mut self,
573 cols: &Vec<String>,
574 values: &Vec<String>,
575 ) -> Result<()> {
576 for (idx, name) in cols.iter().enumerate() {
577 let column = self
578 .columns
579 .iter()
580 .find(|c| &c.column_name == name)
581 .ok_or_else(|| SQLRiteError::General(format!("Column '{name}' not found")))?;
582 if !column.is_unique {
583 continue;
584 }
585 let datatype = &column.datatype;
586 let val = &values[idx];
587
588 let parsed = match datatype {
593 DataType::Integer => val.parse::<i64>().map(Value::Integer).map_err(|_| {
594 SQLRiteError::General(format!(
595 "Type mismatch: expected INTEGER for column '{name}', got '{val}'"
596 ))
597 })?,
598 DataType::Text => Value::Text(val.clone()),
599 DataType::Real => val.parse::<f64>().map(Value::Real).map_err(|_| {
600 SQLRiteError::General(format!(
601 "Type mismatch: expected REAL for column '{name}', got '{val}'"
602 ))
603 })?,
604 DataType::Bool => val.parse::<bool>().map(Value::Bool).map_err(|_| {
605 SQLRiteError::General(format!(
606 "Type mismatch: expected BOOL for column '{name}', got '{val}'"
607 ))
608 })?,
609 DataType::Vector(declared_dim) => {
610 let parsed_vec = parse_vector_literal(val).map_err(|e| {
611 SQLRiteError::General(format!(
612 "Type mismatch: expected VECTOR({declared_dim}) for column '{name}', {e}"
613 ))
614 })?;
615 if parsed_vec.len() != *declared_dim {
616 return Err(SQLRiteError::General(format!(
617 "Vector dimension mismatch for column '{name}': declared {declared_dim}, got {}",
618 parsed_vec.len()
619 )));
620 }
621 Value::Vector(parsed_vec)
622 }
623 DataType::None | DataType::Invalid => {
624 return Err(SQLRiteError::Internal(format!(
625 "column '{name}' has an unsupported datatype"
626 )));
627 }
628 };
629
630 if let Some(secondary) = self.index_for_column(name) {
631 if secondary.would_violate_unique(&parsed) {
632 return Err(SQLRiteError::General(format!(
633 "UNIQUE constraint violated for column '{name}': value '{val}' already exists"
634 )));
635 }
636 } else {
637 for other in self.rowids() {
639 if self.get_value(name, other).as_ref() == Some(&parsed) {
640 return Err(SQLRiteError::General(format!(
641 "UNIQUE constraint violated for column '{name}': value '{val}' already exists"
642 )));
643 }
644 }
645 }
646 }
647 Ok(())
648 }
649
650 pub fn insert_row(&mut self, cols: &Vec<String>, values: &Vec<String>) -> Result<()> {
661 let mut next_rowid = self.last_rowid + 1;
662
663 if self.primary_key != "-1" {
666 if !cols.iter().any(|col| col == &self.primary_key) {
667 let val = next_rowid as i32;
670 let wrote_integer = {
671 let rows_clone = Arc::clone(&self.rows);
672 let mut row_data = rows_clone.lock().expect("rows mutex poisoned");
673 let table_col_data = row_data.get_mut(&self.primary_key).ok_or_else(|| {
674 SQLRiteError::Internal(format!(
675 "Row storage missing for primary key column '{}'",
676 self.primary_key
677 ))
678 })?;
679 match table_col_data {
680 Row::Integer(tree) => {
681 tree.insert(next_rowid, val);
682 true
683 }
684 _ => false, }
686 };
687 if wrote_integer {
688 let pk = self.primary_key.clone();
689 if let Some(idx) = self.index_for_column_mut(&pk) {
690 idx.insert(&Value::Integer(val as i64), next_rowid)?;
691 }
692 }
693 } else {
694 for i in 0..cols.len() {
695 if cols[i] == self.primary_key {
696 let val = &values[i];
697 next_rowid = val.parse::<i64>().map_err(|_| {
698 SQLRiteError::General(format!(
699 "Type mismatch: PRIMARY KEY column '{}' expects INTEGER, got '{val}'",
700 self.primary_key
701 ))
702 })?;
703 }
704 }
705 }
706 }
707
708 let column_names = self
711 .columns
712 .iter()
713 .map(|col| col.column_name.to_string())
714 .collect::<Vec<String>>();
715 let mut j: usize = 0;
716 for i in 0..column_names.len() {
717 let mut val = String::from("Null");
718 let key = &column_names[i];
719
720 if let Some(supplied_key) = cols.get(j) {
721 if supplied_key == &column_names[i] {
722 val = values[j].to_string();
723 j += 1;
724 } else if self.primary_key == column_names[i] {
725 continue;
727 }
728 } else if self.primary_key == column_names[i] {
729 continue;
730 }
731
732 let typed_value: Option<Value> = {
735 let rows_clone = Arc::clone(&self.rows);
736 let mut row_data = rows_clone.lock().expect("rows mutex poisoned");
737 let table_col_data = row_data.get_mut(key).ok_or_else(|| {
738 SQLRiteError::Internal(format!("Row storage missing for column '{key}'"))
739 })?;
740
741 match table_col_data {
742 Row::Integer(tree) => {
743 let parsed = val.parse::<i32>().map_err(|_| {
744 SQLRiteError::General(format!(
745 "Type mismatch: expected INTEGER for column '{key}', got '{val}'"
746 ))
747 })?;
748 tree.insert(next_rowid, parsed);
749 Some(Value::Integer(parsed as i64))
750 }
751 Row::Text(tree) => {
752 tree.insert(next_rowid, val.to_string());
753 if val != "Null" {
756 Some(Value::Text(val.to_string()))
757 } else {
758 None
759 }
760 }
761 Row::Real(tree) => {
762 let parsed = val.parse::<f32>().map_err(|_| {
763 SQLRiteError::General(format!(
764 "Type mismatch: expected REAL for column '{key}', got '{val}'"
765 ))
766 })?;
767 tree.insert(next_rowid, parsed);
768 Some(Value::Real(parsed as f64))
769 }
770 Row::Bool(tree) => {
771 let parsed = val.parse::<bool>().map_err(|_| {
772 SQLRiteError::General(format!(
773 "Type mismatch: expected BOOL for column '{key}', got '{val}'"
774 ))
775 })?;
776 tree.insert(next_rowid, parsed);
777 Some(Value::Bool(parsed))
778 }
779 Row::Vector(tree) => {
780 let parsed = parse_vector_literal(&val).map_err(|e| {
785 SQLRiteError::General(format!(
786 "Type mismatch: expected VECTOR for column '{key}', {e}"
787 ))
788 })?;
789 let declared_dim = match &self.columns[i].datatype {
790 DataType::Vector(d) => *d,
791 other => {
792 return Err(SQLRiteError::Internal(format!(
793 "Row::Vector storage on non-Vector column '{key}' (declared as {other})"
794 )));
795 }
796 };
797 if parsed.len() != declared_dim {
798 return Err(SQLRiteError::General(format!(
799 "Vector dimension mismatch for column '{key}': declared {declared_dim}, got {}",
800 parsed.len()
801 )));
802 }
803 tree.insert(next_rowid, parsed.clone());
804 Some(Value::Vector(parsed))
805 }
806 Row::None => {
807 return Err(SQLRiteError::Internal(format!(
808 "Column '{key}' has no row storage"
809 )));
810 }
811 }
812 };
813
814 if let Some(v) = typed_value {
817 if let Some(idx) = self.index_for_column_mut(key) {
818 idx.insert(&v, next_rowid)?;
819 }
820 }
821 }
822 self.last_rowid = next_rowid;
823 Ok(())
824 }
825
826 pub fn print_table_schema(&self) -> Result<usize> {
847 let mut table = PrintTable::new();
848 table.add_row(row![
849 "Column Name",
850 "Data Type",
851 "PRIMARY KEY",
852 "UNIQUE",
853 "NOT NULL"
854 ]);
855
856 for col in &self.columns {
857 table.add_row(row![
858 col.column_name,
859 col.datatype,
860 col.is_pk,
861 col.is_unique,
862 col.not_null
863 ]);
864 }
865
866 table.printstd();
867 Ok(table.len() * 2 + 1)
868 }
869
870 pub fn print_table_data(&self) {
891 let mut print_table = PrintTable::new();
892
893 let column_names = self
894 .columns
895 .iter()
896 .map(|col| col.column_name.to_string())
897 .collect::<Vec<String>>();
898
899 let header_row = PrintRow::new(
900 column_names
901 .iter()
902 .map(|col| PrintCell::new(col))
903 .collect::<Vec<PrintCell>>(),
904 );
905
906 let rows_clone = Arc::clone(&self.rows);
907 let row_data = rows_clone.lock().expect("rows mutex poisoned");
908 let first_col_data = row_data
909 .get(&self.columns.first().unwrap().column_name)
910 .unwrap();
911 let num_rows = first_col_data.count();
912 let mut print_table_rows: Vec<PrintRow> = vec![PrintRow::new(vec![]); num_rows];
913
914 for col_name in &column_names {
915 let col_val = row_data
916 .get(col_name)
917 .expect("Can't find any rows with the given column");
918 let columns: Vec<String> = col_val.get_serialized_col_data();
919
920 for i in 0..num_rows {
921 if let Some(cell) = &columns.get(i) {
922 print_table_rows[i].add_cell(PrintCell::new(cell));
923 } else {
924 print_table_rows[i].add_cell(PrintCell::new(""));
925 }
926 }
927 }
928
929 print_table.add_row(header_row);
930 for row in print_table_rows {
931 print_table.add_row(row);
932 }
933
934 print_table.printstd();
935 }
936}
937
938#[derive(PartialEq, Debug, Clone)]
944pub struct Column {
945 pub column_name: String,
946 pub datatype: DataType,
947 pub is_pk: bool,
948 pub not_null: bool,
949 pub is_unique: bool,
950}
951
952impl Column {
953 pub fn new(
954 name: String,
955 datatype: String,
956 is_pk: bool,
957 not_null: bool,
958 is_unique: bool,
959 ) -> Self {
960 let dt = DataType::new(datatype);
961 Column {
962 column_name: name,
963 datatype: dt,
964 is_pk,
965 not_null,
966 is_unique,
967 }
968 }
969}
970
971#[derive(PartialEq, Debug, Clone)]
977pub enum Row {
978 Integer(BTreeMap<i64, i32>),
979 Text(BTreeMap<i64, String>),
980 Real(BTreeMap<i64, f32>),
981 Bool(BTreeMap<i64, bool>),
982 Vector(BTreeMap<i64, Vec<f32>>),
987 None,
988}
989
990impl Row {
991 fn get_serialized_col_data(&self) -> Vec<String> {
992 match self {
993 Row::Integer(cd) => cd.values().map(|v| v.to_string()).collect(),
994 Row::Real(cd) => cd.values().map(|v| v.to_string()).collect(),
995 Row::Text(cd) => cd.values().map(|v| v.to_string()).collect(),
996 Row::Bool(cd) => cd.values().map(|v| v.to_string()).collect(),
997 Row::Vector(cd) => cd.values().map(format_vector_for_display).collect(),
998 Row::None => panic!("Found None in columns"),
999 }
1000 }
1001
1002 fn count(&self) -> usize {
1003 match self {
1004 Row::Integer(cd) => cd.len(),
1005 Row::Real(cd) => cd.len(),
1006 Row::Text(cd) => cd.len(),
1007 Row::Bool(cd) => cd.len(),
1008 Row::Vector(cd) => cd.len(),
1009 Row::None => panic!("Found None in columns"),
1010 }
1011 }
1012
1013 pub fn rowids(&self) -> Vec<i64> {
1017 match self {
1018 Row::Integer(m) => m.keys().copied().collect(),
1019 Row::Text(m) => m.keys().copied().collect(),
1020 Row::Real(m) => m.keys().copied().collect(),
1021 Row::Bool(m) => m.keys().copied().collect(),
1022 Row::Vector(m) => m.keys().copied().collect(),
1023 Row::None => vec![],
1024 }
1025 }
1026
1027 pub fn get(&self, rowid: i64) -> Option<Value> {
1028 match self {
1029 Row::Integer(m) => m.get(&rowid).map(|v| Value::Integer(i64::from(*v))),
1030 Row::Text(m) => m.get(&rowid).map(|v| {
1033 if v == "Null" {
1034 Value::Null
1035 } else {
1036 Value::Text(v.clone())
1037 }
1038 }),
1039 Row::Real(m) => m.get(&rowid).map(|v| Value::Real(f64::from(*v))),
1040 Row::Bool(m) => m.get(&rowid).map(|v| Value::Bool(*v)),
1041 Row::Vector(m) => m.get(&rowid).map(|v| Value::Vector(v.clone())),
1042 Row::None => None,
1043 }
1044 }
1045}
1046
1047fn format_vector_for_display(v: &Vec<f32>) -> String {
1055 let mut s = String::with_capacity(v.len() * 6 + 2);
1056 s.push('[');
1057 for (i, x) in v.iter().enumerate() {
1058 if i > 0 {
1059 s.push_str(", ");
1060 }
1061 s.push_str(&x.to_string());
1064 }
1065 s.push(']');
1066 s
1067}
1068
1069#[derive(Debug, Clone, PartialEq)]
1072pub enum Value {
1073 Integer(i64),
1074 Text(String),
1075 Real(f64),
1076 Bool(bool),
1077 Vector(Vec<f32>),
1082 Null,
1083}
1084
1085impl Value {
1086 pub fn to_display_string(&self) -> String {
1087 match self {
1088 Value::Integer(v) => v.to_string(),
1089 Value::Text(s) => s.clone(),
1090 Value::Real(f) => f.to_string(),
1091 Value::Bool(b) => b.to_string(),
1092 Value::Vector(v) => format_vector_for_display(v),
1093 Value::Null => String::from("NULL"),
1094 }
1095 }
1096}
1097
1098pub fn parse_vector_literal(s: &str) -> Result<Vec<f32>> {
1118 let trimmed = s.trim();
1119 if !trimmed.starts_with('[') || !trimmed.ends_with(']') {
1120 return Err(SQLRiteError::General(format!(
1121 "expected bracket-array literal `[...]`, got `{s}`"
1122 )));
1123 }
1124 let inner = &trimmed[1..trimmed.len() - 1].trim();
1125 if inner.is_empty() {
1126 return Ok(Vec::new());
1127 }
1128 let mut out = Vec::new();
1129 for (i, part) in inner.split(',').enumerate() {
1130 let element = part.trim();
1131 let parsed: f32 = element.parse().map_err(|_| {
1132 SQLRiteError::General(format!("vector element {i} (`{element}`) is not a number"))
1133 })?;
1134 out.push(parsed);
1135 }
1136 Ok(out)
1137}
1138
1139#[cfg(test)]
1140mod tests {
1141 use super::*;
1142 use sqlparser::dialect::SQLiteDialect;
1143 use sqlparser::parser::Parser;
1144
1145 #[test]
1146 fn datatype_display_trait_test() {
1147 let integer = DataType::Integer;
1148 let text = DataType::Text;
1149 let real = DataType::Real;
1150 let boolean = DataType::Bool;
1151 let vector = DataType::Vector(384);
1152 let none = DataType::None;
1153 let invalid = DataType::Invalid;
1154
1155 assert_eq!(format!("{}", integer), "Integer");
1156 assert_eq!(format!("{}", text), "Text");
1157 assert_eq!(format!("{}", real), "Real");
1158 assert_eq!(format!("{}", boolean), "Boolean");
1159 assert_eq!(format!("{}", vector), "Vector(384)");
1160 assert_eq!(format!("{}", none), "None");
1161 assert_eq!(format!("{}", invalid), "Invalid");
1162 }
1163
1164 #[test]
1169 fn datatype_new_parses_vector_dim() {
1170 assert_eq!(DataType::new("vector(1)".to_string()), DataType::Vector(1));
1172 assert_eq!(
1173 DataType::new("vector(384)".to_string()),
1174 DataType::Vector(384)
1175 );
1176 assert_eq!(
1177 DataType::new("vector(1536)".to_string()),
1178 DataType::Vector(1536)
1179 );
1180
1181 assert_eq!(
1183 DataType::new("VECTOR(384)".to_string()),
1184 DataType::Vector(384)
1185 );
1186
1187 assert_eq!(
1191 DataType::new("vector( 64 )".to_string()),
1192 DataType::Vector(64)
1193 );
1194 }
1195
1196 #[test]
1197 fn datatype_new_rejects_bad_vector_strings() {
1198 assert_eq!(DataType::new("vector(0)".to_string()), DataType::Invalid);
1200 assert_eq!(DataType::new("vector(abc)".to_string()), DataType::Invalid);
1202 assert_eq!(DataType::new("vector()".to_string()), DataType::Invalid);
1204 assert_eq!(DataType::new("vector(-3)".to_string()), DataType::Invalid);
1206 }
1207
1208 #[test]
1209 fn datatype_to_wire_string_round_trips_vector() {
1210 let dt = DataType::Vector(384);
1211 let wire = dt.to_wire_string();
1212 assert_eq!(wire, "vector(384)");
1213 assert_eq!(DataType::new(wire), DataType::Vector(384));
1216 }
1217
1218 #[test]
1219 fn parse_vector_literal_accepts_floats() {
1220 let v = parse_vector_literal("[0.1, 0.2, 0.3]").expect("parse");
1221 assert_eq!(v, vec![0.1f32, 0.2, 0.3]);
1222 }
1223
1224 #[test]
1225 fn parse_vector_literal_accepts_ints_widening_to_f32() {
1226 let v = parse_vector_literal("[1, 2, 3]").expect("parse");
1227 assert_eq!(v, vec![1.0f32, 2.0, 3.0]);
1228 }
1229
1230 #[test]
1231 fn parse_vector_literal_handles_negatives_and_whitespace() {
1232 let v = parse_vector_literal("[ -1.5 , 2.0, -3.5 ]").expect("parse");
1233 assert_eq!(v, vec![-1.5f32, 2.0, -3.5]);
1234 }
1235
1236 #[test]
1237 fn parse_vector_literal_empty_brackets_is_empty_vec() {
1238 let v = parse_vector_literal("[]").expect("parse");
1239 assert!(v.is_empty());
1240 }
1241
1242 #[test]
1243 fn parse_vector_literal_rejects_non_bracketed() {
1244 assert!(parse_vector_literal("0.1, 0.2").is_err());
1245 assert!(parse_vector_literal("(0.1, 0.2)").is_err());
1246 assert!(parse_vector_literal("[0.1, 0.2").is_err()); assert!(parse_vector_literal("0.1, 0.2]").is_err()); }
1249
1250 #[test]
1251 fn parse_vector_literal_rejects_non_numeric_elements() {
1252 let err = parse_vector_literal("[1.0, 'foo', 3.0]").unwrap_err();
1253 let msg = format!("{err}");
1254 assert!(
1255 msg.contains("vector element 1") && msg.contains("'foo'"),
1256 "error message should pinpoint the bad element: got `{msg}`"
1257 );
1258 }
1259
1260 #[test]
1261 fn value_vector_display_format() {
1262 let v = Value::Vector(vec![0.1, 0.2, 0.3]);
1263 assert_eq!(v.to_display_string(), "[0.1, 0.2, 0.3]");
1264
1265 let empty = Value::Vector(vec![]);
1267 assert_eq!(empty.to_display_string(), "[]");
1268 }
1269
1270 #[test]
1271 fn create_new_table_test() {
1272 let query_statement = "CREATE TABLE contacts (
1273 id INTEGER PRIMARY KEY,
1274 first_name TEXT NOT NULL,
1275 last_name TEXT NOT NULl,
1276 email TEXT NOT NULL UNIQUE,
1277 active BOOL,
1278 score REAL
1279 );";
1280 let dialect = SQLiteDialect {};
1281 let mut ast = Parser::parse_sql(&dialect, query_statement).unwrap();
1282 if ast.len() > 1 {
1283 panic!("Expected a single query statement, but there are more then 1.")
1284 }
1285 let query = ast.pop().unwrap();
1286
1287 let create_query = CreateQuery::new(&query).unwrap();
1288
1289 let table = Table::new(create_query);
1290
1291 assert_eq!(table.columns.len(), 6);
1292 assert_eq!(table.last_rowid, 0);
1293
1294 let id_column = "id".to_string();
1295 if let Some(column) = table
1296 .columns
1297 .iter()
1298 .filter(|c| c.column_name == id_column)
1299 .collect::<Vec<&Column>>()
1300 .first()
1301 {
1302 assert!(column.is_pk);
1303 assert_eq!(column.datatype, DataType::Integer);
1304 } else {
1305 panic!("column not found");
1306 }
1307 }
1308
1309 #[test]
1310 fn print_table_schema_test() {
1311 let query_statement = "CREATE TABLE contacts (
1312 id INTEGER PRIMARY KEY,
1313 first_name TEXT NOT NULL,
1314 last_name TEXT NOT NULl
1315 );";
1316 let dialect = SQLiteDialect {};
1317 let mut ast = Parser::parse_sql(&dialect, query_statement).unwrap();
1318 if ast.len() > 1 {
1319 panic!("Expected a single query statement, but there are more then 1.")
1320 }
1321 let query = ast.pop().unwrap();
1322
1323 let create_query = CreateQuery::new(&query).unwrap();
1324
1325 let table = Table::new(create_query);
1326 let lines_printed = table.print_table_schema();
1327 assert_eq!(lines_printed, Ok(9));
1328 }
1329}