vortex_array/scalar_fn/fns/
literal.rs1use std::fmt::Formatter;
5
6use prost::Message;
7use vortex_error::VortexResult;
8use vortex_error::vortex_err;
9use vortex_proto::expr as pb;
10use vortex_session::VortexSession;
11use vortex_session::registry::CachedId;
12
13use crate::ArrayRef;
14use crate::ExecutionCtx;
15use crate::IntoArray;
16use crate::arrays::ConstantArray;
17use crate::dtype::DType;
18use crate::expr::Expression;
19use crate::expr::StatsCatalog;
20use crate::expr::stats::Stat;
21use crate::match_each_float_ptype;
22use crate::scalar::Scalar;
23use crate::scalar_fn::Arity;
24use crate::scalar_fn::ChildName;
25use crate::scalar_fn::ExecutionArgs;
26use crate::scalar_fn::ScalarFnId;
27use crate::scalar_fn::ScalarFnVTable;
28use crate::scalar_fn::ScalarFnVTableExt;
29
30fn lit(value: impl Into<Scalar>) -> Expression {
31 Literal.new_expr(value.into(), [])
32}
33
34#[derive(Clone)]
36pub struct Literal;
37
38impl ScalarFnVTable for Literal {
39 type Options = Scalar;
40
41 fn id(&self) -> ScalarFnId {
42 static ID: CachedId = CachedId::new("vortex.literal");
43 *ID
44 }
45
46 fn serialize(&self, instance: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
47 Ok(Some(
48 pb::LiteralOpts {
49 value: Some(instance.into()),
50 }
51 .encode_to_vec(),
52 ))
53 }
54
55 fn deserialize(
56 &self,
57 _metadata: &[u8],
58 session: &VortexSession,
59 ) -> VortexResult<Self::Options> {
60 let ops = pb::LiteralOpts::decode(_metadata)?;
61 Scalar::from_proto(
62 ops.value
63 .as_ref()
64 .ok_or_else(|| vortex_err!("Literal metadata missing value"))?,
65 session,
66 )
67 }
68
69 fn arity(&self, _options: &Self::Options) -> Arity {
70 Arity::Exact(0)
71 }
72
73 fn child_name(&self, _instance: &Self::Options, _child_idx: usize) -> ChildName {
74 unreachable!()
75 }
76
77 fn fmt_sql(
78 &self,
79 scalar: &Scalar,
80 _expr: &Expression,
81 f: &mut Formatter<'_>,
82 ) -> std::fmt::Result {
83 write!(f, "{}", scalar)
84 }
85
86 fn return_dtype(&self, options: &Self::Options, _arg_dtypes: &[DType]) -> VortexResult<DType> {
87 Ok(options.dtype().clone())
88 }
89
90 fn execute(
91 &self,
92 scalar: &Scalar,
93 args: &dyn ExecutionArgs,
94 _ctx: &mut ExecutionCtx,
95 ) -> VortexResult<ArrayRef> {
96 Ok(ConstantArray::new(scalar.clone(), args.row_count()).into_array())
97 }
98
99 fn stat_expression(
100 &self,
101 scalar: &Scalar,
102 _expr: &Expression,
103 stat: Stat,
104 _catalog: &dyn StatsCatalog,
105 ) -> Option<Expression> {
106 match stat {
112 Stat::Min | Stat::Max => Some(lit(scalar.clone())),
113 Stat::IsConstant => Some(lit(true)),
114 Stat::NaNCount => {
115 let value = scalar.as_primitive_opt()?;
118 if !value.ptype().is_float() {
119 return None;
120 }
121
122 match_each_float_ptype!(value.ptype(), |T| {
123 if value.typed_value::<T>().is_some_and(|v| v.is_nan()) {
124 Some(lit(1u64))
125 } else {
126 Some(lit(0u64))
127 }
128 })
129 }
130 Stat::NullCount => {
131 if scalar.is_null() {
132 Some(lit(1u64))
133 } else {
134 Some(lit(0u64))
135 }
136 }
137 Stat::IsSorted | Stat::IsStrictSorted | Stat::Sum | Stat::UncompressedSizeInBytes => {
138 None
139 }
140 }
141 }
142
143 fn validity(
144 &self,
145 scalar: &Scalar,
146 _expression: &Expression,
147 ) -> VortexResult<Option<Expression>> {
148 Ok(Some(lit(scalar.is_valid())))
149 }
150
151 fn is_null_sensitive(&self, _instance: &Self::Options) -> bool {
152 false
153 }
154
155 fn is_fallible(&self, _instance: &Self::Options) -> bool {
156 false
157 }
158}
159
160#[cfg(test)]
161mod tests {
162 use crate::dtype::DType;
163 use crate::dtype::Nullability;
164 use crate::dtype::PType;
165 use crate::dtype::StructFields;
166 use crate::expr::lit;
167 use crate::expr::test_harness;
168 use crate::scalar::Scalar;
169
170 #[test]
171 fn dtype() {
172 let dtype = test_harness::struct_dtype();
173
174 assert_eq!(
175 lit(10).return_dtype(&dtype).unwrap(),
176 DType::Primitive(PType::I32, Nullability::NonNullable)
177 );
178 assert_eq!(
179 lit(i64::MAX).return_dtype(&dtype).unwrap(),
180 DType::Primitive(PType::I64, Nullability::NonNullable)
181 );
182 assert_eq!(
183 lit(true).return_dtype(&dtype).unwrap(),
184 DType::Bool(Nullability::NonNullable)
185 );
186 assert_eq!(
187 lit(Scalar::null(DType::Bool(Nullability::Nullable)))
188 .return_dtype(&dtype)
189 .unwrap(),
190 DType::Bool(Nullability::Nullable)
191 );
192
193 let sdtype = DType::Struct(
194 StructFields::new(
195 ["dog", "cat"].into(),
196 vec![
197 DType::Primitive(PType::U32, Nullability::NonNullable),
198 DType::Utf8(Nullability::NonNullable),
199 ],
200 ),
201 Nullability::NonNullable,
202 );
203 assert_eq!(
204 lit(Scalar::struct_(
205 sdtype.clone(),
206 vec![Scalar::from(32_u32), Scalar::from("rufus".to_string())]
207 ))
208 .return_dtype(&dtype)
209 .unwrap(),
210 sdtype
211 );
212 }
213}