Skip to main content

vortex_array/expr/exprs/
literal.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Formatter;
5
6use prost::Message;
7use vortex_dtype::DType;
8use vortex_dtype::match_each_float_ptype;
9use vortex_error::VortexResult;
10use vortex_error::vortex_err;
11use vortex_proto::expr as pb;
12use vortex_session::VortexSession;
13
14use crate::ArrayRef;
15use crate::IntoArray;
16use crate::arrays::ConstantArray;
17use crate::expr::Arity;
18use crate::expr::ChildName;
19use crate::expr::ExecutionArgs;
20use crate::expr::ExprId;
21use crate::expr::Expression;
22use crate::expr::StatsCatalog;
23use crate::expr::VTable;
24use crate::expr::VTableExt;
25use crate::expr::stats::Stat;
26use crate::scalar::Scalar;
27
28/// Expression that represents a literal scalar value.
29pub struct Literal;
30
31impl VTable for Literal {
32    type Options = Scalar;
33
34    fn id(&self) -> ExprId {
35        ExprId::new_ref("vortex.literal")
36    }
37
38    fn serialize(&self, instance: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
39        Ok(Some(
40            pb::LiteralOpts {
41                value: Some(instance.into()),
42            }
43            .encode_to_vec(),
44        ))
45    }
46
47    fn deserialize(
48        &self,
49        _metadata: &[u8],
50        session: &VortexSession,
51    ) -> VortexResult<Self::Options> {
52        let ops = pb::LiteralOpts::decode(_metadata)?;
53        Scalar::from_proto(
54            ops.value
55                .as_ref()
56                .ok_or_else(|| vortex_err!("Literal metadata missing value"))?,
57            session,
58        )
59    }
60
61    fn arity(&self, _options: &Self::Options) -> Arity {
62        Arity::Exact(0)
63    }
64
65    fn child_name(&self, _instance: &Self::Options, _child_idx: usize) -> ChildName {
66        unreachable!()
67    }
68
69    fn fmt_sql(
70        &self,
71        scalar: &Scalar,
72        _expr: &Expression,
73        f: &mut Formatter<'_>,
74    ) -> std::fmt::Result {
75        write!(f, "{}", scalar)
76    }
77
78    fn return_dtype(&self, options: &Self::Options, _arg_dtypes: &[DType]) -> VortexResult<DType> {
79        Ok(options.dtype().clone())
80    }
81
82    fn execute(&self, scalar: &Scalar, args: ExecutionArgs) -> VortexResult<ArrayRef> {
83        Ok(ConstantArray::new(scalar.clone(), args.row_count).into_array())
84    }
85
86    fn stat_expression(
87        &self,
88        scalar: &Scalar,
89        _expr: &Expression,
90        stat: Stat,
91        _catalog: &dyn StatsCatalog,
92    ) -> Option<Expression> {
93        // NOTE(ngates): we return incorrect `1` values for counts here since we don't have
94        //  row-count information. We could resolve this in the future by introducing a `count()`
95        //  expression that evaluates to the row count of the provided scope. But since this is
96        //  only currently used for pruning, it doesn't change the outcome.
97
98        match stat {
99            Stat::Min | Stat::Max => Some(lit(scalar.clone())),
100            Stat::IsConstant => Some(lit(true)),
101            Stat::NaNCount => {
102                // The NaNCount for a non-float literal is not defined.
103                // For floating point types, the NaNCount is 1 for lit(NaN), and 0 otherwise.
104                let value = scalar.as_primitive_opt()?;
105                if !value.ptype().is_float() {
106                    return None;
107                }
108
109                match_each_float_ptype!(value.ptype(), |T| {
110                    if value.typed_value::<T>().is_some_and(|v| v.is_nan()) {
111                        Some(lit(1u64))
112                    } else {
113                        Some(lit(0u64))
114                    }
115                })
116            }
117            Stat::NullCount => {
118                if scalar.is_null() {
119                    Some(lit(1u64))
120                } else {
121                    Some(lit(0u64))
122                }
123            }
124            Stat::IsSorted | Stat::IsStrictSorted | Stat::Sum | Stat::UncompressedSizeInBytes => {
125                None
126            }
127        }
128    }
129
130    fn validity(
131        &self,
132        scalar: &Scalar,
133        _expression: &Expression,
134    ) -> VortexResult<Option<Expression>> {
135        Ok(Some(lit(scalar.is_valid())))
136    }
137
138    fn is_null_sensitive(&self, _instance: &Self::Options) -> bool {
139        false
140    }
141
142    fn is_fallible(&self, _instance: &Self::Options) -> bool {
143        false
144    }
145}
146
147/// Create a new `Literal` expression from a type that coerces to `Scalar`.
148///
149///
150/// ## Example usage
151///
152/// ```
153/// use vortex_array::arrays::PrimitiveArray;
154/// use vortex_dtype::Nullability;
155/// use vortex_array::expr::{lit, Literal};
156/// use vortex_array::scalar::Scalar;
157///
158/// let number = lit(34i32);
159///
160/// let scalar = number.as_::<Literal>();
161/// assert_eq!(scalar, &Scalar::primitive(34i32, Nullability::NonNullable));
162/// ```
163pub fn lit(value: impl Into<Scalar>) -> Expression {
164    Literal.new_expr(value.into(), [])
165}
166
167#[cfg(test)]
168mod tests {
169    use vortex_dtype::DType;
170    use vortex_dtype::Nullability;
171    use vortex_dtype::PType;
172    use vortex_dtype::StructFields;
173
174    use super::lit;
175    use crate::expr::test_harness;
176    use crate::scalar::Scalar;
177
178    #[test]
179    fn dtype() {
180        let dtype = test_harness::struct_dtype();
181
182        assert_eq!(
183            lit(10).return_dtype(&dtype).unwrap(),
184            DType::Primitive(PType::I32, Nullability::NonNullable)
185        );
186        assert_eq!(
187            lit(i64::MAX).return_dtype(&dtype).unwrap(),
188            DType::Primitive(PType::I64, Nullability::NonNullable)
189        );
190        assert_eq!(
191            lit(true).return_dtype(&dtype).unwrap(),
192            DType::Bool(Nullability::NonNullable)
193        );
194        assert_eq!(
195            lit(Scalar::null(DType::Bool(Nullability::Nullable)))
196                .return_dtype(&dtype)
197                .unwrap(),
198            DType::Bool(Nullability::Nullable)
199        );
200
201        let sdtype = DType::Struct(
202            StructFields::new(
203                ["dog", "cat"].into(),
204                vec![
205                    DType::Primitive(PType::U32, Nullability::NonNullable),
206                    DType::Utf8(Nullability::NonNullable),
207                ],
208            ),
209            Nullability::NonNullable,
210        );
211        assert_eq!(
212            lit(Scalar::struct_(
213                sdtype.clone(),
214                vec![Scalar::from(32_u32), Scalar::from("rufus".to_string())]
215            ))
216            .return_dtype(&dtype)
217            .unwrap(),
218            sdtype
219        );
220    }
221}