vortex_array/expr/exprs/
literal.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Formatter;
5
6use prost::Message;
7use vortex_dtype::DType;
8use vortex_dtype::match_each_float_ptype;
9use vortex_error::VortexResult;
10use vortex_error::vortex_bail;
11use vortex_error::vortex_err;
12use vortex_proto::expr as pb;
13use vortex_scalar::Scalar;
14
15use crate::Array;
16use crate::ArrayRef;
17use crate::IntoArray;
18use crate::arrays::ConstantArray;
19use crate::expr::ChildName;
20use crate::expr::ExprId;
21use crate::expr::Expression;
22use crate::expr::ExpressionView;
23use crate::expr::StatsCatalog;
24use crate::expr::VTable;
25use crate::expr::VTableExt;
26use crate::expr::stats::Stat;
27
28/// Expression that represents a literal scalar value.
29pub struct Literal;
30
31impl VTable for Literal {
32    type Instance = Scalar;
33
34    fn id(&self) -> ExprId {
35        ExprId::new_ref("vortex.literal")
36    }
37
38    fn serialize(&self, instance: &Self::Instance) -> VortexResult<Option<Vec<u8>>> {
39        Ok(Some(
40            pb::LiteralOpts {
41                value: Some(instance.as_ref().into()),
42            }
43            .encode_to_vec(),
44        ))
45    }
46
47    fn deserialize(&self, metadata: &[u8]) -> VortexResult<Option<Self::Instance>> {
48        let ops = pb::LiteralOpts::decode(metadata)?;
49        Ok(Some(
50            ops.value
51                .as_ref()
52                .ok_or_else(|| vortex_err!("Literal metadata missing value"))?
53                .try_into()?,
54        ))
55    }
56
57    fn validate(&self, expr: &ExpressionView<Self>) -> VortexResult<()> {
58        if !expr.children().is_empty() {
59            vortex_bail!(
60                "Literal expression does not have children, got: {:?}",
61                expr.children()
62            );
63        }
64        Ok(())
65    }
66
67    fn child_name(&self, _instance: &Self::Instance, _child_idx: usize) -> ChildName {
68        unreachable!()
69    }
70
71    fn fmt_sql(&self, expr: &ExpressionView<Self>, f: &mut Formatter<'_>) -> std::fmt::Result {
72        write!(f, "{}", expr.data())
73    }
74
75    fn fmt_data(&self, instance: &Self::Instance, f: &mut Formatter<'_>) -> std::fmt::Result {
76        write!(f, "{}", instance)
77    }
78
79    fn return_dtype(&self, expr: &ExpressionView<Self>, _scope: &DType) -> VortexResult<DType> {
80        Ok(expr.data().dtype().clone())
81    }
82
83    fn evaluate(&self, expr: &ExpressionView<Self>, scope: &ArrayRef) -> VortexResult<ArrayRef> {
84        Ok(ConstantArray::new(expr.data().clone(), scope.len()).into_array())
85    }
86
87    fn stat_expression(
88        &self,
89        expr: &ExpressionView<Self>,
90        stat: Stat,
91        _catalog: &dyn StatsCatalog,
92    ) -> Option<Expression> {
93        // NOTE(ngates): we return incorrect `1` values for counts here since we don't have
94        //  row-count information. We could resolve this in the future by introducing a `count()`
95        //  expression that evaluates to the row count of the provided scope. But since this is
96        //  only currently used for pruning, it doesn't change the outcome.
97
98        match stat {
99            Stat::Min | Stat::Max => Some(lit(expr.data().clone())),
100            Stat::IsConstant => Some(lit(true)),
101            Stat::NaNCount => {
102                // The NaNCount for a non-float literal is not defined.
103                // For floating point types, the NaNCount is 1 for lit(NaN), and 0 otherwise.
104                let value = expr.data().as_primitive_opt()?;
105                if !value.ptype().is_float() {
106                    return None;
107                }
108
109                match_each_float_ptype!(value.ptype(), |T| {
110                    if value.typed_value::<T>().is_some_and(|v| v.is_nan()) {
111                        Some(lit(1u64))
112                    } else {
113                        Some(lit(0u64))
114                    }
115                })
116            }
117            Stat::NullCount => {
118                if expr.data().is_null() {
119                    Some(lit(1u64))
120                } else {
121                    Some(lit(0u64))
122                }
123            }
124            Stat::IsSorted | Stat::IsStrictSorted | Stat::Sum | Stat::UncompressedSizeInBytes => {
125                None
126            }
127        }
128    }
129
130    fn is_null_sensitive(&self, _instance: &Self::Instance) -> bool {
131        false
132    }
133
134    fn is_fallible(&self, _instance: &Self::Instance) -> bool {
135        false
136    }
137}
138
139/// Create a new `Literal` expression from a type that coerces to `Scalar`.
140///
141///
142/// ## Example usage
143///
144/// ```
145/// use vortex_array::arrays::PrimitiveArray;
146/// use vortex_dtype::Nullability;
147/// use vortex_array::expr::{lit, Literal};
148/// use vortex_scalar::Scalar;
149///
150/// let number = lit(34i32);
151///
152/// let literal = number.as_::<Literal>();
153/// assert_eq!(literal.data(), &Scalar::primitive(34i32, Nullability::NonNullable));
154/// ```
155pub fn lit(value: impl Into<Scalar>) -> Expression {
156    Literal.new_expr(value.into(), [])
157}
158
159#[cfg(test)]
160mod tests {
161    use vortex_dtype::DType;
162    use vortex_dtype::Nullability;
163    use vortex_dtype::PType;
164    use vortex_dtype::StructFields;
165    use vortex_scalar::Scalar;
166
167    use super::lit;
168    use crate::expr::test_harness;
169
170    #[test]
171    fn dtype() {
172        let dtype = test_harness::struct_dtype();
173
174        assert_eq!(
175            lit(10).return_dtype(&dtype).unwrap(),
176            DType::Primitive(PType::I32, Nullability::NonNullable)
177        );
178        assert_eq!(
179            lit(i64::MAX).return_dtype(&dtype).unwrap(),
180            DType::Primitive(PType::I64, Nullability::NonNullable)
181        );
182        assert_eq!(
183            lit(true).return_dtype(&dtype).unwrap(),
184            DType::Bool(Nullability::NonNullable)
185        );
186        assert_eq!(
187            lit(Scalar::null(DType::Bool(Nullability::Nullable)))
188                .return_dtype(&dtype)
189                .unwrap(),
190            DType::Bool(Nullability::Nullable)
191        );
192
193        let sdtype = DType::Struct(
194            StructFields::new(
195                ["dog", "cat"].into(),
196                vec![
197                    DType::Primitive(PType::U32, Nullability::NonNullable),
198                    DType::Utf8(Nullability::NonNullable),
199                ],
200            ),
201            Nullability::NonNullable,
202        );
203        assert_eq!(
204            lit(Scalar::struct_(
205                sdtype.clone(),
206                vec![Scalar::from(32_u32), Scalar::from("rufus".to_string())]
207            ))
208            .return_dtype(&dtype)
209            .unwrap(),
210            sdtype
211        );
212    }
213}