1use std::fmt::{Debug, Display, Formatter};
2use std::hash::Hash;
3use std::sync::Arc;
4
5use DType::*;
6use itertools::Itertools;
7use static_assertions::const_assert_eq;
8
9use crate::decimal::DecimalDType;
10use crate::nullability::Nullability;
11use crate::{ExtDType, PType, StructDType};
12
13pub type FieldName = Arc<str>;
15pub type FieldNames = Arc<[FieldName]>;
17
18#[derive(Debug, Clone, PartialEq, Eq, Hash)]
23#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
24pub enum DType {
25 Null,
27 Bool(Nullability),
29 Primitive(PType, Nullability),
31 Decimal(DecimalDType, Nullability),
33 Utf8(Nullability),
35 Binary(Nullability),
37 Struct(Arc<StructDType>, Nullability),
39 List(Arc<DType>, Nullability),
41 Extension(Arc<ExtDType>),
43}
44
45#[cfg(not(target_arch = "wasm32"))]
46const_assert_eq!(size_of::<DType>(), 16);
47
48#[cfg(target_arch = "wasm32")]
49const_assert_eq!(size_of::<DType>(), 8);
50
51impl DType {
52 pub const BYTES: Self = Primitive(PType::U8, Nullability::NonNullable);
54
55 pub fn nullability(&self) -> Nullability {
57 self.is_nullable().into()
58 }
59
60 pub fn is_nullable(&self) -> bool {
62 use crate::nullability::Nullability::*;
63
64 match self {
65 Null => true,
66 Extension(ext_dtype) => ext_dtype.storage_dtype().is_nullable(),
67 Bool(n)
68 | Primitive(_, n)
69 | Decimal(_, n)
70 | Utf8(n)
71 | Binary(n)
72 | Struct(_, n)
73 | List(_, n) => matches!(n, Nullable),
74 }
75 }
76
77 pub fn as_nonnullable(&self) -> Self {
79 self.with_nullability(Nullability::NonNullable)
80 }
81
82 pub fn as_nullable(&self) -> Self {
84 self.with_nullability(Nullability::Nullable)
85 }
86
87 pub fn with_nullability(&self, nullability: Nullability) -> Self {
89 match self {
90 Null => Null,
91 Bool(_) => Bool(nullability),
92 Primitive(p, _) => Primitive(*p, nullability),
93 Decimal(d, _) => Decimal(*d, nullability),
94 Utf8(_) => Utf8(nullability),
95 Binary(_) => Binary(nullability),
96 Struct(st, _) => Struct(st.clone(), nullability),
97 List(c, _) => List(c.clone(), nullability),
98 Extension(ext) => Extension(Arc::new(ext.with_nullability(nullability))),
99 }
100 }
101
102 pub fn eq_ignore_nullability(&self, other: &Self) -> bool {
104 match (self, other) {
105 (Null, Null) => true,
106 (Bool(_), Bool(_)) => true,
107 (Primitive(lhs_ptype, _), Primitive(rhs_ptype, _)) => lhs_ptype == rhs_ptype,
108 (Decimal(lhs, _), Decimal(rhs, _)) => lhs == rhs,
109 (Utf8(_), Utf8(_)) => true,
110 (Binary(_), Binary(_)) => true,
111 (List(lhs_dtype, _), List(rhs_dtype, _)) => lhs_dtype.eq_ignore_nullability(rhs_dtype),
112 (Struct(lhs_dtype, _), Struct(rhs_dtype, _)) => {
113 (lhs_dtype.names() == rhs_dtype.names())
114 && (lhs_dtype
115 .fields()
116 .zip_eq(rhs_dtype.fields())
117 .all(|(l, r)| l.eq_ignore_nullability(&r)))
118 }
119 (Extension(lhs_extdtype), Extension(rhs_extdtype)) => {
120 lhs_extdtype.as_ref().eq_ignore_nullability(rhs_extdtype)
121 }
122 _ => false,
123 }
124 }
125
126 pub fn is_struct(&self) -> bool {
128 matches!(self, Struct(_, _))
129 }
130
131 pub fn is_unsigned_int(&self) -> bool {
133 PType::try_from(self).is_ok_and(PType::is_unsigned_int)
134 }
135
136 pub fn is_signed_int(&self) -> bool {
138 PType::try_from(self).is_ok_and(PType::is_signed_int)
139 }
140
141 pub fn is_int(&self) -> bool {
143 PType::try_from(self).is_ok_and(PType::is_int)
144 }
145
146 pub fn is_float(&self) -> bool {
148 PType::try_from(self).is_ok_and(PType::is_float)
149 }
150
151 pub fn is_boolean(&self) -> bool {
153 matches!(self, Bool(_))
154 }
155
156 pub fn is_extension(&self) -> bool {
158 matches!(self, Extension(_))
159 }
160
161 pub fn as_struct(&self) -> Option<&Arc<StructDType>> {
163 match self {
164 Struct(s, _) => Some(s),
165 _ => None,
166 }
167 }
168
169 pub fn as_list_element(&self) -> Option<&DType> {
171 match self {
172 List(s, _) => Some(s.as_ref()),
173 _ => None,
174 }
175 }
176}
177
178impl Display for DType {
179 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
180 match self {
181 Null => write!(f, "null"),
182 Bool(n) => write!(f, "bool{}", n),
183 Primitive(pt, n) => write!(f, "{}{}", pt, n),
184 Decimal(dt, n) => write!(f, "decimal({},{}){}", dt.precision(), dt.scale(), n),
185 Utf8(n) => write!(f, "utf8{}", n),
186 Binary(n) => write!(f, "binary{}", n),
187 Struct(sdt, n) => write!(
188 f,
189 "{{{}}}{}",
190 sdt.names()
191 .iter()
192 .zip(sdt.fields())
193 .map(|(n, dt)| format!("{}={}", n, dt))
194 .join(", "),
195 n
196 ),
197 List(edt, n) => write!(f, "list({}){}", edt, n),
198 Extension(ext) => write!(
199 f,
200 "ext({}, {}{}){}",
201 ext.id(),
202 ext.storage_dtype()
203 .with_nullability(Nullability::NonNullable),
204 ext.metadata()
205 .map(|m| format!(", {:?}", m))
206 .unwrap_or_else(|| "".to_string()),
207 ext.storage_dtype().nullability(),
208 ),
209 }
210 }
211}