Skip to main content

vortex_array/scalar_fn/fns/
literal.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Formatter;
5
6use prost::Message;
7use vortex_error::VortexResult;
8use vortex_error::vortex_err;
9use vortex_proto::expr as pb;
10use vortex_session::VortexSession;
11use vortex_session::registry::CachedId;
12
13use crate::ArrayRef;
14use crate::ExecutionCtx;
15use crate::IntoArray;
16use crate::arrays::ConstantArray;
17use crate::dtype::DType;
18use crate::expr::Expression;
19use crate::scalar::Scalar;
20use crate::scalar_fn::Arity;
21use crate::scalar_fn::ChildName;
22use crate::scalar_fn::ExecutionArgs;
23use crate::scalar_fn::ScalarFnId;
24use crate::scalar_fn::ScalarFnVTable;
25use crate::scalar_fn::ScalarFnVTableExt;
26
27fn lit(value: impl Into<Scalar>) -> Expression {
28    Literal.new_expr(value.into(), [])
29}
30
31/// Expression that represents a literal scalar value.
32#[derive(Clone)]
33pub struct Literal;
34
35impl ScalarFnVTable for Literal {
36    type Options = Scalar;
37
38    fn id(&self) -> ScalarFnId {
39        static ID: CachedId = CachedId::new("vortex.literal");
40        *ID
41    }
42
43    fn serialize(&self, instance: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
44        Ok(Some(
45            pb::LiteralOpts {
46                value: Some(instance.into()),
47            }
48            .encode_to_vec(),
49        ))
50    }
51
52    fn deserialize(
53        &self,
54        _metadata: &[u8],
55        session: &VortexSession,
56    ) -> VortexResult<Self::Options> {
57        let ops = pb::LiteralOpts::decode(_metadata)?;
58        Scalar::from_proto(
59            ops.value
60                .as_ref()
61                .ok_or_else(|| vortex_err!("Literal metadata missing value"))?,
62            session,
63        )
64    }
65
66    fn arity(&self, _options: &Self::Options) -> Arity {
67        Arity::Exact(0)
68    }
69
70    fn child_name(&self, _instance: &Self::Options, _child_idx: usize) -> ChildName {
71        unreachable!()
72    }
73
74    fn fmt_sql(
75        &self,
76        scalar: &Scalar,
77        _expr: &Expression,
78        f: &mut Formatter<'_>,
79    ) -> std::fmt::Result {
80        write!(f, "{}", scalar)
81    }
82
83    fn return_dtype(&self, options: &Self::Options, _arg_dtypes: &[DType]) -> VortexResult<DType> {
84        Ok(options.dtype().clone())
85    }
86
87    fn execute(
88        &self,
89        scalar: &Scalar,
90        args: &dyn ExecutionArgs,
91        _ctx: &mut ExecutionCtx,
92    ) -> VortexResult<ArrayRef> {
93        Ok(ConstantArray::new(scalar.clone(), args.row_count()).into_array())
94    }
95
96    fn validity(
97        &self,
98        scalar: &Scalar,
99        _expression: &Expression,
100    ) -> VortexResult<Option<Expression>> {
101        Ok(Some(lit(scalar.is_valid())))
102    }
103
104    fn is_null_sensitive(&self, _instance: &Self::Options) -> bool {
105        false
106    }
107
108    fn is_fallible(&self, _instance: &Self::Options) -> bool {
109        false
110    }
111}
112
113#[cfg(test)]
114mod tests {
115    use crate::dtype::DType;
116    use crate::dtype::Nullability;
117    use crate::dtype::PType;
118    use crate::dtype::StructFields;
119    use crate::expr::lit;
120    use crate::expr::test_harness;
121    use crate::scalar::Scalar;
122
123    #[test]
124    fn dtype() {
125        let dtype = test_harness::struct_dtype();
126
127        assert_eq!(
128            lit(10).return_dtype(&dtype).unwrap(),
129            DType::Primitive(PType::I32, Nullability::NonNullable)
130        );
131        assert_eq!(
132            lit(i64::MAX).return_dtype(&dtype).unwrap(),
133            DType::Primitive(PType::I64, Nullability::NonNullable)
134        );
135        assert_eq!(
136            lit(true).return_dtype(&dtype).unwrap(),
137            DType::Bool(Nullability::NonNullable)
138        );
139        assert_eq!(
140            lit(Scalar::null(DType::Bool(Nullability::Nullable)))
141                .return_dtype(&dtype)
142                .unwrap(),
143            DType::Bool(Nullability::Nullable)
144        );
145
146        let sdtype = DType::Struct(
147            StructFields::new(
148                ["dog", "cat"].into(),
149                vec![
150                    DType::Primitive(PType::U32, Nullability::NonNullable),
151                    DType::Utf8(Nullability::NonNullable),
152                ],
153            ),
154            Nullability::NonNullable,
155        );
156        assert_eq!(
157            lit(Scalar::struct_(
158                sdtype.clone(),
159                vec![Scalar::from(32_u32), Scalar::from("rufus".to_string())]
160            ))
161            .return_dtype(&dtype)
162            .unwrap(),
163            sdtype
164        );
165    }
166}