vortex_expr/exprs/
literal.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Display;
5
6use vortex_array::arrays::ConstantArray;
7use vortex_array::{ArrayRef, DeserializeMetadata, IntoArray, ProstMetadata};
8use vortex_dtype::{DType, match_each_float_ptype};
9use vortex_error::{VortexResult, vortex_bail, vortex_err};
10use vortex_proto::expr as pb;
11use vortex_scalar::Scalar;
12
13use crate::{
14    AnalysisExpr, ExprEncodingRef, ExprId, ExprRef, IntoExpr, Scope, StatsCatalog, VTable, vtable,
15};
16
17vtable!(Literal);
18
19#[derive(Clone, Debug, PartialEq, Eq, Hash)]
20pub struct LiteralExpr {
21    value: Scalar,
22}
23
24pub struct LiteralExprEncoding;
25
26impl VTable for LiteralVTable {
27    type Expr = LiteralExpr;
28    type Encoding = LiteralExprEncoding;
29    type Metadata = ProstMetadata<pb::LiteralOpts>;
30
31    fn id(_encoding: &Self::Encoding) -> ExprId {
32        ExprId::new_ref("literal")
33    }
34
35    fn encoding(_expr: &Self::Expr) -> ExprEncodingRef {
36        ExprEncodingRef::new_ref(LiteralExprEncoding.as_ref())
37    }
38
39    fn metadata(expr: &Self::Expr) -> Option<Self::Metadata> {
40        Some(ProstMetadata(pb::LiteralOpts {
41            value: Some((&expr.value).into()),
42        }))
43    }
44
45    fn children(_expr: &Self::Expr) -> Vec<&ExprRef> {
46        vec![]
47    }
48
49    fn with_children(expr: &Self::Expr, _children: Vec<ExprRef>) -> VortexResult<Self::Expr> {
50        Ok(expr.clone())
51    }
52
53    fn build(
54        _encoding: &Self::Encoding,
55        metadata: &<Self::Metadata as DeserializeMetadata>::Output,
56        children: Vec<ExprRef>,
57    ) -> VortexResult<Self::Expr> {
58        if !children.is_empty() {
59            vortex_bail!(
60                "Literal expression does not have children, got: {:?}",
61                children
62            );
63        }
64        let value: Scalar = metadata
65            .value
66            .as_ref()
67            .ok_or_else(|| vortex_err!("Literal metadata missing value"))?
68            .try_into()?;
69        Ok(LiteralExpr::new(value))
70    }
71
72    fn evaluate(expr: &Self::Expr, scope: &Scope) -> VortexResult<ArrayRef> {
73        Ok(ConstantArray::new(expr.value.clone(), scope.len()).into_array())
74    }
75
76    fn return_dtype(expr: &Self::Expr, _scope: &DType) -> VortexResult<DType> {
77        Ok(expr.value.dtype().clone())
78    }
79}
80
81impl LiteralExpr {
82    pub fn new(value: impl Into<Scalar>) -> Self {
83        Self {
84            value: value.into(),
85        }
86    }
87
88    pub fn new_expr(value: impl Into<Scalar>) -> ExprRef {
89        Self::new(value).into_expr()
90    }
91
92    pub fn value(&self) -> &Scalar {
93        &self.value
94    }
95
96    pub fn maybe_from(expr: &ExprRef) -> Option<&LiteralExpr> {
97        expr.as_opt::<LiteralVTable>()
98    }
99}
100
101impl Display for LiteralExpr {
102    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
103        write!(f, "{}", self.value)
104    }
105}
106
107impl AnalysisExpr for LiteralExpr {
108    fn max(&self, _catalog: &mut dyn StatsCatalog) -> Option<ExprRef> {
109        Some(lit(self.value.clone()))
110    }
111
112    fn min(&self, _catalog: &mut dyn StatsCatalog) -> Option<ExprRef> {
113        Some(lit(self.value.clone()))
114    }
115
116    fn nan_count(&self, _catalog: &mut dyn StatsCatalog) -> Option<ExprRef> {
117        // The NaNCount for a non-float literal is not defined.
118        // For floating point types, the NaNCount is 1 for lit(NaN), and 0 otherwise.
119        let value = self.value.as_primitive_opt()?;
120        if !value.ptype().is_float() {
121            return None;
122        }
123
124        match_each_float_ptype!(value.ptype(), |T| {
125            match value.typed_value::<T>() {
126                None => Some(lit(0u64)),
127                Some(value) if value.is_nan() => Some(lit(1u64)),
128                _ => Some(lit(0u64)),
129            }
130        })
131    }
132}
133
134/// Create a new `Literal` expression from a type that coerces to `Scalar`.
135///
136///
137/// ## Example usage
138///
139/// ```
140/// use vortex_array::arrays::PrimitiveArray;
141/// use vortex_dtype::Nullability;
142/// use vortex_expr::{lit, LiteralVTable};
143/// use vortex_scalar::Scalar;
144///
145/// let number = lit(34i32);
146///
147/// let literal = number.as_::<LiteralVTable>();
148/// assert_eq!(literal.value(), &Scalar::primitive(34i32, Nullability::NonNullable));
149/// ```
150pub fn lit(value: impl Into<Scalar>) -> ExprRef {
151    LiteralExpr::new(value.into()).into_expr()
152}
153
154#[cfg(test)]
155mod tests {
156    use vortex_dtype::{DType, Nullability, PType, StructFields};
157    use vortex_scalar::Scalar;
158
159    use crate::{lit, test_harness};
160
161    #[test]
162    fn dtype() {
163        let dtype = test_harness::struct_dtype();
164
165        assert_eq!(
166            lit(10).return_dtype(&dtype).unwrap(),
167            DType::Primitive(PType::I32, Nullability::NonNullable)
168        );
169        assert_eq!(
170            lit(i64::MAX).return_dtype(&dtype).unwrap(),
171            DType::Primitive(PType::I64, Nullability::NonNullable)
172        );
173        assert_eq!(
174            lit(true).return_dtype(&dtype).unwrap(),
175            DType::Bool(Nullability::NonNullable)
176        );
177        assert_eq!(
178            lit(Scalar::null(DType::Bool(Nullability::Nullable)))
179                .return_dtype(&dtype)
180                .unwrap(),
181            DType::Bool(Nullability::Nullable)
182        );
183
184        let sdtype = DType::Struct(
185            StructFields::new(
186                ["dog", "cat"].into(),
187                vec![
188                    DType::Primitive(PType::U32, Nullability::NonNullable),
189                    DType::Utf8(Nullability::NonNullable),
190                ],
191            ),
192            Nullability::NonNullable,
193        );
194        assert_eq!(
195            lit(Scalar::struct_(
196                sdtype.clone(),
197                vec![Scalar::from(32_u32), Scalar::from("rufus".to_string())]
198            ))
199            .return_dtype(&dtype)
200            .unwrap(),
201            sdtype
202        );
203    }
204}