vortex_expr/exprs/
literal.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::sync::Arc;
5
6use vortex_array::arrays::ConstantArray;
7use vortex_array::operator::OperatorRef;
8use vortex_array::{Array, ArrayRef, DeserializeMetadata, IntoArray, ProstMetadata};
9use vortex_dtype::{DType, match_each_float_ptype};
10use vortex_error::{VortexResult, vortex_bail, vortex_err};
11use vortex_proto::expr as pb;
12use vortex_scalar::Scalar;
13
14use crate::display::{DisplayAs, DisplayFormat};
15use crate::{
16    AnalysisExpr, ExprEncodingRef, ExprId, ExprRef, IntoExpr, Scope, StatsCatalog, VTable, vtable,
17};
18
19vtable!(Literal);
20
21#[derive(Clone, Debug, PartialEq, Eq, Hash)]
22pub struct LiteralExpr {
23    value: Scalar,
24}
25
26pub struct LiteralExprEncoding;
27
28impl VTable for LiteralVTable {
29    type Expr = LiteralExpr;
30    type Encoding = LiteralExprEncoding;
31    type Metadata = ProstMetadata<pb::LiteralOpts>;
32
33    fn id(_encoding: &Self::Encoding) -> ExprId {
34        ExprId::new_ref("literal")
35    }
36
37    fn encoding(_expr: &Self::Expr) -> ExprEncodingRef {
38        ExprEncodingRef::new_ref(LiteralExprEncoding.as_ref())
39    }
40
41    fn metadata(expr: &Self::Expr) -> Option<Self::Metadata> {
42        Some(ProstMetadata(pb::LiteralOpts {
43            value: Some((&expr.value).into()),
44        }))
45    }
46
47    fn children(_expr: &Self::Expr) -> Vec<&ExprRef> {
48        vec![]
49    }
50
51    fn with_children(expr: &Self::Expr, _children: Vec<ExprRef>) -> VortexResult<Self::Expr> {
52        Ok(expr.clone())
53    }
54
55    fn build(
56        _encoding: &Self::Encoding,
57        metadata: &<Self::Metadata as DeserializeMetadata>::Output,
58        children: Vec<ExprRef>,
59    ) -> VortexResult<Self::Expr> {
60        if !children.is_empty() {
61            vortex_bail!(
62                "Literal expression does not have children, got: {:?}",
63                children
64            );
65        }
66        let value: Scalar = metadata
67            .value
68            .as_ref()
69            .ok_or_else(|| vortex_err!("Literal metadata missing value"))?
70            .try_into()?;
71        Ok(LiteralExpr::new(value))
72    }
73
74    fn evaluate(expr: &Self::Expr, scope: &Scope) -> VortexResult<ArrayRef> {
75        Ok(ConstantArray::new(expr.value.clone(), scope.len()).into_array())
76    }
77
78    fn return_dtype(expr: &Self::Expr, _scope: &DType) -> VortexResult<DType> {
79        Ok(expr.value.dtype().clone())
80    }
81
82    fn operator(expr: &Self::Expr, scope: &OperatorRef) -> VortexResult<Option<OperatorRef>> {
83        let Some(len) = scope.bounds().maybe_len() else {
84            // Cannot return unsized operator.
85            return Ok(None);
86        };
87        Ok(Some(Arc::new(ConstantArray::new(expr.value.clone(), len))))
88    }
89}
90
91impl LiteralExpr {
92    pub fn new(value: impl Into<Scalar>) -> Self {
93        Self {
94            value: value.into(),
95        }
96    }
97
98    pub fn new_expr(value: impl Into<Scalar>) -> ExprRef {
99        Self::new(value).into_expr()
100    }
101
102    pub fn value(&self) -> &Scalar {
103        &self.value
104    }
105
106    pub fn maybe_from(expr: &ExprRef) -> Option<&LiteralExpr> {
107        expr.as_opt::<LiteralVTable>()
108    }
109}
110
111impl DisplayAs for LiteralExpr {
112    fn fmt_as(&self, df: DisplayFormat, f: &mut std::fmt::Formatter) -> std::fmt::Result {
113        match df {
114            DisplayFormat::Compact => {
115                write!(f, "{}", self.value)
116            }
117            DisplayFormat::Tree => {
118                write!(
119                    f,
120                    "Literal(value: {}, dtype: {})",
121                    self.value,
122                    self.value.dtype()
123                )
124            }
125        }
126    }
127}
128
129impl AnalysisExpr for LiteralExpr {
130    fn max(&self, _catalog: &mut dyn StatsCatalog) -> Option<ExprRef> {
131        Some(lit(self.value.clone()))
132    }
133
134    fn min(&self, _catalog: &mut dyn StatsCatalog) -> Option<ExprRef> {
135        Some(lit(self.value.clone()))
136    }
137
138    fn nan_count(&self, _catalog: &mut dyn StatsCatalog) -> Option<ExprRef> {
139        // The NaNCount for a non-float literal is not defined.
140        // For floating point types, the NaNCount is 1 for lit(NaN), and 0 otherwise.
141        let value = self.value.as_primitive_opt()?;
142        if !value.ptype().is_float() {
143            return None;
144        }
145
146        match_each_float_ptype!(value.ptype(), |T| {
147            match value.typed_value::<T>() {
148                None => Some(lit(0u64)),
149                Some(value) if value.is_nan() => Some(lit(1u64)),
150                _ => Some(lit(0u64)),
151            }
152        })
153    }
154}
155
156/// Create a new `Literal` expression from a type that coerces to `Scalar`.
157///
158///
159/// ## Example usage
160///
161/// ```
162/// use vortex_array::arrays::PrimitiveArray;
163/// use vortex_dtype::Nullability;
164/// use vortex_expr::{lit, LiteralVTable};
165/// use vortex_scalar::Scalar;
166///
167/// let number = lit(34i32);
168///
169/// let literal = number.as_::<LiteralVTable>();
170/// assert_eq!(literal.value(), &Scalar::primitive(34i32, Nullability::NonNullable));
171/// ```
172pub fn lit(value: impl Into<Scalar>) -> ExprRef {
173    LiteralExpr::new(value.into()).into_expr()
174}
175
176#[cfg(test)]
177mod tests {
178    use vortex_dtype::{DType, Nullability, PType, StructFields};
179    use vortex_scalar::Scalar;
180
181    use crate::{lit, test_harness};
182
183    #[test]
184    fn dtype() {
185        let dtype = test_harness::struct_dtype();
186
187        assert_eq!(
188            lit(10).return_dtype(&dtype).unwrap(),
189            DType::Primitive(PType::I32, Nullability::NonNullable)
190        );
191        assert_eq!(
192            lit(i64::MAX).return_dtype(&dtype).unwrap(),
193            DType::Primitive(PType::I64, Nullability::NonNullable)
194        );
195        assert_eq!(
196            lit(true).return_dtype(&dtype).unwrap(),
197            DType::Bool(Nullability::NonNullable)
198        );
199        assert_eq!(
200            lit(Scalar::null(DType::Bool(Nullability::Nullable)))
201                .return_dtype(&dtype)
202                .unwrap(),
203            DType::Bool(Nullability::Nullable)
204        );
205
206        let sdtype = DType::Struct(
207            StructFields::new(
208                ["dog", "cat"].into(),
209                vec![
210                    DType::Primitive(PType::U32, Nullability::NonNullable),
211                    DType::Utf8(Nullability::NonNullable),
212                ],
213            ),
214            Nullability::NonNullable,
215        );
216        assert_eq!(
217            lit(Scalar::struct_(
218                sdtype.clone(),
219                vec![Scalar::from(32_u32), Scalar::from("rufus".to_string())]
220            ))
221            .return_dtype(&dtype)
222            .unwrap(),
223            sdtype
224        );
225    }
226}