1use crate::error::{Result, SQLRiteError};
2use crate::sql::db::secondary_index::{IndexOrigin, SecondaryIndex};
3use crate::sql::hnsw::HnswIndex;
4use crate::sql::parser::create::CreateQuery;
5use std::collections::{BTreeMap, HashMap};
6use std::fmt;
7use std::sync::{Arc, Mutex};
8
9use prettytable::{Cell as PrintCell, Row as PrintRow, Table as PrintTable};
10
11#[derive(PartialEq, Debug, Clone)]
20pub enum DataType {
21 Integer,
22 Text,
23 Real,
24 Bool,
25 Vector(usize),
29 None,
30 Invalid,
31}
32
33impl DataType {
34 pub fn new(cmd: String) -> DataType {
41 let lower = cmd.to_lowercase();
42 match lower.as_str() {
43 "integer" => DataType::Integer,
44 "text" => DataType::Text,
45 "real" => DataType::Real,
46 "bool" => DataType::Bool,
47 "none" => DataType::None,
48 other if other.starts_with("vector(") && other.ends_with(')') => {
49 let inside = &other["vector(".len()..other.len() - 1];
53 match inside.trim().parse::<usize>() {
54 Ok(dim) if dim > 0 => DataType::Vector(dim),
55 _ => {
56 eprintln!("Invalid VECTOR dimension in {cmd}");
57 DataType::Invalid
58 }
59 }
60 }
61 _ => {
62 eprintln!("Invalid data type given {}", cmd);
63 DataType::Invalid
64 }
65 }
66 }
67
68 pub fn to_wire_string(&self) -> String {
74 match self {
75 DataType::Integer => "Integer".to_string(),
76 DataType::Text => "Text".to_string(),
77 DataType::Real => "Real".to_string(),
78 DataType::Bool => "Bool".to_string(),
79 DataType::Vector(dim) => format!("vector({dim})"),
80 DataType::None => "None".to_string(),
81 DataType::Invalid => "Invalid".to_string(),
82 }
83 }
84}
85
86impl fmt::Display for DataType {
87 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
88 match self {
89 DataType::Integer => f.write_str("Integer"),
90 DataType::Text => f.write_str("Text"),
91 DataType::Real => f.write_str("Real"),
92 DataType::Bool => f.write_str("Boolean"),
93 DataType::Vector(dim) => write!(f, "Vector({dim})"),
94 DataType::None => f.write_str("None"),
95 DataType::Invalid => f.write_str("Invalid"),
96 }
97 }
98}
99
100#[derive(Debug)]
108pub struct Table {
109 pub tb_name: String,
111 pub columns: Vec<Column>,
113 pub rows: Arc<Mutex<HashMap<String, Row>>>,
117 pub secondary_indexes: Vec<SecondaryIndex>,
122 pub hnsw_indexes: Vec<HnswIndexEntry>,
127 pub last_rowid: i64,
129 pub primary_key: String,
131}
132
133#[derive(Debug, Clone)]
138pub struct HnswIndexEntry {
139 pub name: String,
142 pub column_name: String,
144 pub index: HnswIndex,
146}
147
148impl Table {
149 pub fn new(create_query: CreateQuery) -> Self {
150 let table_name = create_query.table_name;
151 let mut primary_key: String = String::from("-1");
152 let columns = create_query.columns;
153
154 let mut table_cols: Vec<Column> = vec![];
155 let table_rows: Arc<Mutex<HashMap<String, Row>>> = Arc::new(Mutex::new(HashMap::new()));
156 let mut secondary_indexes: Vec<SecondaryIndex> = Vec::new();
157 for col in &columns {
158 let col_name = &col.name;
159 if col.is_pk {
160 primary_key = col_name.to_string();
161 }
162 table_cols.push(Column::new(
163 col_name.to_string(),
164 col.datatype.to_string(),
165 col.is_pk,
166 col.not_null,
167 col.is_unique,
168 ));
169
170 let dt = DataType::new(col.datatype.to_string());
171 let row_storage = match &dt {
172 DataType::Integer => Row::Integer(BTreeMap::new()),
173 DataType::Real => Row::Real(BTreeMap::new()),
174 DataType::Text => Row::Text(BTreeMap::new()),
175 DataType::Bool => Row::Bool(BTreeMap::new()),
176 DataType::Vector(_dim) => Row::Vector(BTreeMap::new()),
181 DataType::Invalid | DataType::None => Row::None,
182 };
183 table_rows
184 .lock()
185 .expect("Table row storage mutex poisoned")
186 .insert(col.name.to_string(), row_storage);
187
188 if (col.is_pk || col.is_unique) && matches!(dt, DataType::Integer | DataType::Text) {
195 let name = SecondaryIndex::auto_name(&table_name, &col.name);
196 match SecondaryIndex::new(
197 name,
198 table_name.clone(),
199 col.name.clone(),
200 &dt,
201 true,
202 IndexOrigin::Auto,
203 ) {
204 Ok(idx) => secondary_indexes.push(idx),
205 Err(_) => {
206 }
209 }
210 }
211 }
212
213 Table {
214 tb_name: table_name,
215 columns: table_cols,
216 rows: table_rows,
217 secondary_indexes,
218 hnsw_indexes: Vec::new(),
223 last_rowid: 0,
224 primary_key,
225 }
226 }
227
228 pub fn deep_clone(&self) -> Self {
237 let cloned_rows: HashMap<String, Row> = {
238 let guard = self.rows.lock().expect("row mutex poisoned");
239 guard.clone()
240 };
241 Table {
242 tb_name: self.tb_name.clone(),
243 columns: self.columns.clone(),
244 rows: Arc::new(Mutex::new(cloned_rows)),
245 secondary_indexes: self.secondary_indexes.clone(),
246 hnsw_indexes: self.hnsw_indexes.clone(),
250 last_rowid: self.last_rowid,
251 primary_key: self.primary_key.clone(),
252 }
253 }
254
255 pub fn index_for_column(&self, column: &str) -> Option<&SecondaryIndex> {
258 self.secondary_indexes
259 .iter()
260 .find(|i| i.column_name == column)
261 }
262
263 fn index_for_column_mut(&mut self, column: &str) -> Option<&mut SecondaryIndex> {
264 self.secondary_indexes
265 .iter_mut()
266 .find(|i| i.column_name == column)
267 }
268
269 #[allow(dead_code)]
273 pub fn index_by_name(&self, name: &str) -> Option<&SecondaryIndex> {
274 self.secondary_indexes.iter().find(|i| i.name == name)
275 }
276
277 pub fn contains_column(&self, column: String) -> bool {
280 self.columns.iter().any(|col| col.column_name == column)
281 }
282
283 pub fn column_names(&self) -> Vec<String> {
285 self.columns.iter().map(|c| c.column_name.clone()).collect()
286 }
287
288 pub fn rowids(&self) -> Vec<i64> {
291 let Some(first) = self.columns.first() else {
292 return vec![];
293 };
294 let rows = self.rows.lock().expect("rows mutex poisoned");
295 rows.get(&first.column_name)
296 .map(|r| r.rowids())
297 .unwrap_or_default()
298 }
299
300 pub fn get_value(&self, column: &str, rowid: i64) -> Option<Value> {
302 let rows = self.rows.lock().expect("rows mutex poisoned");
303 rows.get(column).and_then(|r| r.get(rowid))
304 }
305
306 pub fn delete_row(&mut self, rowid: i64) {
309 let per_column_values: Vec<(String, Option<Value>)> = self
313 .columns
314 .iter()
315 .map(|c| (c.column_name.clone(), self.get_value(&c.column_name, rowid)))
316 .collect();
317
318 {
320 let rows_clone = Arc::clone(&self.rows);
321 let mut row_data = rows_clone.lock().expect("rows mutex poisoned");
322 for col in &self.columns {
323 if let Some(r) = row_data.get_mut(&col.column_name) {
324 match r {
325 Row::Integer(m) => {
326 m.remove(&rowid);
327 }
328 Row::Text(m) => {
329 m.remove(&rowid);
330 }
331 Row::Real(m) => {
332 m.remove(&rowid);
333 }
334 Row::Bool(m) => {
335 m.remove(&rowid);
336 }
337 Row::Vector(m) => {
338 m.remove(&rowid);
339 }
340 Row::None => {}
341 }
342 }
343 }
344 }
345
346 for (col_name, value) in per_column_values {
349 if let Some(idx) = self.index_for_column_mut(&col_name) {
350 if let Some(v) = value {
351 idx.remove(&v, rowid);
352 }
353 }
354 }
355 }
356
357 pub fn restore_row(&mut self, rowid: i64, values: Vec<Option<Value>>) -> Result<()> {
363 if values.len() != self.columns.len() {
364 return Err(SQLRiteError::Internal(format!(
365 "cell has {} values but table '{}' has {} columns",
366 values.len(),
367 self.tb_name,
368 self.columns.len()
369 )));
370 }
371
372 let column_names: Vec<String> =
373 self.columns.iter().map(|c| c.column_name.clone()).collect();
374
375 for (i, value) in values.into_iter().enumerate() {
376 let col_name = &column_names[i];
377
378 {
381 let rows_clone = Arc::clone(&self.rows);
382 let mut row_data = rows_clone.lock().expect("rows mutex poisoned");
383 let cell = row_data.get_mut(col_name).ok_or_else(|| {
384 SQLRiteError::Internal(format!("Row storage missing for column '{col_name}'"))
385 })?;
386
387 match (cell, &value) {
388 (Row::Integer(map), Some(Value::Integer(v))) => {
389 map.insert(rowid, *v as i32);
390 }
391 (Row::Integer(_), None) => {
392 return Err(SQLRiteError::Internal(format!(
393 "Integer column '{col_name}' cannot store NULL — corrupt cell?"
394 )));
395 }
396 (Row::Text(map), Some(Value::Text(s))) => {
397 map.insert(rowid, s.clone());
398 }
399 (Row::Text(map), None) => {
400 map.insert(rowid, "Null".to_string());
404 }
405 (Row::Real(map), Some(Value::Real(v))) => {
406 map.insert(rowid, *v as f32);
407 }
408 (Row::Real(_), None) => {
409 return Err(SQLRiteError::Internal(format!(
410 "Real column '{col_name}' cannot store NULL — corrupt cell?"
411 )));
412 }
413 (Row::Bool(map), Some(Value::Bool(v))) => {
414 map.insert(rowid, *v);
415 }
416 (Row::Bool(_), None) => {
417 return Err(SQLRiteError::Internal(format!(
418 "Bool column '{col_name}' cannot store NULL — corrupt cell?"
419 )));
420 }
421 (Row::Vector(map), Some(Value::Vector(v))) => {
422 map.insert(rowid, v.clone());
423 }
424 (Row::Vector(_), None) => {
425 return Err(SQLRiteError::Internal(format!(
426 "Vector column '{col_name}' cannot store NULL — corrupt cell?"
427 )));
428 }
429 (row, v) => {
430 return Err(SQLRiteError::Internal(format!(
431 "Type mismatch restoring column '{col_name}': storage {row:?} vs value {v:?}"
432 )));
433 }
434 }
435 }
436
437 if let Some(v) = &value {
440 if let Some(idx) = self.index_for_column_mut(col_name) {
441 idx.insert(v, rowid)?;
442 }
443 }
444 }
445
446 if rowid > self.last_rowid {
447 self.last_rowid = rowid;
448 }
449 Ok(())
450 }
451
452 pub fn extract_row(&self, rowid: i64) -> Vec<Option<Value>> {
456 self.columns
457 .iter()
458 .map(|c| match self.get_value(&c.column_name, rowid) {
459 Some(Value::Null) => None,
460 Some(v) => Some(v),
461 None => None,
462 })
463 .collect()
464 }
465
466 pub fn set_value(&mut self, column: &str, rowid: i64, new_val: Value) -> Result<()> {
473 let col_index = self
474 .columns
475 .iter()
476 .position(|c| c.column_name == column)
477 .ok_or_else(|| SQLRiteError::General(format!("Column '{column}' not found")))?;
478
479 let current = self.get_value(column, rowid);
481 if current.as_ref() == Some(&new_val) {
482 return Ok(());
483 }
484
485 if self.columns[col_index].is_unique && !matches!(new_val, Value::Null) {
489 if let Some(idx) = self.index_for_column(column) {
490 for other in idx.lookup(&new_val) {
491 if other != rowid {
492 return Err(SQLRiteError::General(format!(
493 "UNIQUE constraint violated for column '{column}'"
494 )));
495 }
496 }
497 } else {
498 for other in self.rowids() {
499 if other == rowid {
500 continue;
501 }
502 if self.get_value(column, other).as_ref() == Some(&new_val) {
503 return Err(SQLRiteError::General(format!(
504 "UNIQUE constraint violated for column '{column}'"
505 )));
506 }
507 }
508 }
509 }
510
511 if let Some(old) = current {
514 if let Some(idx) = self.index_for_column_mut(column) {
515 idx.remove(&old, rowid);
516 }
517 }
518
519 let declared = &self.columns[col_index].datatype;
521 {
522 let rows_clone = Arc::clone(&self.rows);
523 let mut row_data = rows_clone.lock().expect("rows mutex poisoned");
524 let cell = row_data.get_mut(column).ok_or_else(|| {
525 SQLRiteError::Internal(format!("Row storage missing for column '{column}'"))
526 })?;
527
528 match (cell, &new_val, declared) {
529 (Row::Integer(m), Value::Integer(v), _) => {
530 m.insert(rowid, *v as i32);
531 }
532 (Row::Real(m), Value::Real(v), _) => {
533 m.insert(rowid, *v as f32);
534 }
535 (Row::Real(m), Value::Integer(v), _) => {
536 m.insert(rowid, *v as f32);
537 }
538 (Row::Text(m), Value::Text(v), _) => {
539 m.insert(rowid, v.clone());
540 }
541 (Row::Bool(m), Value::Bool(v), _) => {
542 m.insert(rowid, *v);
543 }
544 (Row::Vector(m), Value::Vector(v), DataType::Vector(declared_dim)) => {
545 if v.len() != *declared_dim {
546 return Err(SQLRiteError::General(format!(
547 "Vector dimension mismatch for column '{column}': declared {declared_dim}, got {}",
548 v.len()
549 )));
550 }
551 m.insert(rowid, v.clone());
552 }
553 (Row::Text(m), Value::Null, _) => {
556 m.insert(rowid, "Null".to_string());
557 }
558 (_, new, dt) => {
559 return Err(SQLRiteError::General(format!(
560 "Type mismatch: cannot assign {} to column '{column}' of type {dt}",
561 new.to_display_string()
562 )));
563 }
564 }
565 }
566
567 if !matches!(new_val, Value::Null) {
570 if let Some(idx) = self.index_for_column_mut(column) {
571 idx.insert(&new_val, rowid)?;
572 }
573 }
574
575 Ok(())
576 }
577
578 #[allow(dead_code)]
582 pub fn get_column(&mut self, column_name: String) -> Result<&Column> {
583 if let Some(column) = self
584 .columns
585 .iter()
586 .filter(|c| c.column_name == column_name)
587 .collect::<Vec<&Column>>()
588 .first()
589 {
590 Ok(column)
591 } else {
592 Err(SQLRiteError::General(String::from("Column not found.")))
593 }
594 }
595
596 pub fn validate_unique_constraint(
602 &mut self,
603 cols: &Vec<String>,
604 values: &Vec<String>,
605 ) -> Result<()> {
606 for (idx, name) in cols.iter().enumerate() {
607 let column = self
608 .columns
609 .iter()
610 .find(|c| &c.column_name == name)
611 .ok_or_else(|| SQLRiteError::General(format!("Column '{name}' not found")))?;
612 if !column.is_unique {
613 continue;
614 }
615 let datatype = &column.datatype;
616 let val = &values[idx];
617
618 let parsed = match datatype {
623 DataType::Integer => val.parse::<i64>().map(Value::Integer).map_err(|_| {
624 SQLRiteError::General(format!(
625 "Type mismatch: expected INTEGER for column '{name}', got '{val}'"
626 ))
627 })?,
628 DataType::Text => Value::Text(val.clone()),
629 DataType::Real => val.parse::<f64>().map(Value::Real).map_err(|_| {
630 SQLRiteError::General(format!(
631 "Type mismatch: expected REAL for column '{name}', got '{val}'"
632 ))
633 })?,
634 DataType::Bool => val.parse::<bool>().map(Value::Bool).map_err(|_| {
635 SQLRiteError::General(format!(
636 "Type mismatch: expected BOOL for column '{name}', got '{val}'"
637 ))
638 })?,
639 DataType::Vector(declared_dim) => {
640 let parsed_vec = parse_vector_literal(val).map_err(|e| {
641 SQLRiteError::General(format!(
642 "Type mismatch: expected VECTOR({declared_dim}) for column '{name}', {e}"
643 ))
644 })?;
645 if parsed_vec.len() != *declared_dim {
646 return Err(SQLRiteError::General(format!(
647 "Vector dimension mismatch for column '{name}': declared {declared_dim}, got {}",
648 parsed_vec.len()
649 )));
650 }
651 Value::Vector(parsed_vec)
652 }
653 DataType::None | DataType::Invalid => {
654 return Err(SQLRiteError::Internal(format!(
655 "column '{name}' has an unsupported datatype"
656 )));
657 }
658 };
659
660 if let Some(secondary) = self.index_for_column(name) {
661 if secondary.would_violate_unique(&parsed) {
662 return Err(SQLRiteError::General(format!(
663 "UNIQUE constraint violated for column '{name}': value '{val}' already exists"
664 )));
665 }
666 } else {
667 for other in self.rowids() {
669 if self.get_value(name, other).as_ref() == Some(&parsed) {
670 return Err(SQLRiteError::General(format!(
671 "UNIQUE constraint violated for column '{name}': value '{val}' already exists"
672 )));
673 }
674 }
675 }
676 }
677 Ok(())
678 }
679
680 pub fn insert_row(&mut self, cols: &Vec<String>, values: &Vec<String>) -> Result<()> {
691 let mut next_rowid = self.last_rowid + 1;
692
693 if self.primary_key != "-1" {
696 if !cols.iter().any(|col| col == &self.primary_key) {
697 let val = next_rowid as i32;
700 let wrote_integer = {
701 let rows_clone = Arc::clone(&self.rows);
702 let mut row_data = rows_clone.lock().expect("rows mutex poisoned");
703 let table_col_data = row_data.get_mut(&self.primary_key).ok_or_else(|| {
704 SQLRiteError::Internal(format!(
705 "Row storage missing for primary key column '{}'",
706 self.primary_key
707 ))
708 })?;
709 match table_col_data {
710 Row::Integer(tree) => {
711 tree.insert(next_rowid, val);
712 true
713 }
714 _ => false, }
716 };
717 if wrote_integer {
718 let pk = self.primary_key.clone();
719 if let Some(idx) = self.index_for_column_mut(&pk) {
720 idx.insert(&Value::Integer(val as i64), next_rowid)?;
721 }
722 }
723 } else {
724 for i in 0..cols.len() {
725 if cols[i] == self.primary_key {
726 let val = &values[i];
727 next_rowid = val.parse::<i64>().map_err(|_| {
728 SQLRiteError::General(format!(
729 "Type mismatch: PRIMARY KEY column '{}' expects INTEGER, got '{val}'",
730 self.primary_key
731 ))
732 })?;
733 }
734 }
735 }
736 }
737
738 let column_names = self
741 .columns
742 .iter()
743 .map(|col| col.column_name.to_string())
744 .collect::<Vec<String>>();
745 let mut j: usize = 0;
746 for i in 0..column_names.len() {
747 let mut val = String::from("Null");
748 let key = &column_names[i];
749
750 if let Some(supplied_key) = cols.get(j) {
751 if supplied_key == &column_names[i] {
752 val = values[j].to_string();
753 j += 1;
754 } else if self.primary_key == column_names[i] {
755 continue;
757 }
758 } else if self.primary_key == column_names[i] {
759 continue;
760 }
761
762 let typed_value: Option<Value> = {
765 let rows_clone = Arc::clone(&self.rows);
766 let mut row_data = rows_clone.lock().expect("rows mutex poisoned");
767 let table_col_data = row_data.get_mut(key).ok_or_else(|| {
768 SQLRiteError::Internal(format!("Row storage missing for column '{key}'"))
769 })?;
770
771 match table_col_data {
772 Row::Integer(tree) => {
773 let parsed = val.parse::<i32>().map_err(|_| {
774 SQLRiteError::General(format!(
775 "Type mismatch: expected INTEGER for column '{key}', got '{val}'"
776 ))
777 })?;
778 tree.insert(next_rowid, parsed);
779 Some(Value::Integer(parsed as i64))
780 }
781 Row::Text(tree) => {
782 tree.insert(next_rowid, val.to_string());
783 if val != "Null" {
786 Some(Value::Text(val.to_string()))
787 } else {
788 None
789 }
790 }
791 Row::Real(tree) => {
792 let parsed = val.parse::<f32>().map_err(|_| {
793 SQLRiteError::General(format!(
794 "Type mismatch: expected REAL for column '{key}', got '{val}'"
795 ))
796 })?;
797 tree.insert(next_rowid, parsed);
798 Some(Value::Real(parsed as f64))
799 }
800 Row::Bool(tree) => {
801 let parsed = val.parse::<bool>().map_err(|_| {
802 SQLRiteError::General(format!(
803 "Type mismatch: expected BOOL for column '{key}', got '{val}'"
804 ))
805 })?;
806 tree.insert(next_rowid, parsed);
807 Some(Value::Bool(parsed))
808 }
809 Row::Vector(tree) => {
810 let parsed = parse_vector_literal(&val).map_err(|e| {
815 SQLRiteError::General(format!(
816 "Type mismatch: expected VECTOR for column '{key}', {e}"
817 ))
818 })?;
819 let declared_dim = match &self.columns[i].datatype {
820 DataType::Vector(d) => *d,
821 other => {
822 return Err(SQLRiteError::Internal(format!(
823 "Row::Vector storage on non-Vector column '{key}' (declared as {other})"
824 )));
825 }
826 };
827 if parsed.len() != declared_dim {
828 return Err(SQLRiteError::General(format!(
829 "Vector dimension mismatch for column '{key}': declared {declared_dim}, got {}",
830 parsed.len()
831 )));
832 }
833 tree.insert(next_rowid, parsed.clone());
834 Some(Value::Vector(parsed))
835 }
836 Row::None => {
837 return Err(SQLRiteError::Internal(format!(
838 "Column '{key}' has no row storage"
839 )));
840 }
841 }
842 };
843
844 if let Some(v) = typed_value.clone() {
847 if let Some(idx) = self.index_for_column_mut(key) {
848 idx.insert(&v, next_rowid)?;
849 }
850 }
851
852 if let Some(Value::Vector(new_vec)) = typed_value {
858 self.maintain_hnsw_on_insert(key, next_rowid, &new_vec);
859 }
860 }
861 self.last_rowid = next_rowid;
862 Ok(())
863 }
864
865 fn maintain_hnsw_on_insert(&mut self, column: &str, rowid: i64, new_vec: &[f32]) {
871 let mut vec_snapshot: HashMap<i64, Vec<f32>> = HashMap::new();
876 {
877 let row_data = self.rows.lock().expect("rows mutex poisoned");
878 if let Some(Row::Vector(map)) = row_data.get(column) {
879 for (id, v) in map.iter() {
880 vec_snapshot.insert(*id, v.clone());
881 }
882 }
883 }
884 vec_snapshot.insert(rowid, new_vec.to_vec());
887
888 for entry in &mut self.hnsw_indexes {
889 if entry.column_name == column {
890 entry.index.insert(rowid, new_vec, |id| {
891 vec_snapshot.get(&id).cloned().unwrap_or_default()
892 });
893 }
894 }
895 }
896
897 pub fn print_table_schema(&self) -> Result<usize> {
918 let mut table = PrintTable::new();
919 table.add_row(row![
920 "Column Name",
921 "Data Type",
922 "PRIMARY KEY",
923 "UNIQUE",
924 "NOT NULL"
925 ]);
926
927 for col in &self.columns {
928 table.add_row(row![
929 col.column_name,
930 col.datatype,
931 col.is_pk,
932 col.is_unique,
933 col.not_null
934 ]);
935 }
936
937 table.printstd();
938 Ok(table.len() * 2 + 1)
939 }
940
941 pub fn print_table_data(&self) {
962 let mut print_table = PrintTable::new();
963
964 let column_names = self
965 .columns
966 .iter()
967 .map(|col| col.column_name.to_string())
968 .collect::<Vec<String>>();
969
970 let header_row = PrintRow::new(
971 column_names
972 .iter()
973 .map(|col| PrintCell::new(col))
974 .collect::<Vec<PrintCell>>(),
975 );
976
977 let rows_clone = Arc::clone(&self.rows);
978 let row_data = rows_clone.lock().expect("rows mutex poisoned");
979 let first_col_data = row_data
980 .get(&self.columns.first().unwrap().column_name)
981 .unwrap();
982 let num_rows = first_col_data.count();
983 let mut print_table_rows: Vec<PrintRow> = vec![PrintRow::new(vec![]); num_rows];
984
985 for col_name in &column_names {
986 let col_val = row_data
987 .get(col_name)
988 .expect("Can't find any rows with the given column");
989 let columns: Vec<String> = col_val.get_serialized_col_data();
990
991 for i in 0..num_rows {
992 if let Some(cell) = &columns.get(i) {
993 print_table_rows[i].add_cell(PrintCell::new(cell));
994 } else {
995 print_table_rows[i].add_cell(PrintCell::new(""));
996 }
997 }
998 }
999
1000 print_table.add_row(header_row);
1001 for row in print_table_rows {
1002 print_table.add_row(row);
1003 }
1004
1005 print_table.printstd();
1006 }
1007}
1008
1009#[derive(PartialEq, Debug, Clone)]
1015pub struct Column {
1016 pub column_name: String,
1017 pub datatype: DataType,
1018 pub is_pk: bool,
1019 pub not_null: bool,
1020 pub is_unique: bool,
1021}
1022
1023impl Column {
1024 pub fn new(
1025 name: String,
1026 datatype: String,
1027 is_pk: bool,
1028 not_null: bool,
1029 is_unique: bool,
1030 ) -> Self {
1031 let dt = DataType::new(datatype);
1032 Column {
1033 column_name: name,
1034 datatype: dt,
1035 is_pk,
1036 not_null,
1037 is_unique,
1038 }
1039 }
1040}
1041
1042#[derive(PartialEq, Debug, Clone)]
1048pub enum Row {
1049 Integer(BTreeMap<i64, i32>),
1050 Text(BTreeMap<i64, String>),
1051 Real(BTreeMap<i64, f32>),
1052 Bool(BTreeMap<i64, bool>),
1053 Vector(BTreeMap<i64, Vec<f32>>),
1058 None,
1059}
1060
1061impl Row {
1062 fn get_serialized_col_data(&self) -> Vec<String> {
1063 match self {
1064 Row::Integer(cd) => cd.values().map(|v| v.to_string()).collect(),
1065 Row::Real(cd) => cd.values().map(|v| v.to_string()).collect(),
1066 Row::Text(cd) => cd.values().map(|v| v.to_string()).collect(),
1067 Row::Bool(cd) => cd.values().map(|v| v.to_string()).collect(),
1068 Row::Vector(cd) => cd.values().map(format_vector_for_display).collect(),
1069 Row::None => panic!("Found None in columns"),
1070 }
1071 }
1072
1073 fn count(&self) -> usize {
1074 match self {
1075 Row::Integer(cd) => cd.len(),
1076 Row::Real(cd) => cd.len(),
1077 Row::Text(cd) => cd.len(),
1078 Row::Bool(cd) => cd.len(),
1079 Row::Vector(cd) => cd.len(),
1080 Row::None => panic!("Found None in columns"),
1081 }
1082 }
1083
1084 pub fn rowids(&self) -> Vec<i64> {
1088 match self {
1089 Row::Integer(m) => m.keys().copied().collect(),
1090 Row::Text(m) => m.keys().copied().collect(),
1091 Row::Real(m) => m.keys().copied().collect(),
1092 Row::Bool(m) => m.keys().copied().collect(),
1093 Row::Vector(m) => m.keys().copied().collect(),
1094 Row::None => vec![],
1095 }
1096 }
1097
1098 pub fn get(&self, rowid: i64) -> Option<Value> {
1099 match self {
1100 Row::Integer(m) => m.get(&rowid).map(|v| Value::Integer(i64::from(*v))),
1101 Row::Text(m) => m.get(&rowid).map(|v| {
1104 if v == "Null" {
1105 Value::Null
1106 } else {
1107 Value::Text(v.clone())
1108 }
1109 }),
1110 Row::Real(m) => m.get(&rowid).map(|v| Value::Real(f64::from(*v))),
1111 Row::Bool(m) => m.get(&rowid).map(|v| Value::Bool(*v)),
1112 Row::Vector(m) => m.get(&rowid).map(|v| Value::Vector(v.clone())),
1113 Row::None => None,
1114 }
1115 }
1116}
1117
1118fn format_vector_for_display(v: &Vec<f32>) -> String {
1126 let mut s = String::with_capacity(v.len() * 6 + 2);
1127 s.push('[');
1128 for (i, x) in v.iter().enumerate() {
1129 if i > 0 {
1130 s.push_str(", ");
1131 }
1132 s.push_str(&x.to_string());
1135 }
1136 s.push(']');
1137 s
1138}
1139
1140#[derive(Debug, Clone, PartialEq)]
1143pub enum Value {
1144 Integer(i64),
1145 Text(String),
1146 Real(f64),
1147 Bool(bool),
1148 Vector(Vec<f32>),
1153 Null,
1154}
1155
1156impl Value {
1157 pub fn to_display_string(&self) -> String {
1158 match self {
1159 Value::Integer(v) => v.to_string(),
1160 Value::Text(s) => s.clone(),
1161 Value::Real(f) => f.to_string(),
1162 Value::Bool(b) => b.to_string(),
1163 Value::Vector(v) => format_vector_for_display(v),
1164 Value::Null => String::from("NULL"),
1165 }
1166 }
1167}
1168
1169pub fn parse_vector_literal(s: &str) -> Result<Vec<f32>> {
1189 let trimmed = s.trim();
1190 if !trimmed.starts_with('[') || !trimmed.ends_with(']') {
1191 return Err(SQLRiteError::General(format!(
1192 "expected bracket-array literal `[...]`, got `{s}`"
1193 )));
1194 }
1195 let inner = &trimmed[1..trimmed.len() - 1].trim();
1196 if inner.is_empty() {
1197 return Ok(Vec::new());
1198 }
1199 let mut out = Vec::new();
1200 for (i, part) in inner.split(',').enumerate() {
1201 let element = part.trim();
1202 let parsed: f32 = element.parse().map_err(|_| {
1203 SQLRiteError::General(format!("vector element {i} (`{element}`) is not a number"))
1204 })?;
1205 out.push(parsed);
1206 }
1207 Ok(out)
1208}
1209
1210#[cfg(test)]
1211mod tests {
1212 use super::*;
1213 use sqlparser::dialect::SQLiteDialect;
1214 use sqlparser::parser::Parser;
1215
1216 #[test]
1217 fn datatype_display_trait_test() {
1218 let integer = DataType::Integer;
1219 let text = DataType::Text;
1220 let real = DataType::Real;
1221 let boolean = DataType::Bool;
1222 let vector = DataType::Vector(384);
1223 let none = DataType::None;
1224 let invalid = DataType::Invalid;
1225
1226 assert_eq!(format!("{}", integer), "Integer");
1227 assert_eq!(format!("{}", text), "Text");
1228 assert_eq!(format!("{}", real), "Real");
1229 assert_eq!(format!("{}", boolean), "Boolean");
1230 assert_eq!(format!("{}", vector), "Vector(384)");
1231 assert_eq!(format!("{}", none), "None");
1232 assert_eq!(format!("{}", invalid), "Invalid");
1233 }
1234
1235 #[test]
1240 fn datatype_new_parses_vector_dim() {
1241 assert_eq!(DataType::new("vector(1)".to_string()), DataType::Vector(1));
1243 assert_eq!(
1244 DataType::new("vector(384)".to_string()),
1245 DataType::Vector(384)
1246 );
1247 assert_eq!(
1248 DataType::new("vector(1536)".to_string()),
1249 DataType::Vector(1536)
1250 );
1251
1252 assert_eq!(
1254 DataType::new("VECTOR(384)".to_string()),
1255 DataType::Vector(384)
1256 );
1257
1258 assert_eq!(
1262 DataType::new("vector( 64 )".to_string()),
1263 DataType::Vector(64)
1264 );
1265 }
1266
1267 #[test]
1268 fn datatype_new_rejects_bad_vector_strings() {
1269 assert_eq!(DataType::new("vector(0)".to_string()), DataType::Invalid);
1271 assert_eq!(DataType::new("vector(abc)".to_string()), DataType::Invalid);
1273 assert_eq!(DataType::new("vector()".to_string()), DataType::Invalid);
1275 assert_eq!(DataType::new("vector(-3)".to_string()), DataType::Invalid);
1277 }
1278
1279 #[test]
1280 fn datatype_to_wire_string_round_trips_vector() {
1281 let dt = DataType::Vector(384);
1282 let wire = dt.to_wire_string();
1283 assert_eq!(wire, "vector(384)");
1284 assert_eq!(DataType::new(wire), DataType::Vector(384));
1287 }
1288
1289 #[test]
1290 fn parse_vector_literal_accepts_floats() {
1291 let v = parse_vector_literal("[0.1, 0.2, 0.3]").expect("parse");
1292 assert_eq!(v, vec![0.1f32, 0.2, 0.3]);
1293 }
1294
1295 #[test]
1296 fn parse_vector_literal_accepts_ints_widening_to_f32() {
1297 let v = parse_vector_literal("[1, 2, 3]").expect("parse");
1298 assert_eq!(v, vec![1.0f32, 2.0, 3.0]);
1299 }
1300
1301 #[test]
1302 fn parse_vector_literal_handles_negatives_and_whitespace() {
1303 let v = parse_vector_literal("[ -1.5 , 2.0, -3.5 ]").expect("parse");
1304 assert_eq!(v, vec![-1.5f32, 2.0, -3.5]);
1305 }
1306
1307 #[test]
1308 fn parse_vector_literal_empty_brackets_is_empty_vec() {
1309 let v = parse_vector_literal("[]").expect("parse");
1310 assert!(v.is_empty());
1311 }
1312
1313 #[test]
1314 fn parse_vector_literal_rejects_non_bracketed() {
1315 assert!(parse_vector_literal("0.1, 0.2").is_err());
1316 assert!(parse_vector_literal("(0.1, 0.2)").is_err());
1317 assert!(parse_vector_literal("[0.1, 0.2").is_err()); assert!(parse_vector_literal("0.1, 0.2]").is_err()); }
1320
1321 #[test]
1322 fn parse_vector_literal_rejects_non_numeric_elements() {
1323 let err = parse_vector_literal("[1.0, 'foo', 3.0]").unwrap_err();
1324 let msg = format!("{err}");
1325 assert!(
1326 msg.contains("vector element 1") && msg.contains("'foo'"),
1327 "error message should pinpoint the bad element: got `{msg}`"
1328 );
1329 }
1330
1331 #[test]
1332 fn value_vector_display_format() {
1333 let v = Value::Vector(vec![0.1, 0.2, 0.3]);
1334 assert_eq!(v.to_display_string(), "[0.1, 0.2, 0.3]");
1335
1336 let empty = Value::Vector(vec![]);
1338 assert_eq!(empty.to_display_string(), "[]");
1339 }
1340
1341 #[test]
1342 fn create_new_table_test() {
1343 let query_statement = "CREATE TABLE contacts (
1344 id INTEGER PRIMARY KEY,
1345 first_name TEXT NOT NULL,
1346 last_name TEXT NOT NULl,
1347 email TEXT NOT NULL UNIQUE,
1348 active BOOL,
1349 score REAL
1350 );";
1351 let dialect = SQLiteDialect {};
1352 let mut ast = Parser::parse_sql(&dialect, query_statement).unwrap();
1353 if ast.len() > 1 {
1354 panic!("Expected a single query statement, but there are more then 1.")
1355 }
1356 let query = ast.pop().unwrap();
1357
1358 let create_query = CreateQuery::new(&query).unwrap();
1359
1360 let table = Table::new(create_query);
1361
1362 assert_eq!(table.columns.len(), 6);
1363 assert_eq!(table.last_rowid, 0);
1364
1365 let id_column = "id".to_string();
1366 if let Some(column) = table
1367 .columns
1368 .iter()
1369 .filter(|c| c.column_name == id_column)
1370 .collect::<Vec<&Column>>()
1371 .first()
1372 {
1373 assert!(column.is_pk);
1374 assert_eq!(column.datatype, DataType::Integer);
1375 } else {
1376 panic!("column not found");
1377 }
1378 }
1379
1380 #[test]
1381 fn print_table_schema_test() {
1382 let query_statement = "CREATE TABLE contacts (
1383 id INTEGER PRIMARY KEY,
1384 first_name TEXT NOT NULL,
1385 last_name TEXT NOT NULl
1386 );";
1387 let dialect = SQLiteDialect {};
1388 let mut ast = Parser::parse_sql(&dialect, query_statement).unwrap();
1389 if ast.len() > 1 {
1390 panic!("Expected a single query statement, but there are more then 1.")
1391 }
1392 let query = ast.pop().unwrap();
1393
1394 let create_query = CreateQuery::new(&query).unwrap();
1395
1396 let table = Table::new(create_query);
1397 let lines_printed = table.print_table_schema();
1398 assert_eq!(lines_printed, Ok(9));
1399 }
1400}