use std::fmt::{Debug, Display, Formatter};
use std::hash::Hash;
use std::sync::Arc;
use itertools::Itertools;
use vortex_error::{vortex_bail, vortex_err, vortex_panic, VortexResult};
use DType::*;
use crate::field::Field;
use crate::nullability::Nullability;
use crate::{ExtDType, PType};
pub type FieldName = Arc<str>;
pub type FieldNames = Arc<[FieldName]>;
#[derive(Debug, Clone, PartialOrd, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub enum DType {
Null,
Bool(Nullability),
Primitive(PType, Nullability),
Utf8(Nullability),
Binary(Nullability),
Struct(StructDType, Nullability),
List(Arc<DType>, Nullability),
Extension(Arc<ExtDType>),
}
impl DType {
pub const BYTES: Self = Primitive(PType::U8, Nullability::NonNullable);
pub const IDX: Self = Primitive(PType::U64, Nullability::NonNullable);
pub const IDX_32: Self = Primitive(PType::U32, Nullability::NonNullable);
pub fn nullability(&self) -> Nullability {
self.is_nullable().into()
}
pub fn is_nullable(&self) -> bool {
use crate::nullability::Nullability::*;
match self {
Null => true,
Bool(n) => matches!(n, Nullable),
Primitive(_, n) => matches!(n, Nullable),
Utf8(n) => matches!(n, Nullable),
Binary(n) => matches!(n, Nullable),
Struct(_, n) => matches!(n, Nullable),
List(_, n) => matches!(n, Nullable),
Extension(ext_dtype) => ext_dtype.storage_dtype().is_nullable(),
}
}
pub fn as_nonnullable(&self) -> Self {
self.with_nullability(Nullability::NonNullable)
}
pub fn as_nullable(&self) -> Self {
self.with_nullability(Nullability::Nullable)
}
pub fn with_nullability(&self, nullability: Nullability) -> Self {
match self {
Null => Null,
Bool(_) => Bool(nullability),
Primitive(p, _) => Primitive(*p, nullability),
Utf8(_) => Utf8(nullability),
Binary(_) => Binary(nullability),
Struct(st, _) => Struct(st.clone(), nullability),
List(c, _) => List(c.clone(), nullability),
Extension(ext) => Extension(Arc::new(ext.with_nullability(nullability))),
}
}
pub fn eq_ignore_nullability(&self, other: &Self) -> bool {
self.as_nullable().eq(&other.as_nullable())
}
pub fn is_struct(&self) -> bool {
matches!(self, Struct(_, _))
}
pub fn is_unsigned_int(&self) -> bool {
PType::try_from(self).is_ok_and(PType::is_unsigned_int)
}
pub fn is_signed_int(&self) -> bool {
PType::try_from(self).is_ok_and(PType::is_signed_int)
}
pub fn is_int(&self) -> bool {
PType::try_from(self).is_ok_and(PType::is_int)
}
pub fn is_float(&self) -> bool {
PType::try_from(self).is_ok_and(PType::is_float)
}
pub fn is_boolean(&self) -> bool {
matches!(self, Bool(_))
}
pub fn as_struct(&self) -> Option<&StructDType> {
match self {
Struct(s, _) => Some(s),
_ => None,
}
}
}
impl Display for DType {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
Null => write!(f, "null"),
Bool(n) => write!(f, "bool{}", n),
Primitive(pt, n) => write!(f, "{}{}", pt, n),
Utf8(n) => write!(f, "utf8{}", n),
Binary(n) => write!(f, "binary{}", n),
Struct(sdt, n) => write!(
f,
"{{{}}}{}",
sdt.names()
.iter()
.zip(sdt.dtypes().iter())
.map(|(n, dt)| format!("{}={}", n, dt))
.join(", "),
n
),
List(edt, n) => write!(f, "list({}){}", edt, n),
Extension(ext) => write!(
f,
"ext({}, {}{}){}",
ext.id(),
ext.storage_dtype()
.with_nullability(Nullability::NonNullable),
ext.metadata()
.map(|m| format!(", {:?}", m))
.unwrap_or_else(|| "".to_string()),
ext.storage_dtype().nullability(),
),
}
}
}
#[derive(Debug, Clone, PartialOrd, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct StructDType {
names: FieldNames,
dtypes: Arc<[DType]>,
}
pub struct FieldInfo<'a> {
pub index: usize,
pub name: Arc<str>,
pub dtype: &'a DType,
}
impl StructDType {
pub fn new(names: FieldNames, dtypes: Vec<DType>) -> Self {
if names.len() != dtypes.len() {
vortex_panic!(
"length mismatch between names ({}) and dtypes ({})",
names.len(),
dtypes.len()
);
}
Self {
names,
dtypes: dtypes.into(),
}
}
pub fn names(&self) -> &FieldNames {
&self.names
}
pub fn find_name(&self, name: &str) -> Option<usize> {
self.names.iter().position(|n| n.as_ref() == name)
}
pub fn field_info(&self, field: &Field) -> VortexResult<FieldInfo> {
let index = match field {
Field::Name(name) => self
.find_name(name)
.ok_or_else(|| vortex_err!("Unknown field: {}", name))?,
Field::Index(index) => *index,
};
if index >= self.names.len() {
vortex_bail!("field index out of bounds: {}", index)
}
Ok(FieldInfo {
index,
name: self.names[index].clone(),
dtype: &self.dtypes[index],
})
}
pub fn dtypes(&self) -> &Arc<[DType]> {
&self.dtypes
}
pub fn project(&self, projection: &[Field]) -> VortexResult<Self> {
let mut names = Vec::with_capacity(projection.len());
let mut dtypes = Vec::with_capacity(projection.len());
for field in projection.iter() {
let FieldInfo { name, dtype, .. } = self.field_info(field)?;
names.push(name.clone());
dtypes.push(dtype.clone());
}
Ok(StructDType::new(names.into(), dtypes))
}
}
#[cfg(test)]
mod test {
use std::mem;
use crate::dtype::DType;
use crate::field::Field;
use crate::{Nullability, PType, StructDType};
#[test]
fn size_of() {
assert_eq!(mem::size_of::<DType>(), 40);
}
#[test]
fn nullability() {
assert!(!DType::Struct(
StructDType::new(vec![].into(), Vec::new()),
Nullability::NonNullable
)
.is_nullable());
let primitive = DType::Primitive(PType::U8, Nullability::Nullable);
assert!(primitive.is_nullable());
assert!(!primitive.as_nonnullable().is_nullable());
assert!(primitive.as_nonnullable().as_nullable().is_nullable());
}
#[test]
fn test_struct() {
let a_type = DType::Primitive(PType::I32, Nullability::Nullable);
let b_type = DType::Bool(Nullability::NonNullable);
let dtype = DType::Struct(
StructDType::new(
vec!["A".into(), "B".into()].into(),
vec![a_type.clone(), b_type.clone()],
),
Nullability::Nullable,
);
assert!(dtype.is_nullable());
assert!(dtype.as_struct().is_some());
assert!(a_type.as_struct().is_none());
let sdt = dtype.as_struct().unwrap();
assert_eq!(sdt.names().len(), 2);
assert_eq!(sdt.dtypes().len(), 2);
assert_eq!(sdt.names()[0], "A".into());
assert_eq!(sdt.names()[1], "B".into());
assert_eq!(sdt.dtypes()[0], a_type);
assert_eq!(sdt.dtypes()[1], b_type);
let proj = sdt
.project(&[Field::Index(1), Field::Name("A".into())])
.unwrap();
assert_eq!(proj.names()[0], "B".into());
assert_eq!(proj.dtypes()[0], b_type);
assert_eq!(proj.names()[1], "A".into());
assert_eq!(proj.dtypes()[1], a_type);
let field_info = sdt.field_info(&Field::Name("B".into())).unwrap();
assert_eq!(field_info.index, 1);
assert_eq!(field_info.name, "B".into());
assert_eq!(field_info.dtype, &b_type);
let field_info = sdt.field_info(&Field::Index(0)).unwrap();
assert_eq!(field_info.index, 0);
assert_eq!(field_info.name, "A".into());
assert_eq!(field_info.dtype, &a_type);
assert!(sdt.field_info(&Field::Index(2)).is_err());
assert_eq!(sdt.find_name("A"), Some(0));
assert_eq!(sdt.find_name("B"), Some(1));
assert_eq!(sdt.find_name("C"), None);
}
}