vortex_array/scalar_fn/fns/
literal.rs1use std::fmt::Formatter;
5
6use prost::Message;
7use vortex_error::VortexResult;
8use vortex_error::vortex_err;
9use vortex_proto::expr as pb;
10use vortex_session::VortexSession;
11
12use crate::ArrayRef;
13use crate::ExecutionCtx;
14use crate::IntoArray;
15use crate::arrays::ConstantArray;
16use crate::dtype::DType;
17use crate::expr::Expression;
18use crate::expr::StatsCatalog;
19use crate::expr::stats::Stat;
20use crate::match_each_float_ptype;
21use crate::scalar::Scalar;
22use crate::scalar_fn::Arity;
23use crate::scalar_fn::ChildName;
24use crate::scalar_fn::ExecutionArgs;
25use crate::scalar_fn::ScalarFnId;
26use crate::scalar_fn::ScalarFnVTable;
27use crate::scalar_fn::ScalarFnVTableExt;
28
29fn lit(value: impl Into<Scalar>) -> Expression {
30 Literal.new_expr(value.into(), [])
31}
32
33#[derive(Clone)]
35pub struct Literal;
36
37impl ScalarFnVTable for Literal {
38 type Options = Scalar;
39
40 fn id(&self) -> ScalarFnId {
41 ScalarFnId::new_ref("vortex.literal")
42 }
43
44 fn serialize(&self, instance: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
45 Ok(Some(
46 pb::LiteralOpts {
47 value: Some(instance.into()),
48 }
49 .encode_to_vec(),
50 ))
51 }
52
53 fn deserialize(
54 &self,
55 _metadata: &[u8],
56 session: &VortexSession,
57 ) -> VortexResult<Self::Options> {
58 let ops = pb::LiteralOpts::decode(_metadata)?;
59 Scalar::from_proto(
60 ops.value
61 .as_ref()
62 .ok_or_else(|| vortex_err!("Literal metadata missing value"))?,
63 session,
64 )
65 }
66
67 fn arity(&self, _options: &Self::Options) -> Arity {
68 Arity::Exact(0)
69 }
70
71 fn child_name(&self, _instance: &Self::Options, _child_idx: usize) -> ChildName {
72 unreachable!()
73 }
74
75 fn fmt_sql(
76 &self,
77 scalar: &Scalar,
78 _expr: &Expression,
79 f: &mut Formatter<'_>,
80 ) -> std::fmt::Result {
81 write!(f, "{}", scalar)
82 }
83
84 fn return_dtype(&self, options: &Self::Options, _arg_dtypes: &[DType]) -> VortexResult<DType> {
85 Ok(options.dtype().clone())
86 }
87
88 fn execute(
89 &self,
90 scalar: &Scalar,
91 args: &dyn ExecutionArgs,
92 _ctx: &mut ExecutionCtx,
93 ) -> VortexResult<ArrayRef> {
94 Ok(ConstantArray::new(scalar.clone(), args.row_count()).into_array())
95 }
96
97 fn stat_expression(
98 &self,
99 scalar: &Scalar,
100 _expr: &Expression,
101 stat: Stat,
102 _catalog: &dyn StatsCatalog,
103 ) -> Option<Expression> {
104 match stat {
110 Stat::Min | Stat::Max => Some(lit(scalar.clone())),
111 Stat::IsConstant => Some(lit(true)),
112 Stat::NaNCount => {
113 let value = scalar.as_primitive_opt()?;
116 if !value.ptype().is_float() {
117 return None;
118 }
119
120 match_each_float_ptype!(value.ptype(), |T| {
121 if value.typed_value::<T>().is_some_and(|v| v.is_nan()) {
122 Some(lit(1u64))
123 } else {
124 Some(lit(0u64))
125 }
126 })
127 }
128 Stat::NullCount => {
129 if scalar.is_null() {
130 Some(lit(1u64))
131 } else {
132 Some(lit(0u64))
133 }
134 }
135 Stat::IsSorted | Stat::IsStrictSorted | Stat::Sum | Stat::UncompressedSizeInBytes => {
136 None
137 }
138 }
139 }
140
141 fn validity(
142 &self,
143 scalar: &Scalar,
144 _expression: &Expression,
145 ) -> VortexResult<Option<Expression>> {
146 Ok(Some(lit(scalar.is_valid())))
147 }
148
149 fn is_null_sensitive(&self, _instance: &Self::Options) -> bool {
150 false
151 }
152
153 fn is_fallible(&self, _instance: &Self::Options) -> bool {
154 false
155 }
156}
157
158#[cfg(test)]
159mod tests {
160 use crate::dtype::DType;
161 use crate::dtype::Nullability;
162 use crate::dtype::PType;
163 use crate::dtype::StructFields;
164 use crate::expr::lit;
165 use crate::expr::test_harness;
166 use crate::scalar::Scalar;
167
168 #[test]
169 fn dtype() {
170 let dtype = test_harness::struct_dtype();
171
172 assert_eq!(
173 lit(10).return_dtype(&dtype).unwrap(),
174 DType::Primitive(PType::I32, Nullability::NonNullable)
175 );
176 assert_eq!(
177 lit(i64::MAX).return_dtype(&dtype).unwrap(),
178 DType::Primitive(PType::I64, Nullability::NonNullable)
179 );
180 assert_eq!(
181 lit(true).return_dtype(&dtype).unwrap(),
182 DType::Bool(Nullability::NonNullable)
183 );
184 assert_eq!(
185 lit(Scalar::null(DType::Bool(Nullability::Nullable)))
186 .return_dtype(&dtype)
187 .unwrap(),
188 DType::Bool(Nullability::Nullable)
189 );
190
191 let sdtype = DType::Struct(
192 StructFields::new(
193 ["dog", "cat"].into(),
194 vec![
195 DType::Primitive(PType::U32, Nullability::NonNullable),
196 DType::Utf8(Nullability::NonNullable),
197 ],
198 ),
199 Nullability::NonNullable,
200 );
201 assert_eq!(
202 lit(Scalar::struct_(
203 sdtype.clone(),
204 vec![Scalar::from(32_u32), Scalar::from("rufus".to_string())]
205 ))
206 .return_dtype(&dtype)
207 .unwrap(),
208 sdtype
209 );
210 }
211}