vortex_expr/
literal.rs

1use std::any::Any;
2use std::fmt::Display;
3use std::sync::Arc;
4
5use vortex_array::arrays::ConstantArray;
6use vortex_array::{Array, ArrayRef};
7use vortex_dtype::DType;
8use vortex_error::VortexResult;
9use vortex_scalar::Scalar;
10
11use crate::{ExprRef, VortexExpr};
12
13#[derive(Debug, PartialEq, Eq, Hash)]
14pub struct Literal {
15    value: Scalar,
16}
17
18impl Literal {
19    pub fn new_expr(value: impl Into<Scalar>) -> ExprRef {
20        Arc::new(Self {
21            value: value.into(),
22        })
23    }
24
25    pub fn value(&self) -> &Scalar {
26        &self.value
27    }
28
29    pub fn maybe_from(expr: &ExprRef) -> Option<&Literal> {
30        expr.as_any().downcast_ref::<Literal>()
31    }
32}
33
34impl Display for Literal {
35    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
36        write!(f, "{}", self.value)
37    }
38}
39
40impl VortexExpr for Literal {
41    fn as_any(&self) -> &dyn Any {
42        self
43    }
44
45    fn unchecked_evaluate(&self, batch: &dyn Array) -> VortexResult<ArrayRef> {
46        Ok(ConstantArray::new(self.value.clone(), batch.len()).into_array())
47    }
48
49    fn children(&self) -> Vec<&ExprRef> {
50        vec![]
51    }
52
53    fn replacing_children(self: Arc<Self>, children: Vec<ExprRef>) -> ExprRef {
54        assert_eq!(children.len(), 0);
55        self
56    }
57
58    fn return_dtype(&self, _scope_dtype: &DType) -> VortexResult<DType> {
59        Ok(self.value.dtype().clone())
60    }
61}
62
63/// Create a new `Literal` expression from a type that coerces to `Scalar`.
64///
65///
66/// ## Example usage
67///
68/// ```
69/// use vortex_array::arrays::PrimitiveArray;
70/// use vortex_dtype::Nullability;
71/// use vortex_expr::{lit, Literal};
72/// use vortex_scalar::Scalar;
73///
74/// let number = lit(34i32);
75///
76/// let literal = number.as_any()
77///     .downcast_ref::<Literal>()
78///     .unwrap();
79/// assert_eq!(literal.value(), &Scalar::primitive(34i32, Nullability::NonNullable));
80/// ```
81pub fn lit(value: impl Into<Scalar>) -> ExprRef {
82    Literal::new_expr(value.into())
83}
84
85#[cfg(test)]
86mod tests {
87    use std::sync::Arc;
88
89    use vortex_dtype::{DType, Nullability, PType, StructDType};
90    use vortex_scalar::Scalar;
91
92    use crate::{lit, test_harness};
93
94    #[test]
95    fn dtype() {
96        let dtype = test_harness::struct_dtype();
97
98        assert_eq!(
99            lit(10).return_dtype(&dtype).unwrap(),
100            DType::Primitive(PType::I32, Nullability::NonNullable)
101        );
102        assert_eq!(
103            lit(0_u8).return_dtype(&dtype).unwrap(),
104            DType::Primitive(PType::U8, Nullability::NonNullable)
105        );
106        assert_eq!(
107            lit(0.0_f32).return_dtype(&dtype).unwrap(),
108            DType::Primitive(PType::F32, Nullability::NonNullable)
109        );
110        assert_eq!(
111            lit(i64::MAX).return_dtype(&dtype).unwrap(),
112            DType::Primitive(PType::I64, Nullability::NonNullable)
113        );
114        assert_eq!(
115            lit(true).return_dtype(&dtype).unwrap(),
116            DType::Bool(Nullability::NonNullable)
117        );
118        assert_eq!(
119            lit(Scalar::null(DType::Bool(Nullability::Nullable)))
120                .return_dtype(&dtype)
121                .unwrap(),
122            DType::Bool(Nullability::Nullable)
123        );
124
125        let sdtype = DType::Struct(
126            Arc::new(StructDType::new(
127                Arc::from([Arc::from("dog"), Arc::from("cat")]),
128                vec![
129                    DType::Primitive(PType::U32, Nullability::NonNullable),
130                    DType::Utf8(Nullability::NonNullable),
131                ],
132            )),
133            Nullability::NonNullable,
134        );
135        assert_eq!(
136            lit(Scalar::struct_(
137                sdtype.clone(),
138                vec![Scalar::from(32_u32), Scalar::from("rufus".to_string())]
139            ))
140            .return_dtype(&dtype)
141            .unwrap(),
142            sdtype
143        );
144    }
145}