vortex_expr/exprs/
literal.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::sync::Arc;
5
6use vortex_array::arrays::{ConstantArray, ConstantOperator};
7use vortex_array::pipeline::OperatorRef;
8use vortex_array::{ArrayRef, DeserializeMetadata, IntoArray, ProstMetadata};
9use vortex_dtype::{DType, match_each_float_ptype};
10use vortex_error::{VortexResult, vortex_bail, vortex_err};
11use vortex_proto::expr as pb;
12use vortex_scalar::Scalar;
13
14use crate::display::{DisplayAs, DisplayFormat};
15use crate::{
16    AnalysisExpr, ExprEncodingRef, ExprId, ExprRef, IntoExpr, Scope, StatsCatalog, VTable, vtable,
17};
18
19vtable!(Literal);
20
21#[derive(Clone, Debug, PartialEq, Eq, Hash)]
22pub struct LiteralExpr {
23    value: Scalar,
24}
25
26pub struct LiteralExprEncoding;
27
28impl VTable for LiteralVTable {
29    type Expr = LiteralExpr;
30    type Encoding = LiteralExprEncoding;
31    type Metadata = ProstMetadata<pb::LiteralOpts>;
32
33    fn id(_encoding: &Self::Encoding) -> ExprId {
34        ExprId::new_ref("literal")
35    }
36
37    fn encoding(_expr: &Self::Expr) -> ExprEncodingRef {
38        ExprEncodingRef::new_ref(LiteralExprEncoding.as_ref())
39    }
40
41    fn metadata(expr: &Self::Expr) -> Option<Self::Metadata> {
42        Some(ProstMetadata(pb::LiteralOpts {
43            value: Some((&expr.value).into()),
44        }))
45    }
46
47    fn children(_expr: &Self::Expr) -> Vec<&ExprRef> {
48        vec![]
49    }
50
51    fn with_children(expr: &Self::Expr, _children: Vec<ExprRef>) -> VortexResult<Self::Expr> {
52        Ok(expr.clone())
53    }
54
55    fn build(
56        _encoding: &Self::Encoding,
57        metadata: &<Self::Metadata as DeserializeMetadata>::Output,
58        children: Vec<ExprRef>,
59    ) -> VortexResult<Self::Expr> {
60        if !children.is_empty() {
61            vortex_bail!(
62                "Literal expression does not have children, got: {:?}",
63                children
64            );
65        }
66        let value: Scalar = metadata
67            .value
68            .as_ref()
69            .ok_or_else(|| vortex_err!("Literal metadata missing value"))?
70            .try_into()?;
71        Ok(LiteralExpr::new(value))
72    }
73
74    fn evaluate(expr: &Self::Expr, scope: &Scope) -> VortexResult<ArrayRef> {
75        Ok(ConstantArray::new(expr.value.clone(), scope.len()).into_array())
76    }
77
78    fn return_dtype(expr: &Self::Expr, _scope: &DType) -> VortexResult<DType> {
79        Ok(expr.value.dtype().clone())
80    }
81
82    fn operator(expr: &LiteralExpr, children: Vec<OperatorRef>) -> Option<OperatorRef> {
83        assert!(children.is_empty());
84
85        ConstantOperator::maybe_new(expr.value().clone()).map(|op| Arc::new(op) as OperatorRef)
86    }
87}
88
89impl LiteralExpr {
90    pub fn new(value: impl Into<Scalar>) -> Self {
91        Self {
92            value: value.into(),
93        }
94    }
95
96    pub fn new_expr(value: impl Into<Scalar>) -> ExprRef {
97        Self::new(value).into_expr()
98    }
99
100    pub fn value(&self) -> &Scalar {
101        &self.value
102    }
103
104    pub fn maybe_from(expr: &ExprRef) -> Option<&LiteralExpr> {
105        expr.as_opt::<LiteralVTable>()
106    }
107}
108
109impl DisplayAs for LiteralExpr {
110    fn fmt_as(&self, df: DisplayFormat, f: &mut std::fmt::Formatter) -> std::fmt::Result {
111        match df {
112            DisplayFormat::Compact => {
113                write!(f, "{}", self.value)
114            }
115            DisplayFormat::Tree => {
116                write!(
117                    f,
118                    "Literal(value: {}, dtype: {})",
119                    self.value,
120                    self.value.dtype()
121                )
122            }
123        }
124    }
125}
126
127impl AnalysisExpr for LiteralExpr {
128    fn max(&self, _catalog: &mut dyn StatsCatalog) -> Option<ExprRef> {
129        Some(lit(self.value.clone()))
130    }
131
132    fn min(&self, _catalog: &mut dyn StatsCatalog) -> Option<ExprRef> {
133        Some(lit(self.value.clone()))
134    }
135
136    fn nan_count(&self, _catalog: &mut dyn StatsCatalog) -> Option<ExprRef> {
137        // The NaNCount for a non-float literal is not defined.
138        // For floating point types, the NaNCount is 1 for lit(NaN), and 0 otherwise.
139        let value = self.value.as_primitive_opt()?;
140        if !value.ptype().is_float() {
141            return None;
142        }
143
144        match_each_float_ptype!(value.ptype(), |T| {
145            match value.typed_value::<T>() {
146                None => Some(lit(0u64)),
147                Some(value) if value.is_nan() => Some(lit(1u64)),
148                _ => Some(lit(0u64)),
149            }
150        })
151    }
152}
153
154/// Create a new `Literal` expression from a type that coerces to `Scalar`.
155///
156///
157/// ## Example usage
158///
159/// ```
160/// use vortex_array::arrays::PrimitiveArray;
161/// use vortex_dtype::Nullability;
162/// use vortex_expr::{lit, LiteralVTable};
163/// use vortex_scalar::Scalar;
164///
165/// let number = lit(34i32);
166///
167/// let literal = number.as_::<LiteralVTable>();
168/// assert_eq!(literal.value(), &Scalar::primitive(34i32, Nullability::NonNullable));
169/// ```
170pub fn lit(value: impl Into<Scalar>) -> ExprRef {
171    LiteralExpr::new(value.into()).into_expr()
172}
173
174#[cfg(test)]
175mod tests {
176    use vortex_dtype::{DType, Nullability, PType, StructFields};
177    use vortex_scalar::Scalar;
178
179    use crate::{lit, test_harness};
180
181    #[test]
182    fn dtype() {
183        let dtype = test_harness::struct_dtype();
184
185        assert_eq!(
186            lit(10).return_dtype(&dtype).unwrap(),
187            DType::Primitive(PType::I32, Nullability::NonNullable)
188        );
189        assert_eq!(
190            lit(i64::MAX).return_dtype(&dtype).unwrap(),
191            DType::Primitive(PType::I64, Nullability::NonNullable)
192        );
193        assert_eq!(
194            lit(true).return_dtype(&dtype).unwrap(),
195            DType::Bool(Nullability::NonNullable)
196        );
197        assert_eq!(
198            lit(Scalar::null(DType::Bool(Nullability::Nullable)))
199                .return_dtype(&dtype)
200                .unwrap(),
201            DType::Bool(Nullability::Nullable)
202        );
203
204        let sdtype = DType::Struct(
205            StructFields::new(
206                ["dog", "cat"].into(),
207                vec![
208                    DType::Primitive(PType::U32, Nullability::NonNullable),
209                    DType::Utf8(Nullability::NonNullable),
210                ],
211            ),
212            Nullability::NonNullable,
213        );
214        assert_eq!(
215            lit(Scalar::struct_(
216                sdtype.clone(),
217                vec![Scalar::from(32_u32), Scalar::from("rufus".to_string())]
218            ))
219            .return_dtype(&dtype)
220            .unwrap(),
221            sdtype
222        );
223    }
224}