1use std::any::Any;
2use std::fmt::Display;
3use std::sync::Arc;
4
5use vortex_array::arrays::ConstantArray;
6use vortex_array::{Array, ArrayRef};
7use vortex_dtype::DType;
8use vortex_error::VortexResult;
9use vortex_scalar::Scalar;
10
11use crate::{ExprRef, VortexExpr};
12
13#[derive(Debug, PartialEq, Eq, Hash)]
14pub struct Literal {
15 value: Scalar,
16}
17
18impl Literal {
19 pub fn new_expr(value: impl Into<Scalar>) -> ExprRef {
20 Arc::new(Self {
21 value: value.into(),
22 })
23 }
24
25 pub fn value(&self) -> &Scalar {
26 &self.value
27 }
28
29 pub fn maybe_from(expr: &ExprRef) -> Option<&Literal> {
30 expr.as_any().downcast_ref::<Literal>()
31 }
32}
33
34impl Display for Literal {
35 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
36 write!(f, "{}", self.value)
37 }
38}
39
40impl VortexExpr for Literal {
41 fn as_any(&self) -> &dyn Any {
42 self
43 }
44
45 fn unchecked_evaluate(&self, batch: &dyn Array) -> VortexResult<ArrayRef> {
46 Ok(ConstantArray::new(self.value.clone(), batch.len()).into_array())
47 }
48
49 fn children(&self) -> Vec<&ExprRef> {
50 vec![]
51 }
52
53 fn replacing_children(self: Arc<Self>, children: Vec<ExprRef>) -> ExprRef {
54 assert_eq!(children.len(), 0);
55 self
56 }
57
58 fn return_dtype(&self, _scope_dtype: &DType) -> VortexResult<DType> {
59 Ok(self.value.dtype().clone())
60 }
61}
62
63pub fn lit(value: impl Into<Scalar>) -> ExprRef {
82 Literal::new_expr(value.into())
83}
84
85#[cfg(test)]
86mod tests {
87 use std::sync::Arc;
88
89 use vortex_dtype::{DType, Nullability, PType, StructDType};
90 use vortex_scalar::Scalar;
91
92 use crate::{lit, test_harness};
93
94 #[test]
95 fn dtype() {
96 let dtype = test_harness::struct_dtype();
97
98 assert_eq!(
99 lit(10).return_dtype(&dtype).unwrap(),
100 DType::Primitive(PType::I32, Nullability::NonNullable)
101 );
102 assert_eq!(
103 lit(0_u8).return_dtype(&dtype).unwrap(),
104 DType::Primitive(PType::U8, Nullability::NonNullable)
105 );
106 assert_eq!(
107 lit(0.0_f32).return_dtype(&dtype).unwrap(),
108 DType::Primitive(PType::F32, Nullability::NonNullable)
109 );
110 assert_eq!(
111 lit(i64::MAX).return_dtype(&dtype).unwrap(),
112 DType::Primitive(PType::I64, Nullability::NonNullable)
113 );
114 assert_eq!(
115 lit(true).return_dtype(&dtype).unwrap(),
116 DType::Bool(Nullability::NonNullable)
117 );
118 assert_eq!(
119 lit(Scalar::null(DType::Bool(Nullability::Nullable)))
120 .return_dtype(&dtype)
121 .unwrap(),
122 DType::Bool(Nullability::Nullable)
123 );
124
125 let sdtype = DType::Struct(
126 Arc::new(StructDType::new(
127 Arc::from([Arc::from("dog"), Arc::from("cat")]),
128 vec![
129 DType::Primitive(PType::U32, Nullability::NonNullable),
130 DType::Utf8(Nullability::NonNullable),
131 ],
132 )),
133 Nullability::NonNullable,
134 );
135 assert_eq!(
136 lit(Scalar::struct_(
137 sdtype.clone(),
138 vec![Scalar::from(32_u32), Scalar::from("rufus".to_string())]
139 ))
140 .return_dtype(&dtype)
141 .unwrap(),
142 sdtype
143 );
144 }
145}