Skip to main content

vortex_array/scalar_fn/fns/
literal.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Formatter;
5
6use prost::Message;
7use vortex_error::VortexResult;
8use vortex_error::vortex_err;
9use vortex_proto::expr as pb;
10use vortex_session::VortexSession;
11
12use crate::ArrayRef;
13use crate::ExecutionCtx;
14use crate::IntoArray;
15use crate::arrays::ConstantArray;
16use crate::dtype::DType;
17use crate::expr::Expression;
18use crate::expr::StatsCatalog;
19use crate::expr::stats::Stat;
20use crate::match_each_float_ptype;
21use crate::scalar::Scalar;
22use crate::scalar_fn::Arity;
23use crate::scalar_fn::ChildName;
24use crate::scalar_fn::ExecutionArgs;
25use crate::scalar_fn::ScalarFnId;
26use crate::scalar_fn::ScalarFnVTable;
27use crate::scalar_fn::ScalarFnVTableExt;
28
29fn lit(value: impl Into<Scalar>) -> Expression {
30    Literal.new_expr(value.into(), [])
31}
32
33/// Expression that represents a literal scalar value.
34#[derive(Clone)]
35pub struct Literal;
36
37impl ScalarFnVTable for Literal {
38    type Options = Scalar;
39
40    fn id(&self) -> ScalarFnId {
41        ScalarFnId::new_ref("vortex.literal")
42    }
43
44    fn serialize(&self, instance: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
45        Ok(Some(
46            pb::LiteralOpts {
47                value: Some(instance.into()),
48            }
49            .encode_to_vec(),
50        ))
51    }
52
53    fn deserialize(
54        &self,
55        _metadata: &[u8],
56        session: &VortexSession,
57    ) -> VortexResult<Self::Options> {
58        let ops = pb::LiteralOpts::decode(_metadata)?;
59        Scalar::from_proto(
60            ops.value
61                .as_ref()
62                .ok_or_else(|| vortex_err!("Literal metadata missing value"))?,
63            session,
64        )
65    }
66
67    fn arity(&self, _options: &Self::Options) -> Arity {
68        Arity::Exact(0)
69    }
70
71    fn child_name(&self, _instance: &Self::Options, _child_idx: usize) -> ChildName {
72        unreachable!()
73    }
74
75    fn fmt_sql(
76        &self,
77        scalar: &Scalar,
78        _expr: &Expression,
79        f: &mut Formatter<'_>,
80    ) -> std::fmt::Result {
81        write!(f, "{}", scalar)
82    }
83
84    fn return_dtype(&self, options: &Self::Options, _arg_dtypes: &[DType]) -> VortexResult<DType> {
85        Ok(options.dtype().clone())
86    }
87
88    fn execute(
89        &self,
90        scalar: &Scalar,
91        args: &dyn ExecutionArgs,
92        _ctx: &mut ExecutionCtx,
93    ) -> VortexResult<ArrayRef> {
94        Ok(ConstantArray::new(scalar.clone(), args.row_count()).into_array())
95    }
96
97    fn stat_expression(
98        &self,
99        scalar: &Scalar,
100        _expr: &Expression,
101        stat: Stat,
102        _catalog: &dyn StatsCatalog,
103    ) -> Option<Expression> {
104        // NOTE(ngates): we return incorrect `1` values for counts here since we don't have
105        //  row-count information. We could resolve this in the future by introducing a `count()`
106        //  expression that evaluates to the row count of the provided scope. But since this is
107        //  only currently used for pruning, it doesn't change the outcome.
108
109        match stat {
110            Stat::Min | Stat::Max => Some(lit(scalar.clone())),
111            Stat::IsConstant => Some(lit(true)),
112            Stat::NaNCount => {
113                // The NaNCount for a non-float literal is not defined.
114                // For floating point types, the NaNCount is 1 for lit(NaN), and 0 otherwise.
115                let value = scalar.as_primitive_opt()?;
116                if !value.ptype().is_float() {
117                    return None;
118                }
119
120                match_each_float_ptype!(value.ptype(), |T| {
121                    if value.typed_value::<T>().is_some_and(|v| v.is_nan()) {
122                        Some(lit(1u64))
123                    } else {
124                        Some(lit(0u64))
125                    }
126                })
127            }
128            Stat::NullCount => {
129                if scalar.is_null() {
130                    Some(lit(1u64))
131                } else {
132                    Some(lit(0u64))
133                }
134            }
135            Stat::IsSorted | Stat::IsStrictSorted | Stat::Sum | Stat::UncompressedSizeInBytes => {
136                None
137            }
138        }
139    }
140
141    fn validity(
142        &self,
143        scalar: &Scalar,
144        _expression: &Expression,
145    ) -> VortexResult<Option<Expression>> {
146        Ok(Some(lit(scalar.is_valid())))
147    }
148
149    fn is_null_sensitive(&self, _instance: &Self::Options) -> bool {
150        false
151    }
152
153    fn is_fallible(&self, _instance: &Self::Options) -> bool {
154        false
155    }
156}
157
158#[cfg(test)]
159mod tests {
160    use crate::dtype::DType;
161    use crate::dtype::Nullability;
162    use crate::dtype::PType;
163    use crate::dtype::StructFields;
164    use crate::expr::lit;
165    use crate::expr::test_harness;
166    use crate::scalar::Scalar;
167
168    #[test]
169    fn dtype() {
170        let dtype = test_harness::struct_dtype();
171
172        assert_eq!(
173            lit(10).return_dtype(&dtype).unwrap(),
174            DType::Primitive(PType::I32, Nullability::NonNullable)
175        );
176        assert_eq!(
177            lit(i64::MAX).return_dtype(&dtype).unwrap(),
178            DType::Primitive(PType::I64, Nullability::NonNullable)
179        );
180        assert_eq!(
181            lit(true).return_dtype(&dtype).unwrap(),
182            DType::Bool(Nullability::NonNullable)
183        );
184        assert_eq!(
185            lit(Scalar::null(DType::Bool(Nullability::Nullable)))
186                .return_dtype(&dtype)
187                .unwrap(),
188            DType::Bool(Nullability::Nullable)
189        );
190
191        let sdtype = DType::Struct(
192            StructFields::new(
193                ["dog", "cat"].into(),
194                vec![
195                    DType::Primitive(PType::U32, Nullability::NonNullable),
196                    DType::Utf8(Nullability::NonNullable),
197                ],
198            ),
199            Nullability::NonNullable,
200        );
201        assert_eq!(
202            lit(Scalar::struct_(
203                sdtype.clone(),
204                vec![Scalar::from(32_u32), Scalar::from("rufus".to_string())]
205            ))
206            .return_dtype(&dtype)
207            .unwrap(),
208            sdtype
209        );
210    }
211}