use std::fmt::{Debug, Display, Formatter};
use std::hash::Hash;
use std::sync::Arc;
use itertools::Itertools;
use DType::*;
use crate::nullability::Nullability;
use crate::{ExtDType, PType, StructDType};
pub type FieldName = Arc<str>;
pub type FieldNames = Arc<[FieldName]>;
#[derive(Debug, Clone, PartialOrd, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub enum DType {
Null,
Bool(Nullability),
Primitive(PType, Nullability),
Utf8(Nullability),
Binary(Nullability),
Struct(StructDType, Nullability),
List(Arc<DType>, Nullability),
Extension(Arc<ExtDType>),
}
impl DType {
pub const BYTES: Self = Primitive(PType::U8, Nullability::NonNullable);
pub fn nullability(&self) -> Nullability {
self.is_nullable().into()
}
pub fn is_nullable(&self) -> bool {
use crate::nullability::Nullability::*;
match self {
Null => true,
Bool(n) => matches!(n, Nullable),
Primitive(_, n) => matches!(n, Nullable),
Utf8(n) => matches!(n, Nullable),
Binary(n) => matches!(n, Nullable),
Struct(_, n) => matches!(n, Nullable),
List(_, n) => matches!(n, Nullable),
Extension(ext_dtype) => ext_dtype.storage_dtype().is_nullable(),
}
}
pub fn as_nonnullable(&self) -> Self {
self.with_nullability(Nullability::NonNullable)
}
pub fn as_nullable(&self) -> Self {
self.with_nullability(Nullability::Nullable)
}
pub fn with_nullability(&self, nullability: Nullability) -> Self {
match self {
Null => Null,
Bool(_) => Bool(nullability),
Primitive(p, _) => Primitive(*p, nullability),
Utf8(_) => Utf8(nullability),
Binary(_) => Binary(nullability),
Struct(st, _) => Struct(st.clone(), nullability),
List(c, _) => List(c.clone(), nullability),
Extension(ext) => Extension(Arc::new(ext.with_nullability(nullability))),
}
}
pub fn eq_ignore_nullability(&self, other: &Self) -> bool {
match (self, other) {
(Null, Null) => true,
(Null, _) => false,
(Bool(_), Bool(_)) => true,
(Bool(_), _) => false,
(Primitive(lhs_ptype, _), Primitive(rhs_ptype, _)) => lhs_ptype == rhs_ptype,
(Primitive(..), _) => false,
(Utf8(_), Utf8(_)) => true,
(Utf8(_), _) => false,
(Binary(_), Binary(_)) => true,
(Binary(_), _) => false,
(List(lhs_dtype, _), List(rhs_dtype, _)) => lhs_dtype.eq_ignore_nullability(rhs_dtype),
(List(..), _) => false,
(Struct(lhs_dtype, _), Struct(rhs_dtype, _)) => {
(lhs_dtype.names() == rhs_dtype.names())
&& (lhs_dtype
.dtypes()
.zip_eq(rhs_dtype.dtypes())
.all(|(l, r)| l.eq_ignore_nullability(&r)))
}
(Struct(..), _) => false,
(Extension(lhs_extdtype), Extension(rhs_extdtype)) => {
lhs_extdtype.as_ref().eq_ignore_nullability(rhs_extdtype)
}
(Extension(_), _) => false,
}
}
pub fn is_struct(&self) -> bool {
matches!(self, Struct(_, _))
}
pub fn is_unsigned_int(&self) -> bool {
PType::try_from(self).is_ok_and(PType::is_unsigned_int)
}
pub fn is_signed_int(&self) -> bool {
PType::try_from(self).is_ok_and(PType::is_signed_int)
}
pub fn is_int(&self) -> bool {
PType::try_from(self).is_ok_and(PType::is_int)
}
pub fn is_float(&self) -> bool {
PType::try_from(self).is_ok_and(PType::is_float)
}
pub fn is_boolean(&self) -> bool {
matches!(self, Bool(_))
}
pub fn as_struct(&self) -> Option<&StructDType> {
match self {
Struct(s, _) => Some(s),
_ => None,
}
}
pub fn as_list_element(&self) -> Option<&DType> {
match self {
List(s, _) => Some(s.as_ref()),
_ => None,
}
}
}
impl Display for DType {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
Null => write!(f, "null"),
Bool(n) => write!(f, "bool{}", n),
Primitive(pt, n) => write!(f, "{}{}", pt, n),
Utf8(n) => write!(f, "utf8{}", n),
Binary(n) => write!(f, "binary{}", n),
Struct(sdt, n) => write!(
f,
"{{{}}}{}",
sdt.names()
.iter()
.zip(sdt.dtypes())
.map(|(n, dt)| format!("{}={}", n, dt))
.join(", "),
n
),
List(edt, n) => write!(f, "list({}){}", edt, n),
Extension(ext) => write!(
f,
"ext({}, {}{}){}",
ext.id(),
ext.storage_dtype()
.with_nullability(Nullability::NonNullable),
ext.metadata()
.map(|m| format!(", {:?}", m))
.unwrap_or_else(|| "".to_string()),
ext.storage_dtype().nullability(),
),
}
}
}
#[cfg(test)]
mod test {
use std::mem;
use crate::dtype::DType;
#[test]
fn size_of() {
assert_eq!(mem::size_of::<DType>(), 40);
}
}