Skip to main content

vortex_array/scalar_fn/fns/
literal.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Formatter;
5
6use prost::Message;
7use vortex_error::VortexResult;
8use vortex_error::vortex_err;
9use vortex_proto::expr as pb;
10use vortex_session::VortexSession;
11use vortex_session::registry::CachedId;
12
13use crate::ArrayRef;
14use crate::ExecutionCtx;
15use crate::IntoArray;
16use crate::arrays::ConstantArray;
17use crate::dtype::DType;
18use crate::expr::Expression;
19use crate::expr::StatsCatalog;
20use crate::expr::stats::Stat;
21use crate::match_each_float_ptype;
22use crate::scalar::Scalar;
23use crate::scalar_fn::Arity;
24use crate::scalar_fn::ChildName;
25use crate::scalar_fn::ExecutionArgs;
26use crate::scalar_fn::ScalarFnId;
27use crate::scalar_fn::ScalarFnVTable;
28use crate::scalar_fn::ScalarFnVTableExt;
29
30fn lit(value: impl Into<Scalar>) -> Expression {
31    Literal.new_expr(value.into(), [])
32}
33
34/// Expression that represents a literal scalar value.
35#[derive(Clone)]
36pub struct Literal;
37
38impl ScalarFnVTable for Literal {
39    type Options = Scalar;
40
41    fn id(&self) -> ScalarFnId {
42        static ID: CachedId = CachedId::new("vortex.literal");
43        *ID
44    }
45
46    fn serialize(&self, instance: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
47        Ok(Some(
48            pb::LiteralOpts {
49                value: Some(instance.into()),
50            }
51            .encode_to_vec(),
52        ))
53    }
54
55    fn deserialize(
56        &self,
57        _metadata: &[u8],
58        session: &VortexSession,
59    ) -> VortexResult<Self::Options> {
60        let ops = pb::LiteralOpts::decode(_metadata)?;
61        Scalar::from_proto(
62            ops.value
63                .as_ref()
64                .ok_or_else(|| vortex_err!("Literal metadata missing value"))?,
65            session,
66        )
67    }
68
69    fn arity(&self, _options: &Self::Options) -> Arity {
70        Arity::Exact(0)
71    }
72
73    fn child_name(&self, _instance: &Self::Options, _child_idx: usize) -> ChildName {
74        unreachable!()
75    }
76
77    fn fmt_sql(
78        &self,
79        scalar: &Scalar,
80        _expr: &Expression,
81        f: &mut Formatter<'_>,
82    ) -> std::fmt::Result {
83        write!(f, "{}", scalar)
84    }
85
86    fn return_dtype(&self, options: &Self::Options, _arg_dtypes: &[DType]) -> VortexResult<DType> {
87        Ok(options.dtype().clone())
88    }
89
90    fn execute(
91        &self,
92        scalar: &Scalar,
93        args: &dyn ExecutionArgs,
94        _ctx: &mut ExecutionCtx,
95    ) -> VortexResult<ArrayRef> {
96        Ok(ConstantArray::new(scalar.clone(), args.row_count()).into_array())
97    }
98
99    fn stat_expression(
100        &self,
101        scalar: &Scalar,
102        _expr: &Expression,
103        stat: Stat,
104        _catalog: &dyn StatsCatalog,
105    ) -> Option<Expression> {
106        // NOTE(ngates): we return incorrect `1` values for counts here since we don't have
107        //  row-count information. We could resolve this in the future by introducing a `count()`
108        //  expression that evaluates to the row count of the provided scope. But since this is
109        //  only currently used for pruning, it doesn't change the outcome.
110
111        match stat {
112            Stat::Min | Stat::Max => Some(lit(scalar.clone())),
113            Stat::IsConstant => Some(lit(true)),
114            Stat::NaNCount => {
115                // The NaNCount for a non-float literal is not defined.
116                // For floating point types, the NaNCount is 1 for lit(NaN), and 0 otherwise.
117                let value = scalar.as_primitive_opt()?;
118                if !value.ptype().is_float() {
119                    return None;
120                }
121
122                match_each_float_ptype!(value.ptype(), |T| {
123                    if value.typed_value::<T>().is_some_and(|v| v.is_nan()) {
124                        Some(lit(1u64))
125                    } else {
126                        Some(lit(0u64))
127                    }
128                })
129            }
130            Stat::NullCount => {
131                if scalar.is_null() {
132                    Some(lit(1u64))
133                } else {
134                    Some(lit(0u64))
135                }
136            }
137            Stat::IsSorted | Stat::IsStrictSorted | Stat::Sum | Stat::UncompressedSizeInBytes => {
138                None
139            }
140        }
141    }
142
143    fn validity(
144        &self,
145        scalar: &Scalar,
146        _expression: &Expression,
147    ) -> VortexResult<Option<Expression>> {
148        Ok(Some(lit(scalar.is_valid())))
149    }
150
151    fn is_null_sensitive(&self, _instance: &Self::Options) -> bool {
152        false
153    }
154
155    fn is_fallible(&self, _instance: &Self::Options) -> bool {
156        false
157    }
158}
159
160#[cfg(test)]
161mod tests {
162    use crate::dtype::DType;
163    use crate::dtype::Nullability;
164    use crate::dtype::PType;
165    use crate::dtype::StructFields;
166    use crate::expr::lit;
167    use crate::expr::test_harness;
168    use crate::scalar::Scalar;
169
170    #[test]
171    fn dtype() {
172        let dtype = test_harness::struct_dtype();
173
174        assert_eq!(
175            lit(10).return_dtype(&dtype).unwrap(),
176            DType::Primitive(PType::I32, Nullability::NonNullable)
177        );
178        assert_eq!(
179            lit(i64::MAX).return_dtype(&dtype).unwrap(),
180            DType::Primitive(PType::I64, Nullability::NonNullable)
181        );
182        assert_eq!(
183            lit(true).return_dtype(&dtype).unwrap(),
184            DType::Bool(Nullability::NonNullable)
185        );
186        assert_eq!(
187            lit(Scalar::null(DType::Bool(Nullability::Nullable)))
188                .return_dtype(&dtype)
189                .unwrap(),
190            DType::Bool(Nullability::Nullable)
191        );
192
193        let sdtype = DType::Struct(
194            StructFields::new(
195                ["dog", "cat"].into(),
196                vec![
197                    DType::Primitive(PType::U32, Nullability::NonNullable),
198                    DType::Utf8(Nullability::NonNullable),
199                ],
200            ),
201            Nullability::NonNullable,
202        );
203        assert_eq!(
204            lit(Scalar::struct_(
205                sdtype.clone(),
206                vec![Scalar::from(32_u32), Scalar::from("rufus".to_string())]
207            ))
208            .return_dtype(&dtype)
209            .unwrap(),
210            sdtype
211        );
212    }
213}