1use std::any::Any;
2use std::fmt::Display;
3use std::sync::Arc;
4
5use vortex_array::arrays::ConstantArray;
6use vortex_array::{ArrayRef, IntoArray};
7use vortex_dtype::DType;
8use vortex_error::VortexResult;
9use vortex_scalar::Scalar;
10
11use crate::{AnalysisExpr, ExprRef, Scope, ScopeDType, StatsCatalog, VortexExpr};
12
13#[derive(Debug, PartialEq, Eq, Hash)]
14pub struct Literal {
15 value: Scalar,
16}
17
18impl Literal {
19 pub fn new_expr(value: impl Into<Scalar>) -> ExprRef {
20 Arc::new(Self {
21 value: value.into(),
22 })
23 }
24
25 pub fn value(&self) -> &Scalar {
26 &self.value
27 }
28
29 pub fn maybe_from(expr: &ExprRef) -> Option<&Literal> {
30 expr.as_any().downcast_ref::<Literal>()
31 }
32}
33
34impl Display for Literal {
35 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
36 write!(f, "{}", self.value)
37 }
38}
39
40#[cfg(feature = "proto")]
41pub(crate) mod proto {
42 use kind::Kind;
43 use vortex_error::{VortexResult, vortex_bail, vortex_err};
44 use vortex_proto::expr::kind;
45 use vortex_scalar::Scalar;
46
47 use crate::{ExprDeserialize, ExprRef, ExprSerializable, Id, Literal};
48
49 pub(crate) struct LiteralSerde;
50
51 impl Id for LiteralSerde {
52 fn id(&self) -> &'static str {
53 "literal"
54 }
55 }
56
57 impl ExprDeserialize for LiteralSerde {
58 fn deserialize(&self, kind: &Kind, _children: Vec<ExprRef>) -> VortexResult<ExprRef> {
59 let Kind::Literal(value) = kind else {
60 vortex_bail!("Expected literal kind");
61 };
62 let scalar: Scalar = value
63 .value
64 .as_ref()
65 .ok_or_else(|| vortex_err!("empty literal scalar"))?
66 .try_into()?;
67 Ok(Literal::new_expr(scalar))
68 }
69 }
70
71 impl ExprSerializable for Literal {
72 fn id(&self) -> &'static str {
73 LiteralSerde.id()
74 }
75
76 fn serialize_kind(&self) -> VortexResult<Kind> {
77 Ok(Kind::Literal(kind::Literal {
78 value: Some((&self.value).into()),
79 }))
80 }
81 }
82}
83
84impl AnalysisExpr for Literal {
85 fn max(&self, _catalog: &mut dyn StatsCatalog) -> Option<ExprRef> {
86 Some(lit(self.value.clone()))
87 }
88
89 fn min(&self, _catalog: &mut dyn StatsCatalog) -> Option<ExprRef> {
90 Some(lit(self.value.clone()))
91 }
92}
93
94impl VortexExpr for Literal {
95 fn as_any(&self) -> &dyn Any {
96 self
97 }
98
99 fn unchecked_evaluate(&self, ctx: &Scope) -> VortexResult<ArrayRef> {
100 Ok(ConstantArray::new(self.value.clone(), ctx.len()).into_array())
101 }
102
103 fn children(&self) -> Vec<&ExprRef> {
104 vec![]
105 }
106
107 fn replacing_children(self: Arc<Self>, children: Vec<ExprRef>) -> ExprRef {
108 assert_eq!(children.len(), 0);
109 self
110 }
111
112 fn return_dtype(&self, _ctx: &ScopeDType) -> VortexResult<DType> {
113 Ok(self.value.dtype().clone())
114 }
115}
116
117pub fn lit(value: impl Into<Scalar>) -> ExprRef {
136 Literal::new_expr(value.into())
137}
138
139#[cfg(test)]
140mod tests {
141 use std::sync::Arc;
142
143 use vortex_dtype::{DType, Nullability, PType, StructFields};
144 use vortex_scalar::Scalar;
145
146 use crate::{ScopeDType, lit, test_harness};
147
148 #[test]
149 fn dtype() {
150 let dtype = test_harness::struct_dtype();
151
152 assert_eq!(
153 lit(10)
154 .return_dtype(&ScopeDType::new(dtype.clone()))
155 .unwrap(),
156 DType::Primitive(PType::I32, Nullability::NonNullable)
157 );
158 assert_eq!(
159 lit(i64::MAX)
160 .return_dtype(&ScopeDType::new(dtype.clone()))
161 .unwrap(),
162 DType::Primitive(PType::I64, Nullability::NonNullable)
163 );
164 assert_eq!(
165 lit(true)
166 .return_dtype(&ScopeDType::new(dtype.clone()))
167 .unwrap(),
168 DType::Bool(Nullability::NonNullable)
169 );
170 assert_eq!(
171 lit(Scalar::null(DType::Bool(Nullability::Nullable)))
172 .return_dtype(&ScopeDType::new(dtype.clone()))
173 .unwrap(),
174 DType::Bool(Nullability::Nullable)
175 );
176
177 let sdtype = DType::Struct(
178 Arc::new(StructFields::new(
179 Arc::from([Arc::from("dog"), Arc::from("cat")]),
180 vec![
181 DType::Primitive(PType::U32, Nullability::NonNullable),
182 DType::Utf8(Nullability::NonNullable),
183 ],
184 )),
185 Nullability::NonNullable,
186 );
187 assert_eq!(
188 lit(Scalar::struct_(
189 sdtype.clone(),
190 vec![Scalar::from(32_u32), Scalar::from("rufus".to_string())]
191 ))
192 .return_dtype(&ScopeDType::new(dtype))
193 .unwrap(),
194 sdtype
195 );
196 }
197}