Skip to main content

vortex_array/scalar_fn/fns/
literal.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Formatter;
5
6use prost::Message;
7use vortex_error::VortexResult;
8use vortex_error::vortex_err;
9use vortex_proto::expr as pb;
10use vortex_session::VortexSession;
11
12use crate::ArrayRef;
13use crate::IntoArray;
14use crate::arrays::ConstantArray;
15use crate::dtype::DType;
16use crate::expr::Expression;
17use crate::expr::StatsCatalog;
18use crate::expr::stats::Stat;
19use crate::match_each_float_ptype;
20use crate::scalar::Scalar;
21use crate::scalar_fn::Arity;
22use crate::scalar_fn::ChildName;
23use crate::scalar_fn::ExecutionArgs;
24use crate::scalar_fn::ScalarFnId;
25use crate::scalar_fn::ScalarFnVTable;
26use crate::scalar_fn::ScalarFnVTableExt;
27
28fn lit(value: impl Into<Scalar>) -> Expression {
29    Literal.new_expr(value.into(), [])
30}
31
32/// Expression that represents a literal scalar value.
33#[derive(Clone)]
34pub struct Literal;
35
36impl ScalarFnVTable for Literal {
37    type Options = Scalar;
38
39    fn id(&self) -> ScalarFnId {
40        ScalarFnId::new_ref("vortex.literal")
41    }
42
43    fn serialize(&self, instance: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
44        Ok(Some(
45            pb::LiteralOpts {
46                value: Some(instance.into()),
47            }
48            .encode_to_vec(),
49        ))
50    }
51
52    fn deserialize(
53        &self,
54        _metadata: &[u8],
55        session: &VortexSession,
56    ) -> VortexResult<Self::Options> {
57        let ops = pb::LiteralOpts::decode(_metadata)?;
58        Scalar::from_proto(
59            ops.value
60                .as_ref()
61                .ok_or_else(|| vortex_err!("Literal metadata missing value"))?,
62            session,
63        )
64    }
65
66    fn arity(&self, _options: &Self::Options) -> Arity {
67        Arity::Exact(0)
68    }
69
70    fn child_name(&self, _instance: &Self::Options, _child_idx: usize) -> ChildName {
71        unreachable!()
72    }
73
74    fn fmt_sql(
75        &self,
76        scalar: &Scalar,
77        _expr: &Expression,
78        f: &mut Formatter<'_>,
79    ) -> std::fmt::Result {
80        write!(f, "{}", scalar)
81    }
82
83    fn return_dtype(&self, options: &Self::Options, _arg_dtypes: &[DType]) -> VortexResult<DType> {
84        Ok(options.dtype().clone())
85    }
86
87    fn execute(&self, scalar: &Scalar, args: ExecutionArgs) -> VortexResult<ArrayRef> {
88        Ok(ConstantArray::new(scalar.clone(), args.row_count).into_array())
89    }
90
91    fn stat_expression(
92        &self,
93        scalar: &Scalar,
94        _expr: &Expression,
95        stat: Stat,
96        _catalog: &dyn StatsCatalog,
97    ) -> Option<Expression> {
98        // NOTE(ngates): we return incorrect `1` values for counts here since we don't have
99        //  row-count information. We could resolve this in the future by introducing a `count()`
100        //  expression that evaluates to the row count of the provided scope. But since this is
101        //  only currently used for pruning, it doesn't change the outcome.
102
103        match stat {
104            Stat::Min | Stat::Max => Some(lit(scalar.clone())),
105            Stat::IsConstant => Some(lit(true)),
106            Stat::NaNCount => {
107                // The NaNCount for a non-float literal is not defined.
108                // For floating point types, the NaNCount is 1 for lit(NaN), and 0 otherwise.
109                let value = scalar.as_primitive_opt()?;
110                if !value.ptype().is_float() {
111                    return None;
112                }
113
114                match_each_float_ptype!(value.ptype(), |T| {
115                    if value.typed_value::<T>().is_some_and(|v| v.is_nan()) {
116                        Some(lit(1u64))
117                    } else {
118                        Some(lit(0u64))
119                    }
120                })
121            }
122            Stat::NullCount => {
123                if scalar.is_null() {
124                    Some(lit(1u64))
125                } else {
126                    Some(lit(0u64))
127                }
128            }
129            Stat::IsSorted | Stat::IsStrictSorted | Stat::Sum | Stat::UncompressedSizeInBytes => {
130                None
131            }
132        }
133    }
134
135    fn validity(
136        &self,
137        scalar: &Scalar,
138        _expression: &Expression,
139    ) -> VortexResult<Option<Expression>> {
140        Ok(Some(lit(scalar.is_valid())))
141    }
142
143    fn is_null_sensitive(&self, _instance: &Self::Options) -> bool {
144        false
145    }
146
147    fn is_fallible(&self, _instance: &Self::Options) -> bool {
148        false
149    }
150}
151
152#[cfg(test)]
153mod tests {
154    use crate::dtype::DType;
155    use crate::dtype::Nullability;
156    use crate::dtype::PType;
157    use crate::dtype::StructFields;
158    use crate::expr::lit;
159    use crate::expr::test_harness;
160    use crate::scalar::Scalar;
161
162    #[test]
163    fn dtype() {
164        let dtype = test_harness::struct_dtype();
165
166        assert_eq!(
167            lit(10).return_dtype(&dtype).unwrap(),
168            DType::Primitive(PType::I32, Nullability::NonNullable)
169        );
170        assert_eq!(
171            lit(i64::MAX).return_dtype(&dtype).unwrap(),
172            DType::Primitive(PType::I64, Nullability::NonNullable)
173        );
174        assert_eq!(
175            lit(true).return_dtype(&dtype).unwrap(),
176            DType::Bool(Nullability::NonNullable)
177        );
178        assert_eq!(
179            lit(Scalar::null(DType::Bool(Nullability::Nullable)))
180                .return_dtype(&dtype)
181                .unwrap(),
182            DType::Bool(Nullability::Nullable)
183        );
184
185        let sdtype = DType::Struct(
186            StructFields::new(
187                ["dog", "cat"].into(),
188                vec![
189                    DType::Primitive(PType::U32, Nullability::NonNullable),
190                    DType::Utf8(Nullability::NonNullable),
191                ],
192            ),
193            Nullability::NonNullable,
194        );
195        assert_eq!(
196            lit(Scalar::struct_(
197                sdtype.clone(),
198                vec![Scalar::from(32_u32), Scalar::from("rufus".to_string())]
199            ))
200            .return_dtype(&dtype)
201            .unwrap(),
202            sdtype
203        );
204    }
205}