use substrait::proto::{
r#type::{
Binary, Boolean, Fp32, Fp64, Kind, Nullability, String as SubstraitString, Struct, I16,
I32, I64, I8,
},
Type,
};
use crate::error::Result;
use crate::util::HasRequiredPropertiesRef;
use super::registry::ExtensionsRegistry;
pub trait TypeExt {
fn same_kind(&self, other: &Type) -> Result<bool>;
fn is_unknown(&self, registry: &ExtensionsRegistry) -> bool;
fn num_types(&self) -> u32;
fn children(&self) -> Vec<&Type>;
}
impl TypeExt for Type {
fn same_kind(&self, other: &Type) -> Result<bool> {
let self_kind = self.kind.required("kind")?;
let other_kind = other.kind.required("kind")?;
Ok(std::mem::discriminant(self_kind) == std::mem::discriminant(other_kind))
}
fn is_unknown(&self, registry: &ExtensionsRegistry) -> bool {
match &self.kind {
Some(Kind::UserDefined(user_defined)) => {
let type_name = registry.lookup_type(user_defined.type_reference);
match type_name {
Some(type_name) => {
type_name.uri == UNKNOWN_TYPE_URI && type_name.name == UNKNOWN_TYPE_NAME
}
None => false,
}
}
_ => false,
}
}
fn num_types(&self) -> u32 {
match &self.kind {
Some(Kind::Struct(strct)) => {
strct.types.iter().map(|typ| typ.num_types()).sum::<u32>() + 1
}
_ => 1,
}
}
fn children(&self) -> Vec<&Type> {
match &self.kind {
Some(Kind::Struct(strct)) => strct.types.iter().collect(),
_ => vec![],
}
}
}
pub(crate) const fn nullability(nullable: bool) -> i32 {
if nullable {
Nullability::Nullable as i32
} else {
Nullability::Required as i32
}
}
pub trait TypeInfer {
fn as_substrait(nullable: bool) -> Type;
}
impl TypeInfer for i8 {
fn as_substrait(nullable: bool) -> Type {
Type {
kind: Some(substrait::proto::r#type::Kind::I8(I8 {
nullability: nullability(nullable),
type_variation_reference: 0,
})),
}
}
}
impl TypeInfer for i16 {
fn as_substrait(nullable: bool) -> Type {
Type {
kind: Some(substrait::proto::r#type::Kind::I16(I16 {
nullability: nullability(nullable),
type_variation_reference: 0,
})),
}
}
}
impl TypeInfer for i32 {
fn as_substrait(nullable: bool) -> Type {
Type {
kind: Some(substrait::proto::r#type::Kind::I32(I32 {
nullability: nullability(nullable),
type_variation_reference: 0,
})),
}
}
}
impl TypeInfer for i64 {
fn as_substrait(nullable: bool) -> Type {
Type {
kind: Some(substrait::proto::r#type::Kind::I64(I64 {
nullability: nullability(nullable),
type_variation_reference: 0,
})),
}
}
}
impl TypeInfer for bool {
fn as_substrait(nullable: bool) -> Type {
Type {
kind: Some(substrait::proto::r#type::Kind::Bool(Boolean {
nullability: nullability(nullable),
type_variation_reference: 0,
})),
}
}
}
impl TypeInfer for f32 {
fn as_substrait(nullable: bool) -> Type {
Type {
kind: Some(substrait::proto::r#type::Kind::Fp32(Fp32 {
nullability: nullability(nullable),
type_variation_reference: 0,
})),
}
}
}
impl TypeInfer for f64 {
fn as_substrait(nullable: bool) -> Type {
Type {
kind: Some(substrait::proto::r#type::Kind::Fp64(Fp64 {
nullability: nullability(nullable),
type_variation_reference: 0,
})),
}
}
}
impl TypeInfer for String {
fn as_substrait(nullable: bool) -> Type {
Type {
kind: Some(substrait::proto::r#type::Kind::String(SubstraitString {
nullability: nullability(nullable),
type_variation_reference: 0,
})),
}
}
}
impl TypeInfer for &str {
fn as_substrait(nullable: bool) -> Type {
Type {
kind: Some(substrait::proto::r#type::Kind::String(SubstraitString {
nullability: nullability(nullable),
type_variation_reference: 0,
})),
}
}
}
impl TypeInfer for &[u8] {
fn as_substrait(nullable: bool) -> Type {
Type {
kind: Some(substrait::proto::r#type::Kind::Binary(Binary {
nullability: nullability(nullable),
type_variation_reference: 0,
})),
}
}
}
impl TypeInfer for Vec<u8> {
fn as_substrait(nullable: bool) -> Type {
Type {
kind: Some(substrait::proto::r#type::Kind::Binary(Binary {
nullability: nullability(nullable),
type_variation_reference: 0,
})),
}
}
}
pub fn from_rust<T: TypeInfer>(nullable: bool) -> Type {
<T as TypeInfer>::as_substrait(nullable)
}
pub fn bool(nullable: bool) -> Type {
from_rust::<bool>(nullable)
}
pub fn i8(nullable: bool) -> Type {
from_rust::<i8>(nullable)
}
pub fn i16(nullable: bool) -> Type {
from_rust::<i16>(nullable)
}
pub fn i32(nullable: bool) -> Type {
from_rust::<i32>(nullable)
}
pub fn i64(nullable: bool) -> Type {
from_rust::<i64>(nullable)
}
pub fn fp32(nullable: bool) -> Type {
from_rust::<f32>(nullable)
}
pub fn fp64(nullable: bool) -> Type {
from_rust::<f64>(nullable)
}
pub fn string(nullable: bool) -> Type {
from_rust::<&str>(nullable)
}
pub fn binary(nullable: bool) -> Type {
from_rust::<&[u8]>(nullable)
}
pub fn struct_(nullable: bool, children: Vec<Type>) -> Type {
Type {
kind: Some(Kind::Struct(Struct {
types: children,
nullability: nullability(nullable),
..Default::default()
})),
}
}
pub const UNKNOWN_TYPE_URI: &'static str = "https://substrait.io/types";
pub const UNKNOWN_TYPE_NAME: &'static str = "unknown";
pub const NO_VARIATION: u32 = 0;