1use std::any::Any;
2use std::fmt::Display;
3use std::sync::Arc;
4
5use vortex_array::arrays::ConstantArray;
6use vortex_array::{Array, ArrayRef};
7use vortex_dtype::DType;
8use vortex_error::VortexResult;
9use vortex_scalar::Scalar;
10
11use crate::{ExprRef, VortexExpr};
12
13#[derive(Debug, PartialEq, Eq, Hash)]
14pub struct Literal {
15 value: Scalar,
16}
17
18impl Literal {
19 pub fn new_expr(value: impl Into<Scalar>) -> ExprRef {
20 Arc::new(Self {
21 value: value.into(),
22 })
23 }
24
25 pub fn value(&self) -> &Scalar {
26 &self.value
27 }
28
29 pub fn maybe_from(expr: &ExprRef) -> Option<&Literal> {
30 expr.as_any().downcast_ref::<Literal>()
31 }
32}
33
34impl Display for Literal {
35 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
36 write!(f, "{}", self.value)
37 }
38}
39
40#[cfg(feature = "proto")]
41pub(crate) mod proto {
42 use kind::Kind;
43 use vortex_error::{VortexResult, vortex_bail, vortex_err};
44 use vortex_proto::expr::kind;
45 use vortex_scalar::Scalar;
46
47 use crate::{ExprDeserialize, ExprRef, ExprSerializable, Id, Literal};
48
49 pub(crate) struct LiteralSerde;
50
51 impl Id for LiteralSerde {
52 fn id(&self) -> &'static str {
53 "literal"
54 }
55 }
56
57 impl ExprDeserialize for LiteralSerde {
58 fn deserialize(&self, kind: &Kind, _children: Vec<ExprRef>) -> VortexResult<ExprRef> {
59 let Kind::Literal(value) = kind else {
60 vortex_bail!("Expected literal kind");
61 };
62 let scalar: Scalar = value
63 .value
64 .as_ref()
65 .ok_or_else(|| vortex_err!("empty literal scalar"))?
66 .try_into()?;
67 Ok(Literal::new_expr(scalar))
68 }
69 }
70
71 impl ExprSerializable for Literal {
72 fn id(&self) -> &'static str {
73 LiteralSerde.id()
74 }
75
76 fn serialize_kind(&self) -> VortexResult<Kind> {
77 Ok(Kind::Literal(kind::Literal {
78 value: Some((&self.value).into()),
79 }))
80 }
81 }
82}
83
84impl VortexExpr for Literal {
85 fn as_any(&self) -> &dyn Any {
86 self
87 }
88
89 fn unchecked_evaluate(&self, batch: &dyn Array) -> VortexResult<ArrayRef> {
90 Ok(ConstantArray::new(self.value.clone(), batch.len()).into_array())
91 }
92
93 fn children(&self) -> Vec<&ExprRef> {
94 vec![]
95 }
96
97 fn replacing_children(self: Arc<Self>, children: Vec<ExprRef>) -> ExprRef {
98 assert_eq!(children.len(), 0);
99 self
100 }
101
102 fn return_dtype(&self, _scope_dtype: &DType) -> VortexResult<DType> {
103 Ok(self.value.dtype().clone())
104 }
105}
106
107pub fn lit(value: impl Into<Scalar>) -> ExprRef {
126 Literal::new_expr(value.into())
127}
128
129#[cfg(test)]
130mod tests {
131 use std::sync::Arc;
132
133 use vortex_dtype::{DType, Nullability, PType, StructDType};
134 use vortex_scalar::Scalar;
135
136 use crate::{lit, test_harness};
137
138 #[test]
139 fn dtype() {
140 let dtype = test_harness::struct_dtype();
141
142 assert_eq!(
143 lit(10).return_dtype(&dtype).unwrap(),
144 DType::Primitive(PType::I32, Nullability::NonNullable)
145 );
146 assert_eq!(
147 lit(0_u8).return_dtype(&dtype).unwrap(),
148 DType::Primitive(PType::U8, Nullability::NonNullable)
149 );
150 assert_eq!(
151 lit(0.0_f32).return_dtype(&dtype).unwrap(),
152 DType::Primitive(PType::F32, Nullability::NonNullable)
153 );
154 assert_eq!(
155 lit(i64::MAX).return_dtype(&dtype).unwrap(),
156 DType::Primitive(PType::I64, Nullability::NonNullable)
157 );
158 assert_eq!(
159 lit(true).return_dtype(&dtype).unwrap(),
160 DType::Bool(Nullability::NonNullable)
161 );
162 assert_eq!(
163 lit(Scalar::null(DType::Bool(Nullability::Nullable)))
164 .return_dtype(&dtype)
165 .unwrap(),
166 DType::Bool(Nullability::Nullable)
167 );
168
169 let sdtype = DType::Struct(
170 Arc::new(StructDType::new(
171 Arc::from([Arc::from("dog"), Arc::from("cat")]),
172 vec![
173 DType::Primitive(PType::U32, Nullability::NonNullable),
174 DType::Utf8(Nullability::NonNullable),
175 ],
176 )),
177 Nullability::NonNullable,
178 );
179 assert_eq!(
180 lit(Scalar::struct_(
181 sdtype.clone(),
182 vec![Scalar::from(32_u32), Scalar::from("rufus".to_string())]
183 ))
184 .return_dtype(&dtype)
185 .unwrap(),
186 sdtype
187 );
188 }
189}