vortex_expr/exprs/
literal.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::sync::Arc;
5
6use vortex_array::arrays::ConstantArray;
7use vortex_array::operator::OperatorRef;
8use vortex_array::{Array, ArrayRef, DeserializeMetadata, IntoArray, ProstMetadata};
9use vortex_dtype::{DType, match_each_float_ptype};
10use vortex_error::{VortexResult, vortex_bail, vortex_err};
11use vortex_proto::expr as pb;
12use vortex_scalar::Scalar;
13
14use crate::display::{DisplayAs, DisplayFormat};
15use crate::{
16    AnalysisExpr, ExprEncodingRef, ExprId, ExprRef, IntoExpr, Scope, StatsCatalog, VTable, vtable,
17};
18
19vtable!(Literal);
20
21#[derive(Clone, Debug, PartialEq, Eq, Hash)]
22pub struct LiteralExpr {
23    value: Scalar,
24}
25
26pub struct LiteralExprEncoding;
27
28impl VTable for LiteralVTable {
29    type Expr = LiteralExpr;
30    type Encoding = LiteralExprEncoding;
31    type Metadata = ProstMetadata<pb::LiteralOpts>;
32
33    fn id(_encoding: &Self::Encoding) -> ExprId {
34        ExprId::new_ref("literal")
35    }
36
37    fn encoding(_expr: &Self::Expr) -> ExprEncodingRef {
38        ExprEncodingRef::new_ref(LiteralExprEncoding.as_ref())
39    }
40
41    fn metadata(expr: &Self::Expr) -> Option<Self::Metadata> {
42        Some(ProstMetadata(pb::LiteralOpts {
43            value: Some((&expr.value).into()),
44        }))
45    }
46
47    fn children(_expr: &Self::Expr) -> Vec<&ExprRef> {
48        vec![]
49    }
50
51    fn with_children(expr: &Self::Expr, _children: Vec<ExprRef>) -> VortexResult<Self::Expr> {
52        Ok(expr.clone())
53    }
54
55    fn build(
56        _encoding: &Self::Encoding,
57        metadata: &<Self::Metadata as DeserializeMetadata>::Output,
58        children: Vec<ExprRef>,
59    ) -> VortexResult<Self::Expr> {
60        if !children.is_empty() {
61            vortex_bail!(
62                "Literal expression does not have children, got: {:?}",
63                children
64            );
65        }
66        let value: Scalar = metadata
67            .value
68            .as_ref()
69            .ok_or_else(|| vortex_err!("Literal metadata missing value"))?
70            .try_into()?;
71        Ok(LiteralExpr::new(value))
72    }
73
74    fn evaluate(expr: &Self::Expr, scope: &Scope) -> VortexResult<ArrayRef> {
75        Ok(ConstantArray::new(expr.value.clone(), scope.len()).into_array())
76    }
77
78    fn return_dtype(expr: &Self::Expr, _scope: &DType) -> VortexResult<DType> {
79        Ok(expr.value.dtype().clone())
80    }
81
82    fn operator(expr: &Self::Expr, scope: &OperatorRef) -> VortexResult<Option<OperatorRef>> {
83        Ok(Some(Arc::new(ConstantArray::new(
84            expr.value.clone(),
85            scope.len(),
86        ))))
87    }
88}
89
90impl LiteralExpr {
91    pub fn new(value: impl Into<Scalar>) -> Self {
92        Self {
93            value: value.into(),
94        }
95    }
96
97    pub fn new_expr(value: impl Into<Scalar>) -> ExprRef {
98        Self::new(value).into_expr()
99    }
100
101    pub fn value(&self) -> &Scalar {
102        &self.value
103    }
104
105    pub fn maybe_from(expr: &ExprRef) -> Option<&LiteralExpr> {
106        expr.as_opt::<LiteralVTable>()
107    }
108}
109
110impl DisplayAs for LiteralExpr {
111    fn fmt_as(&self, df: DisplayFormat, f: &mut std::fmt::Formatter) -> std::fmt::Result {
112        match df {
113            DisplayFormat::Compact => {
114                write!(f, "{}", self.value)
115            }
116            DisplayFormat::Tree => {
117                write!(
118                    f,
119                    "Literal(value: {}, dtype: {})",
120                    self.value,
121                    self.value.dtype()
122                )
123            }
124        }
125    }
126}
127
128impl AnalysisExpr for LiteralExpr {
129    fn max(&self, _catalog: &mut dyn StatsCatalog) -> Option<ExprRef> {
130        Some(lit(self.value.clone()))
131    }
132
133    fn min(&self, _catalog: &mut dyn StatsCatalog) -> Option<ExprRef> {
134        Some(lit(self.value.clone()))
135    }
136
137    fn nan_count(&self, _catalog: &mut dyn StatsCatalog) -> Option<ExprRef> {
138        // The NaNCount for a non-float literal is not defined.
139        // For floating point types, the NaNCount is 1 for lit(NaN), and 0 otherwise.
140        let value = self.value.as_primitive_opt()?;
141        if !value.ptype().is_float() {
142            return None;
143        }
144
145        match_each_float_ptype!(value.ptype(), |T| {
146            match value.typed_value::<T>() {
147                None => Some(lit(0u64)),
148                Some(value) if value.is_nan() => Some(lit(1u64)),
149                _ => Some(lit(0u64)),
150            }
151        })
152    }
153}
154
155/// Create a new `Literal` expression from a type that coerces to `Scalar`.
156///
157///
158/// ## Example usage
159///
160/// ```
161/// use vortex_array::arrays::PrimitiveArray;
162/// use vortex_dtype::Nullability;
163/// use vortex_expr::{lit, LiteralVTable};
164/// use vortex_scalar::Scalar;
165///
166/// let number = lit(34i32);
167///
168/// let literal = number.as_::<LiteralVTable>();
169/// assert_eq!(literal.value(), &Scalar::primitive(34i32, Nullability::NonNullable));
170/// ```
171pub fn lit(value: impl Into<Scalar>) -> ExprRef {
172    LiteralExpr::new(value.into()).into_expr()
173}
174
175#[cfg(test)]
176mod tests {
177    use vortex_dtype::{DType, Nullability, PType, StructFields};
178    use vortex_scalar::Scalar;
179
180    use crate::{lit, test_harness};
181
182    #[test]
183    fn dtype() {
184        let dtype = test_harness::struct_dtype();
185
186        assert_eq!(
187            lit(10).return_dtype(&dtype).unwrap(),
188            DType::Primitive(PType::I32, Nullability::NonNullable)
189        );
190        assert_eq!(
191            lit(i64::MAX).return_dtype(&dtype).unwrap(),
192            DType::Primitive(PType::I64, Nullability::NonNullable)
193        );
194        assert_eq!(
195            lit(true).return_dtype(&dtype).unwrap(),
196            DType::Bool(Nullability::NonNullable)
197        );
198        assert_eq!(
199            lit(Scalar::null(DType::Bool(Nullability::Nullable)))
200                .return_dtype(&dtype)
201                .unwrap(),
202            DType::Bool(Nullability::Nullable)
203        );
204
205        let sdtype = DType::Struct(
206            StructFields::new(
207                ["dog", "cat"].into(),
208                vec![
209                    DType::Primitive(PType::U32, Nullability::NonNullable),
210                    DType::Utf8(Nullability::NonNullable),
211                ],
212            ),
213            Nullability::NonNullable,
214        );
215        assert_eq!(
216            lit(Scalar::struct_(
217                sdtype.clone(),
218                vec![Scalar::from(32_u32), Scalar::from("rufus".to_string())]
219            ))
220            .return_dtype(&dtype)
221            .unwrap(),
222            sdtype
223        );
224    }
225}