vortex_expr/exprs/
literal.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_array::arrays::ConstantArray;
5use vortex_array::{ArrayRef, DeserializeMetadata, IntoArray, ProstMetadata};
6use vortex_dtype::{DType, match_each_float_ptype};
7use vortex_error::{VortexResult, vortex_bail, vortex_err};
8use vortex_proto::expr as pb;
9use vortex_scalar::Scalar;
10
11use crate::display::{DisplayAs, DisplayFormat};
12use crate::{
13    AnalysisExpr, ExprEncodingRef, ExprId, ExprRef, IntoExpr, Scope, StatsCatalog, VTable, vtable,
14};
15
16vtable!(Literal);
17
18#[derive(Clone, Debug, PartialEq, Eq, Hash)]
19pub struct LiteralExpr {
20    value: Scalar,
21}
22
23pub struct LiteralExprEncoding;
24
25impl VTable for LiteralVTable {
26    type Expr = LiteralExpr;
27    type Encoding = LiteralExprEncoding;
28    type Metadata = ProstMetadata<pb::LiteralOpts>;
29
30    fn id(_encoding: &Self::Encoding) -> ExprId {
31        ExprId::new_ref("literal")
32    }
33
34    fn encoding(_expr: &Self::Expr) -> ExprEncodingRef {
35        ExprEncodingRef::new_ref(LiteralExprEncoding.as_ref())
36    }
37
38    fn metadata(expr: &Self::Expr) -> Option<Self::Metadata> {
39        Some(ProstMetadata(pb::LiteralOpts {
40            value: Some((&expr.value).into()),
41        }))
42    }
43
44    fn children(_expr: &Self::Expr) -> Vec<&ExprRef> {
45        vec![]
46    }
47
48    fn with_children(expr: &Self::Expr, _children: Vec<ExprRef>) -> VortexResult<Self::Expr> {
49        Ok(expr.clone())
50    }
51
52    fn build(
53        _encoding: &Self::Encoding,
54        metadata: &<Self::Metadata as DeserializeMetadata>::Output,
55        children: Vec<ExprRef>,
56    ) -> VortexResult<Self::Expr> {
57        if !children.is_empty() {
58            vortex_bail!(
59                "Literal expression does not have children, got: {:?}",
60                children
61            );
62        }
63        let value: Scalar = metadata
64            .value
65            .as_ref()
66            .ok_or_else(|| vortex_err!("Literal metadata missing value"))?
67            .try_into()?;
68        Ok(LiteralExpr::new(value))
69    }
70
71    fn evaluate(expr: &Self::Expr, scope: &Scope) -> VortexResult<ArrayRef> {
72        Ok(ConstantArray::new(expr.value.clone(), scope.len()).into_array())
73    }
74
75    fn return_dtype(expr: &Self::Expr, _scope: &DType) -> VortexResult<DType> {
76        Ok(expr.value.dtype().clone())
77    }
78}
79
80impl LiteralExpr {
81    pub fn new(value: impl Into<Scalar>) -> Self {
82        Self {
83            value: value.into(),
84        }
85    }
86
87    pub fn new_expr(value: impl Into<Scalar>) -> ExprRef {
88        Self::new(value).into_expr()
89    }
90
91    pub fn value(&self) -> &Scalar {
92        &self.value
93    }
94
95    pub fn maybe_from(expr: &ExprRef) -> Option<&LiteralExpr> {
96        expr.as_opt::<LiteralVTable>()
97    }
98}
99
100impl DisplayAs for LiteralExpr {
101    fn fmt_as(&self, df: DisplayFormat, f: &mut std::fmt::Formatter) -> std::fmt::Result {
102        match df {
103            DisplayFormat::Compact => {
104                write!(f, "{}", self.value)
105            }
106            DisplayFormat::Tree => {
107                write!(
108                    f,
109                    "Literal(value: {}, dtype: {})",
110                    self.value,
111                    self.value.dtype()
112                )
113            }
114        }
115    }
116}
117
118impl AnalysisExpr for LiteralExpr {
119    fn max(&self, _catalog: &mut dyn StatsCatalog) -> Option<ExprRef> {
120        Some(lit(self.value.clone()))
121    }
122
123    fn min(&self, _catalog: &mut dyn StatsCatalog) -> Option<ExprRef> {
124        Some(lit(self.value.clone()))
125    }
126
127    fn nan_count(&self, _catalog: &mut dyn StatsCatalog) -> Option<ExprRef> {
128        // The NaNCount for a non-float literal is not defined.
129        // For floating point types, the NaNCount is 1 for lit(NaN), and 0 otherwise.
130        let value = self.value.as_primitive_opt()?;
131        if !value.ptype().is_float() {
132            return None;
133        }
134
135        match_each_float_ptype!(value.ptype(), |T| {
136            match value.typed_value::<T>() {
137                None => Some(lit(0u64)),
138                Some(value) if value.is_nan() => Some(lit(1u64)),
139                _ => Some(lit(0u64)),
140            }
141        })
142    }
143}
144
145/// Create a new `Literal` expression from a type that coerces to `Scalar`.
146///
147///
148/// ## Example usage
149///
150/// ```
151/// use vortex_array::arrays::PrimitiveArray;
152/// use vortex_dtype::Nullability;
153/// use vortex_expr::{lit, LiteralVTable};
154/// use vortex_scalar::Scalar;
155///
156/// let number = lit(34i32);
157///
158/// let literal = number.as_::<LiteralVTable>();
159/// assert_eq!(literal.value(), &Scalar::primitive(34i32, Nullability::NonNullable));
160/// ```
161pub fn lit(value: impl Into<Scalar>) -> ExprRef {
162    LiteralExpr::new(value.into()).into_expr()
163}
164
165#[cfg(test)]
166mod tests {
167    use vortex_dtype::{DType, Nullability, PType, StructFields};
168    use vortex_scalar::Scalar;
169
170    use crate::{lit, test_harness};
171
172    #[test]
173    fn dtype() {
174        let dtype = test_harness::struct_dtype();
175
176        assert_eq!(
177            lit(10).return_dtype(&dtype).unwrap(),
178            DType::Primitive(PType::I32, Nullability::NonNullable)
179        );
180        assert_eq!(
181            lit(i64::MAX).return_dtype(&dtype).unwrap(),
182            DType::Primitive(PType::I64, Nullability::NonNullable)
183        );
184        assert_eq!(
185            lit(true).return_dtype(&dtype).unwrap(),
186            DType::Bool(Nullability::NonNullable)
187        );
188        assert_eq!(
189            lit(Scalar::null(DType::Bool(Nullability::Nullable)))
190                .return_dtype(&dtype)
191                .unwrap(),
192            DType::Bool(Nullability::Nullable)
193        );
194
195        let sdtype = DType::Struct(
196            StructFields::new(
197                ["dog", "cat"].into(),
198                vec![
199                    DType::Primitive(PType::U32, Nullability::NonNullable),
200                    DType::Utf8(Nullability::NonNullable),
201                ],
202            ),
203            Nullability::NonNullable,
204        );
205        assert_eq!(
206            lit(Scalar::struct_(
207                sdtype.clone(),
208                vec![Scalar::from(32_u32), Scalar::from("rufus".to_string())]
209            ))
210            .return_dtype(&dtype)
211            .unwrap(),
212            sdtype
213        );
214    }
215}