1use rustc_hash::FxHashMap;
19
20use chrono::{DateTime, TimeZone, Utc};
21
22use super::{find_column_index, resolve_alias, Expression};
23use crate::common::SmartString;
24use crate::core::{DataType, Operator, Result, Row, Schema, Value};
25
26#[derive(Debug, Clone)]
30pub struct CastExpr {
31 column: String,
33 target_type: DataType,
35
36 col_index: Option<usize>,
38
39 aliases: FxHashMap<String, String>,
41 original_column: Option<String>,
43}
44
45impl CastExpr {
46 pub fn new(column: impl Into<String>, target_type: DataType) -> Self {
48 Self {
49 column: column.into(),
50 target_type,
51 col_index: None,
52 aliases: FxHashMap::default(),
53 original_column: None,
54 }
55 }
56
57 pub fn target_type(&self) -> DataType {
59 self.target_type
60 }
61
62 pub fn perform_cast(&self, value: &Value) -> Result<Value> {
64 if value.is_null() {
65 return Ok(Value::null(self.target_type));
66 }
67
68 match self.target_type {
69 DataType::Integer => cast_to_integer(value),
70 DataType::Float => cast_to_float(value),
71 DataType::Text => cast_to_string(value),
72 DataType::Boolean => cast_to_boolean(value),
73 DataType::Timestamp => cast_to_timestamp(value),
74 DataType::Json => cast_to_json(value),
75 DataType::Vector => Err(crate::core::Error::type_conversion(
76 format!("{:?}", value),
77 "VECTOR",
78 )),
79 DataType::Null => Ok(Value::null(DataType::Null)),
80 }
81 }
82}
83
84impl Expression for CastExpr {
85 fn as_any(&self) -> &dyn std::any::Any {
86 self
87 }
88
89 fn evaluate(&self, row: &Row) -> Result<bool> {
90 let col_idx = match self.col_index {
91 Some(idx) if idx < row.len() => idx,
92 _ => return Ok(false),
93 };
94
95 let col_value = &row[col_idx];
96
97 if col_value.is_null() {
98 return Ok(false);
99 }
100
101 match self.perform_cast(col_value) {
104 Ok(_) => Ok(true),
105 Err(_) => Ok(false),
106 }
107 }
108
109 fn evaluate_fast(&self, row: &Row) -> bool {
110 let col_idx = match self.col_index {
111 Some(idx) if idx < row.len() => idx,
112 _ => return false,
113 };
114
115 let col_value = &row[col_idx];
116
117 if col_value.is_null() {
118 return false;
119 }
120
121 self.perform_cast(col_value).is_ok()
122 }
123
124 fn with_aliases(&self, aliases: &FxHashMap<String, String>) -> Box<dyn Expression> {
125 let resolved = resolve_alias(&self.column, aliases);
126 let mut expr = self.clone();
127
128 if resolved != self.column {
129 expr.original_column = Some(self.column.clone());
130 expr.column = resolved.to_string();
131 }
132
133 expr.aliases = aliases.clone();
134 expr.col_index = None;
135 Box::new(expr)
136 }
137
138 fn prepare_for_schema(&mut self, schema: &Schema) {
139 if self.col_index.is_some() {
140 return;
141 }
142 self.col_index = find_column_index(schema, &self.column);
143 }
144
145 fn is_prepared(&self) -> bool {
146 self.col_index.is_some()
147 }
148
149 fn get_column_name(&self) -> Option<&str> {
150 Some(&self.column)
151 }
152
153 fn clone_box(&self) -> Box<dyn Expression> {
154 Box::new(self.clone())
155 }
156}
157
158#[derive(Debug, Clone)]
162pub struct CompoundExpr {
163 cast_expr: CastExpr,
165 operator: Operator,
167 value: Value,
169
170 is_optimized: bool,
172}
173
174impl CompoundExpr {
175 pub fn new(cast_expr: CastExpr, operator: Operator, value: Value) -> Self {
177 Self {
178 cast_expr,
179 operator,
180 value,
181 is_optimized: false,
182 }
183 }
184
185 pub fn operator(&self) -> Operator {
187 self.operator
188 }
189
190 pub fn comparison_value(&self) -> &Value {
192 &self.value
193 }
194}
195
196impl Expression for CompoundExpr {
197 fn evaluate(&self, row: &Row) -> Result<bool> {
198 let col_idx = match self.cast_expr.col_index {
199 Some(idx) if idx < row.len() => idx,
200 _ => return Ok(false),
201 };
202
203 let col_value = &row[col_idx];
204
205 if col_value.is_null() {
206 return Ok(false);
207 }
208
209 let casted = self.cast_expr.perform_cast(col_value)?;
211
212 let comp_value = self.cast_expr.perform_cast(&self.value)?;
214
215 let cmp = compare_values(&casted, &comp_value);
217
218 let result = match self.operator {
219 Operator::Eq => cmp == 0,
220 Operator::Ne => cmp != 0,
221 Operator::Gt => cmp > 0,
222 Operator::Gte => cmp >= 0,
223 Operator::Lt => cmp < 0,
224 Operator::Lte => cmp <= 0,
225 _ => false,
226 };
227
228 Ok(result)
229 }
230
231 fn evaluate_fast(&self, row: &Row) -> bool {
232 let col_idx = match self.cast_expr.col_index {
233 Some(idx) if idx < row.len() => idx,
234 _ => return false,
235 };
236
237 let col_value = &row[col_idx];
238
239 if col_value.is_null() {
240 return false;
241 }
242
243 match self.cast_expr.target_type {
245 DataType::Integer => {
246 let col_int = match col_value {
247 Value::Integer(v) => *v,
248 Value::Float(v) => *v as i64,
249 Value::Boolean(b) => {
250 if *b {
251 1
252 } else {
253 0
254 }
255 }
256 Value::Text(s) => {
257 if let Ok(i) = s.parse::<i64>() {
258 i
259 } else if let Ok(f) = s.parse::<f64>() {
260 f as i64
261 } else {
262 return false;
263 }
264 }
265 _ => return false,
266 };
267
268 let comp_int = match &self.value {
269 Value::Integer(v) => *v,
270 Value::Float(v) => *v as i64,
271 _ => return false,
272 };
273
274 match self.operator {
275 Operator::Eq => col_int == comp_int,
276 Operator::Ne => col_int != comp_int,
277 Operator::Gt => col_int > comp_int,
278 Operator::Gte => col_int >= comp_int,
279 Operator::Lt => col_int < comp_int,
280 Operator::Lte => col_int <= comp_int,
281 _ => false,
282 }
283 }
284 DataType::Float => {
285 let col_float = match col_value {
286 Value::Integer(v) => *v as f64,
287 Value::Float(v) => *v,
288 Value::Boolean(b) => {
289 if *b {
290 1.0
291 } else {
292 0.0
293 }
294 }
295 _ => return false,
296 };
297
298 let comp_float = match &self.value {
299 Value::Integer(v) => *v as f64,
300 Value::Float(v) => *v,
301 _ => return false,
302 };
303
304 match self.operator {
305 Operator::Eq => col_float == comp_float,
306 Operator::Ne => col_float != comp_float,
307 Operator::Gt => col_float > comp_float,
308 Operator::Gte => col_float >= comp_float,
309 Operator::Lt => col_float < comp_float,
310 Operator::Lte => col_float <= comp_float,
311 _ => false,
312 }
313 }
314 DataType::Text => {
315 let col_str = col_value.as_string();
316 let col_str = match col_str {
317 Some(s) => s,
318 None => return false,
319 };
320
321 let comp_str = match &self.value {
322 Value::Text(s) => &**s,
323 _ => return false,
324 };
325
326 match self.operator {
327 Operator::Eq => col_str == comp_str,
328 Operator::Ne => col_str != comp_str,
329 Operator::Gt => col_str.as_str() > comp_str,
330 Operator::Gte => col_str.as_str() >= comp_str,
331 Operator::Lt => col_str.as_str() < comp_str,
332 Operator::Lte => col_str.as_str() <= comp_str,
333 _ => false,
334 }
335 }
336 DataType::Boolean => {
337 let col_bool = match col_value {
338 Value::Integer(v) => *v != 0,
339 Value::Float(v) => *v != 0.0,
340 Value::Boolean(b) => *b,
341 _ => return false,
342 };
343
344 let comp_bool = match &self.value {
345 Value::Boolean(b) => *b,
346 Value::Integer(v) => *v != 0,
347 _ => return false,
348 };
349
350 match self.operator {
351 Operator::Eq => col_bool == comp_bool,
352 Operator::Ne => col_bool != comp_bool,
353 _ => false,
354 }
355 }
356 _ => false,
357 }
358 }
359
360 fn with_aliases(&self, aliases: &FxHashMap<String, String>) -> Box<dyn Expression> {
361 let aliased_cast = self.cast_expr.with_aliases(aliases);
362 let cast_expr = if let Some(cast) = aliased_cast.as_any().downcast_ref::<CastExpr>() {
363 cast.clone()
364 } else {
365 self.cast_expr.clone()
366 };
367
368 Box::new(CompoundExpr {
369 cast_expr,
370 operator: self.operator,
371 value: self.value.clone(),
372 is_optimized: false,
373 })
374 }
375
376 fn prepare_for_schema(&mut self, schema: &Schema) {
377 if self.is_optimized {
378 return;
379 }
380 self.cast_expr.prepare_for_schema(schema);
381 self.is_optimized = true;
382 }
383
384 fn is_prepared(&self) -> bool {
385 self.is_optimized
386 }
387
388 fn get_column_name(&self) -> Option<&str> {
389 self.cast_expr.get_column_name()
390 }
391
392 fn clone_box(&self) -> Box<dyn Expression> {
393 Box::new(self.clone())
394 }
395
396 fn as_any(&self) -> &dyn std::any::Any {
397 self
398 }
399}
400
401fn cast_to_integer(value: &Value) -> Result<Value> {
404 match value {
405 Value::Integer(v) => Ok(Value::Integer(*v)),
406 Value::Float(v) => Ok(Value::Integer(*v as i64)),
407 Value::Text(s) => {
408 if let Ok(i) = s.parse::<i64>() {
409 Ok(Value::Integer(i))
410 } else if let Ok(f) = s.parse::<f64>() {
411 Ok(Value::Integer(f as i64))
412 } else {
413 Ok(Value::Integer(0))
414 }
415 }
416 Value::Boolean(b) => Ok(Value::Integer(if *b { 1 } else { 0 })),
417 Value::Timestamp(t) => Ok(Value::Integer(t.timestamp())),
418 Value::Null(_) => Ok(Value::null(DataType::Integer)),
419 _ => Ok(Value::Integer(0)),
420 }
421}
422
423fn cast_to_float(value: &Value) -> Result<Value> {
424 match value {
425 Value::Integer(v) => Ok(Value::float(*v as f64)),
426 Value::Float(v) => Ok(Value::float(*v)),
427 Value::Text(s) => {
428 if let Ok(f) = s.parse::<f64>() {
429 Ok(Value::float(f))
430 } else {
431 Ok(Value::float(0.0))
432 }
433 }
434 Value::Boolean(b) => Ok(Value::float(if *b { 1.0 } else { 0.0 })),
435 Value::Timestamp(t) => Ok(Value::float(t.timestamp() as f64)),
436 Value::Null(_) => Ok(Value::null(DataType::Float)),
437 _ => Ok(Value::float(0.0)),
438 }
439}
440
441fn cast_to_string(value: &Value) -> Result<Value> {
442 match value {
443 Value::Integer(v) => Ok(Value::Text(SmartString::from_string(v.to_string()))),
444 Value::Float(v) => Ok(Value::Text(SmartString::from_string(v.to_string()))),
445 Value::Text(s) => Ok(Value::Text(s.clone())),
446 Value::Boolean(b) => Ok(Value::Text(SmartString::from(if *b {
447 "true"
448 } else {
449 "false"
450 }))),
451 Value::Timestamp(t) => Ok(Value::Text(SmartString::from_string(t.to_rfc3339()))),
452 Value::Extension(data) if data.first() == Some(&(DataType::Json as u8)) => {
453 let s = std::str::from_utf8(&data[1..]).unwrap_or("");
454 Ok(Value::Text(SmartString::from(s)))
455 }
456 Value::Extension(_) => Ok(Value::Text(SmartString::from(""))),
457 Value::Null(_) => Ok(Value::null(DataType::Text)),
458 }
459}
460
461fn cast_to_boolean(value: &Value) -> Result<Value> {
462 match value {
463 Value::Integer(v) => Ok(Value::Boolean(*v != 0)),
464 Value::Float(v) => Ok(Value::Boolean(*v != 0.0)),
465 Value::Text(s) => {
466 let b = s.eq_ignore_ascii_case("true")
468 || s == "1"
469 || s.eq_ignore_ascii_case("t")
470 || s.eq_ignore_ascii_case("yes")
471 || s.eq_ignore_ascii_case("y");
472 Ok(Value::Boolean(b))
473 }
474 Value::Boolean(b) => Ok(Value::Boolean(*b)),
475 Value::Null(_) => Ok(Value::null(DataType::Boolean)),
476 _ => Ok(Value::Boolean(false)),
477 }
478}
479
480fn cast_to_timestamp(value: &Value) -> Result<Value> {
481 match value {
482 Value::Integer(v) => Ok(Value::Timestamp(Utc.timestamp_opt(*v, 0).unwrap())),
483 Value::Float(v) => Ok(Value::Timestamp(Utc.timestamp_opt(*v as i64, 0).unwrap())),
484 Value::Timestamp(t) => Ok(Value::Timestamp(*t)),
485 Value::Text(s) => {
486 if let Ok(ts) = s.parse::<DateTime<Utc>>() {
488 Ok(Value::Timestamp(ts))
489 } else {
490 Ok(Value::Timestamp(Utc::now()))
492 }
493 }
494 Value::Null(_) => Ok(Value::null(DataType::Timestamp)),
495 _ => Ok(Value::Timestamp(Utc::now())),
496 }
497}
498
499fn cast_to_json(value: &Value) -> Result<Value> {
500 match value {
501 Value::Extension(data) if data.first() == Some(&(DataType::Json as u8)) => {
502 Ok(value.clone())
503 }
504 Value::Text(s) => Ok(Value::json(s.as_ref())),
505 Value::Integer(v) => Ok(Value::json(v.to_string())),
506 Value::Float(v) => Ok(Value::json(v.to_string())),
507 Value::Boolean(b) => Ok(Value::json(if *b { "true" } else { "false" })),
508 Value::Null(_) => Ok(Value::json("null")),
509 _ => Ok(Value::json("null")),
510 }
511}
512
513fn compare_values(a: &Value, b: &Value) -> i32 {
516 if a.is_null() && b.is_null() {
518 return 0;
519 }
520 if a.is_null() {
521 return -1;
522 }
523 if b.is_null() {
524 return 1;
525 }
526
527 match (a, b) {
529 (Value::Integer(av), Value::Integer(bv)) => {
530 if av < bv {
531 -1
532 } else if av > bv {
533 1
534 } else {
535 0
536 }
537 }
538 (Value::Float(av), Value::Float(bv)) => {
539 if av < bv {
540 -1
541 } else if av > bv {
542 1
543 } else {
544 0
545 }
546 }
547 (Value::Text(av), Value::Text(bv)) => {
548 if av < bv {
549 -1
550 } else if av > bv {
551 1
552 } else {
553 0
554 }
555 }
556 (Value::Boolean(av), Value::Boolean(bv)) => match (*av, *bv) {
557 (false, true) => -1,
558 (true, false) => 1,
559 _ => 0,
560 },
561 (Value::Timestamp(av), Value::Timestamp(bv)) => {
562 if av < bv {
563 -1
564 } else if av > bv {
565 1
566 } else {
567 0
568 }
569 }
570 (Value::Integer(av), Value::Float(bv)) => {
572 let af = *av as f64;
573 if af < *bv {
574 -1
575 } else if af > *bv {
576 1
577 } else {
578 0
579 }
580 }
581 (Value::Float(av), Value::Integer(bv)) => {
582 let bf = *bv as f64;
583 if *av < bf {
584 -1
585 } else if *av > bf {
586 1
587 } else {
588 0
589 }
590 }
591 _ => {
593 let as_str = a.as_string().unwrap_or_default();
594 let bs_str = b.as_string().unwrap_or_default();
595 if as_str < bs_str {
596 -1
597 } else if as_str > bs_str {
598 1
599 } else {
600 0
601 }
602 }
603 }
604}
605
606#[cfg(test)]
607mod tests {
608 use super::*;
609 use crate::core::SchemaBuilder;
610
611 fn test_schema() -> Schema {
612 SchemaBuilder::new("test")
613 .add_primary_key("id", DataType::Integer)
614 .add("value", DataType::Text)
615 .add("score", DataType::Float)
616 .build()
617 }
618
619 #[test]
620 fn test_cast_to_integer() {
621 let result = cast_to_integer(&Value::text("42")).unwrap();
622 assert_eq!(result, Value::integer(42));
623
624 let result = cast_to_integer(&Value::float(3.5)).unwrap();
625 assert_eq!(result, Value::integer(3));
626
627 let result = cast_to_integer(&Value::Boolean(true)).unwrap();
628 assert_eq!(result, Value::integer(1));
629 }
630
631 #[test]
632 fn test_cast_to_float() {
633 let result = cast_to_float(&Value::text("3.5")).unwrap();
634 assert_eq!(result, Value::float(3.5));
635
636 let result = cast_to_float(&Value::integer(42)).unwrap();
637 assert_eq!(result, Value::float(42.0));
638 }
639
640 #[test]
641 fn test_cast_to_string() {
642 let result = cast_to_string(&Value::integer(42)).unwrap();
643 assert_eq!(result, Value::text("42"));
644
645 let result = cast_to_string(&Value::Boolean(true)).unwrap();
646 assert_eq!(result, Value::text("true"));
647 }
648
649 #[test]
650 fn test_cast_to_boolean() {
651 let result = cast_to_boolean(&Value::text("true")).unwrap();
652 assert_eq!(result, Value::Boolean(true));
653
654 let result = cast_to_boolean(&Value::text("yes")).unwrap();
655 assert_eq!(result, Value::Boolean(true));
656
657 let result = cast_to_boolean(&Value::integer(0)).unwrap();
658 assert_eq!(result, Value::Boolean(false));
659
660 let result = cast_to_boolean(&Value::integer(1)).unwrap();
661 assert_eq!(result, Value::Boolean(true));
662 }
663
664 #[test]
665 fn test_cast_expr_evaluate() {
666 let schema = test_schema();
667 let row = Row::from_values(vec![
668 Value::integer(1),
669 Value::text("42"),
670 Value::float(3.5),
671 ]);
672
673 let mut expr = CastExpr::new("value", DataType::Integer);
674 expr.prepare_for_schema(&schema);
675
676 assert!(expr.evaluate(&row).unwrap());
677 assert!(expr.evaluate_fast(&row));
678 }
679
680 #[test]
681 fn test_compound_expr_integer_comparison() {
682 let schema = test_schema();
683 let row = Row::from_values(vec![
684 Value::integer(1),
685 Value::text("42"),
686 Value::float(3.5),
687 ]);
688
689 let cast = CastExpr::new("value", DataType::Integer);
691 let mut expr = CompoundExpr::new(cast, Operator::Gt, Value::integer(40));
692 expr.prepare_for_schema(&schema);
693
694 assert!(expr.evaluate(&row).unwrap());
695 assert!(expr.evaluate_fast(&row));
696
697 let cast = CastExpr::new("value", DataType::Integer);
699 let mut expr = CompoundExpr::new(cast, Operator::Lt, Value::integer(40));
700 expr.prepare_for_schema(&schema);
701
702 assert!(!expr.evaluate(&row).unwrap());
703 }
704
705 #[test]
706 fn test_compound_expr_float_comparison() {
707 let schema = test_schema();
708 let row = Row::from_values(vec![
709 Value::integer(1),
710 Value::text("3.14"),
711 Value::float(3.5),
712 ]);
713
714 let cast = CastExpr::new("value", DataType::Float);
716 let mut expr = CompoundExpr::new(cast, Operator::Gte, Value::float(3.0));
717 expr.prepare_for_schema(&schema);
718
719 assert!(expr.evaluate(&row).unwrap());
720 }
721
722 #[test]
723 fn test_compound_expr_string_comparison() {
724 let schema = test_schema();
725 let row = Row::from_values(vec![
726 Value::integer(42),
727 Value::text("hello"),
728 Value::float(3.5),
729 ]);
730
731 let cast = CastExpr::new("id", DataType::Text);
733 let mut expr = CompoundExpr::new(cast, Operator::Eq, Value::text("42"));
734 expr.prepare_for_schema(&schema);
735
736 assert!(expr.evaluate(&row).unwrap());
737 }
738
739 #[test]
740 fn test_null_cast() {
741 let schema = test_schema();
742 let row = Row::from_values(vec![
743 Value::integer(1),
744 Value::null(DataType::Text),
745 Value::float(3.5),
746 ]);
747
748 let mut expr = CastExpr::new("value", DataType::Integer);
749 expr.prepare_for_schema(&schema);
750
751 assert!(!expr.evaluate(&row).unwrap());
753 assert!(!expr.evaluate_fast(&row));
754 }
755
756 #[test]
757 fn test_with_aliases() {
758 let schema = test_schema();
759 let row = Row::from_values(vec![
760 Value::integer(1),
761 Value::text("42"),
762 Value::float(3.5),
763 ]);
764
765 let mut aliases = FxHashMap::default();
766 aliases.insert("v".to_string(), "value".to_string());
767
768 let expr = CastExpr::new("v", DataType::Integer);
769 let mut aliased = expr.with_aliases(&aliases);
770 aliased.prepare_for_schema(&schema);
771
772 assert!(aliased.evaluate(&row).unwrap());
773 }
774
775 #[test]
776 fn test_compare_values() {
777 assert_eq!(compare_values(&Value::integer(1), &Value::integer(2)), -1);
778 assert_eq!(compare_values(&Value::integer(2), &Value::integer(2)), 0);
779 assert_eq!(compare_values(&Value::integer(3), &Value::integer(2)), 1);
780
781 assert_eq!(compare_values(&Value::float(1.0), &Value::float(2.0)), -1);
782 assert_eq!(compare_values(&Value::text("a"), &Value::text("b")), -1);
783 }
784
785 #[test]
786 fn test_get_column_name() {
787 let expr = CastExpr::new("id", DataType::Integer);
788 assert_eq!(expr.get_column_name(), Some("id"));
789 }
790
791 #[test]
792 fn test_target_type() {
793 let expr = CastExpr::new("id", DataType::Integer);
794 assert_eq!(expr.target_type(), DataType::Integer);
795 }
796
797 #[test]
798 fn test_cast_invalid_string_to_integer() {
799 let result = cast_to_integer(&Value::text("not_a_number")).unwrap();
800 assert_eq!(result, Value::integer(0)); }
802
803 #[test]
804 fn test_cast_float_string_to_integer() {
805 let result = cast_to_integer(&Value::text("3.7")).unwrap();
806 assert_eq!(result, Value::integer(3)); }
808}