vortex_array/expr/exprs/
literal.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Formatter;
5
6use prost::Message;
7use vortex_dtype::DType;
8use vortex_dtype::match_each_float_ptype;
9use vortex_error::VortexResult;
10use vortex_error::vortex_err;
11use vortex_proto::expr as pb;
12use vortex_scalar::Scalar;
13use vortex_vector::Datum;
14
15use crate::Array;
16use crate::ArrayRef;
17use crate::IntoArray;
18use crate::arrays::ConstantArray;
19use crate::expr::Arity;
20use crate::expr::ChildName;
21use crate::expr::ExecutionArgs;
22use crate::expr::ExprId;
23use crate::expr::Expression;
24use crate::expr::StatsCatalog;
25use crate::expr::VTable;
26use crate::expr::VTableExt;
27use crate::expr::stats::Stat;
28
29/// Expression that represents a literal scalar value.
30pub struct Literal;
31
32impl VTable for Literal {
33    type Options = Scalar;
34
35    fn id(&self) -> ExprId {
36        ExprId::new_ref("vortex.literal")
37    }
38
39    fn serialize(&self, instance: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
40        Ok(Some(
41            pb::LiteralOpts {
42                value: Some(instance.as_ref().into()),
43            }
44            .encode_to_vec(),
45        ))
46    }
47
48    fn deserialize(&self, metadata: &[u8]) -> VortexResult<Self::Options> {
49        let ops = pb::LiteralOpts::decode(metadata)?;
50        ops.value
51            .as_ref()
52            .ok_or_else(|| vortex_err!("Literal metadata missing value"))?
53            .try_into()
54    }
55
56    fn arity(&self, _options: &Self::Options) -> Arity {
57        Arity::Exact(0)
58    }
59
60    fn child_name(&self, _instance: &Self::Options, _child_idx: usize) -> ChildName {
61        unreachable!()
62    }
63
64    fn fmt_sql(
65        &self,
66        scalar: &Scalar,
67        _expr: &Expression,
68        f: &mut Formatter<'_>,
69    ) -> std::fmt::Result {
70        write!(f, "{}", scalar)
71    }
72
73    fn return_dtype(&self, options: &Self::Options, _arg_dtypes: &[DType]) -> VortexResult<DType> {
74        Ok(options.dtype().clone())
75    }
76
77    fn evaluate(
78        &self,
79        scalar: &Scalar,
80        _expr: &Expression,
81        scope: &ArrayRef,
82    ) -> VortexResult<ArrayRef> {
83        Ok(ConstantArray::new(scalar.clone(), scope.len()).into_array())
84    }
85
86    fn execute(&self, scalar: &Scalar, _args: ExecutionArgs) -> VortexResult<Datum> {
87        let vector_scalar = scalar.to_vector_scalar();
88        Ok(Datum::Scalar(vector_scalar))
89    }
90
91    fn stat_expression(
92        &self,
93        scalar: &Scalar,
94        _expr: &Expression,
95        stat: Stat,
96        _catalog: &dyn StatsCatalog,
97    ) -> Option<Expression> {
98        // NOTE(ngates): we return incorrect `1` values for counts here since we don't have
99        //  row-count information. We could resolve this in the future by introducing a `count()`
100        //  expression that evaluates to the row count of the provided scope. But since this is
101        //  only currently used for pruning, it doesn't change the outcome.
102
103        match stat {
104            Stat::Min | Stat::Max => Some(lit(scalar.clone())),
105            Stat::IsConstant => Some(lit(true)),
106            Stat::NaNCount => {
107                // The NaNCount for a non-float literal is not defined.
108                // For floating point types, the NaNCount is 1 for lit(NaN), and 0 otherwise.
109                let value = scalar.as_primitive_opt()?;
110                if !value.ptype().is_float() {
111                    return None;
112                }
113
114                match_each_float_ptype!(value.ptype(), |T| {
115                    if value.typed_value::<T>().is_some_and(|v| v.is_nan()) {
116                        Some(lit(1u64))
117                    } else {
118                        Some(lit(0u64))
119                    }
120                })
121            }
122            Stat::NullCount => {
123                if scalar.is_null() {
124                    Some(lit(1u64))
125                } else {
126                    Some(lit(0u64))
127                }
128            }
129            Stat::IsSorted | Stat::IsStrictSorted | Stat::Sum | Stat::UncompressedSizeInBytes => {
130                None
131            }
132        }
133    }
134
135    fn is_null_sensitive(&self, _instance: &Self::Options) -> bool {
136        false
137    }
138
139    fn is_fallible(&self, _instance: &Self::Options) -> bool {
140        false
141    }
142}
143
144/// Create a new `Literal` expression from a type that coerces to `Scalar`.
145///
146///
147/// ## Example usage
148///
149/// ```
150/// use vortex_array::arrays::PrimitiveArray;
151/// use vortex_dtype::Nullability;
152/// use vortex_array::expr::{lit, Literal};
153/// use vortex_scalar::Scalar;
154///
155/// let number = lit(34i32);
156///
157/// let scalar = number.as_::<Literal>();
158/// assert_eq!(scalar, &Scalar::primitive(34i32, Nullability::NonNullable));
159/// ```
160pub fn lit(value: impl Into<Scalar>) -> Expression {
161    Literal.new_expr(value.into(), [])
162}
163
164#[cfg(test)]
165mod tests {
166    use vortex_dtype::DType;
167    use vortex_dtype::Nullability;
168    use vortex_dtype::PType;
169    use vortex_dtype::StructFields;
170    use vortex_scalar::Scalar;
171
172    use super::lit;
173    use crate::expr::test_harness;
174
175    #[test]
176    fn dtype() {
177        let dtype = test_harness::struct_dtype();
178
179        assert_eq!(
180            lit(10).return_dtype(&dtype).unwrap(),
181            DType::Primitive(PType::I32, Nullability::NonNullable)
182        );
183        assert_eq!(
184            lit(i64::MAX).return_dtype(&dtype).unwrap(),
185            DType::Primitive(PType::I64, Nullability::NonNullable)
186        );
187        assert_eq!(
188            lit(true).return_dtype(&dtype).unwrap(),
189            DType::Bool(Nullability::NonNullable)
190        );
191        assert_eq!(
192            lit(Scalar::null(DType::Bool(Nullability::Nullable)))
193                .return_dtype(&dtype)
194                .unwrap(),
195            DType::Bool(Nullability::Nullable)
196        );
197
198        let sdtype = DType::Struct(
199            StructFields::new(
200                ["dog", "cat"].into(),
201                vec![
202                    DType::Primitive(PType::U32, Nullability::NonNullable),
203                    DType::Utf8(Nullability::NonNullable),
204                ],
205            ),
206            Nullability::NonNullable,
207        );
208        assert_eq!(
209            lit(Scalar::struct_(
210                sdtype.clone(),
211                vec![Scalar::from(32_u32), Scalar::from("rufus".to_string())]
212            ))
213            .return_dtype(&dtype)
214            .unwrap(),
215            sdtype
216        );
217    }
218}