Skip to main content

trino_rust_client/types/
mod.rs

1mod boolean;
2mod data_set;
3mod date_time;
4mod decimal;
5mod fixed_char;
6mod float;
7mod integer;
8mod interval_day_to_second;
9mod interval_year_to_month;
10mod ip_address;
11pub mod json;
12mod map;
13mod option;
14mod row;
15mod seq;
16mod string;
17mod util;
18pub mod uuid;
19
20pub use self::uuid::*;
21pub use boolean::*;
22pub use data_set::*;
23pub use date_time::*;
24pub use decimal::*;
25pub use fixed_char::*;
26pub use float::*;
27pub use integer::*;
28pub use interval_day_to_second::*;
29pub use interval_year_to_month::*;
30pub use ip_address::*;
31pub use map::*;
32pub use option::*;
33pub use row::*;
34pub use seq::*;
35pub use string::*;
36
37//mod str;
38//pub use self::str::*;
39
40use std::borrow::Cow;
41use std::collections::HashMap;
42use std::iter::FromIterator;
43use std::sync::Arc;
44
45use crate::{
46    ClientTypeSignatureParameter, Column, NamedTypeSignature, RawTrinoTy, RowFieldName,
47    TypeSignature,
48};
49use derive_more::Display;
50use iterable::*;
51use serde::de::DeserializeSeed;
52use serde::Serialize;
53
54//TODO: refine it
55#[derive(Display, Debug)]
56pub enum Error {
57    InvalidTrinoType,
58    InvalidColumn,
59    InvalidTypeSignature,
60    ParseDecimalFailed(String),
61    ParseIntervalMonthFailed,
62    ParseIntervalDayFailed,
63    EmptyInTrinoRow,
64    NoneTrinoRow,
65}
66
67pub trait Trino {
68    type ValueType<'a>: Serialize
69    where
70        Self: 'a;
71    type Seed<'a, 'de>: DeserializeSeed<'de, Value = Self>;
72
73    fn value(&self) -> Self::ValueType<'_>;
74
75    fn ty() -> TrinoTy;
76
77    /// caller must provide a valid context
78    fn seed<'a, 'de>(ctx: &'a Context<'a>) -> Self::Seed<'a, 'de>;
79
80    fn empty() -> Self;
81}
82
83pub trait TrinoMapKey: Trino {}
84
85#[derive(Debug)]
86pub struct Context<'a> {
87    ty: &'a TrinoTy,
88    map: Arc<HashMap<usize, Vec<usize>>>,
89}
90
91impl<'a> Context<'a> {
92    pub fn new<T: Trino>(provided: &'a TrinoTy) -> Result<Self, Error> {
93        let target = T::ty();
94        let ret = extract(&target, provided)?;
95        let map = HashMap::from_iter(ret);
96        Ok(Context {
97            ty: provided,
98            map: Arc::new(map),
99        })
100    }
101
102    pub fn with_ty(&'a self, ty: &'a TrinoTy) -> Context<'a> {
103        Context {
104            ty,
105            map: self.map.clone(),
106        }
107    }
108
109    pub fn ty(&self) -> &TrinoTy {
110        self.ty
111    }
112
113    pub fn row_map(&self) -> Option<&[usize]> {
114        let key = self.ty as *const TrinoTy as usize;
115        self.map.get(&key).map(|r| &**r)
116    }
117}
118
119fn extract(target: &TrinoTy, provided: &TrinoTy) -> Result<Vec<(usize, Vec<usize>)>, Error> {
120    use TrinoTy::*;
121
122    match (target, provided) {
123        (Unknown, _) => Ok(vec![]),
124        (Decimal(p1, s1), Decimal(p2, s2)) if p1 == p2 && s1 == s2 => Ok(vec![]),
125        (Option(ty), provided) => extract(ty, provided),
126        (Boolean, Boolean) => Ok(vec![]),
127        (Date, Date) => Ok(vec![]),
128        (Time, Time) => Ok(vec![]),
129        (TimeWithTimeZone, TimeWithTimeZone) => Ok(vec![]),
130        (Timestamp, Timestamp) => Ok(vec![]),
131        (TimestampWithTimeZone, TimestampWithTimeZone) => Ok(vec![]),
132        (IntervalYearToMonth, IntervalYearToMonth) => Ok(vec![]),
133        (IntervalDayToSecond, IntervalDayToSecond) => Ok(vec![]),
134        (TrinoInt(_), TrinoInt(_)) => Ok(vec![]),
135        (TrinoFloat(_), TrinoFloat(_)) => Ok(vec![]),
136        (Varchar, Varchar) => Ok(vec![]),
137        (Char(a), Char(b)) if a == b => Ok(vec![]),
138        (Tuple(t1), Tuple(t2)) => {
139            if t1.len() != t2.len() {
140                Err(Error::InvalidTrinoType)
141            } else {
142                t1.lazy_zip(t2).try_flat_map(|(l, r)| extract(l, r))
143            }
144        }
145        (Row(t1), Row(t2)) => {
146            if t1.len() != t2.len() {
147                Err(Error::InvalidTrinoType)
148            } else {
149                // create a vector of the original element's reference
150                let t1k = t1.sorted_by(|t1, t2| Ord::cmp(&t1.0, &t2.0));
151                let t2k = t2.sorted_by(|t1, t2| Ord::cmp(&t1.0, &t2.0));
152
153                let ret = t1k.lazy_zip(t2k).try_flat_map(|(l, r)| {
154                    if l.0 == r.0 {
155                        extract(&l.1, &r.1)
156                    } else {
157                        Err(Error::InvalidTrinoType)
158                    }
159                })?;
160
161                let map = t2.map(|provided| t1.position(|target| provided.0 == target.0).unwrap());
162                let key = provided as *const TrinoTy as usize;
163                Ok(ret.add_one((key, map)))
164            }
165        }
166        (Array(t1), Array(t2)) => extract(t1, t2),
167        (Map(t1k, t1v), Map(t2k, t2v)) => Ok(extract(t1k, t2k)?.chain(extract(t1v, t2v)?)),
168        (IpAddress, IpAddress) => Ok(vec![]),
169        (Uuid, Uuid) => Ok(vec![]),
170        (Json, Json) => Ok(vec![]),
171        _ => Err(Error::InvalidTrinoType),
172    }
173}
174
175// TODO:
176// VarBinary Json
177// TimestampWithTimeZone TimeWithTimeZone
178// HyperLogLog P4HyperLogLog
179// QDigest
180#[derive(Clone, Debug, Eq, PartialEq)]
181pub enum TrinoTy {
182    Date,
183    Time,
184    TimeWithTimeZone,
185    Timestamp,
186    TimestampWithTimeZone,
187    Uuid,
188    IntervalYearToMonth,
189    IntervalDayToSecond,
190    Option(Box<TrinoTy>),
191    Boolean,
192    TrinoInt(TrinoInt),
193    TrinoFloat(TrinoFloat),
194    Varchar,
195    Char(usize),
196    Tuple(Vec<TrinoTy>),
197    Row(Vec<(String, TrinoTy)>),
198    Array(Box<TrinoTy>),
199    Map(Box<TrinoTy>, Box<TrinoTy>),
200    Decimal(usize, usize),
201    IpAddress,
202    Json,
203    Unknown,
204}
205
206#[derive(Clone, Debug, Eq, PartialEq)]
207pub enum TrinoInt {
208    I8,
209    I16,
210    I32,
211    I64,
212}
213
214#[derive(Clone, Debug, Eq, PartialEq)]
215pub enum TrinoFloat {
216    F32,
217    F64,
218}
219
220impl TrinoTy {
221    pub fn from_type_signature(mut sig: TypeSignature) -> Result<Self, Error> {
222        use TrinoFloat::*;
223        use TrinoInt::*;
224
225        let ty = match sig.raw_type {
226            RawTrinoTy::Date => TrinoTy::Date,
227            RawTrinoTy::Time => TrinoTy::Time,
228            RawTrinoTy::TimeWithTimeZone => TrinoTy::TimeWithTimeZone,
229            RawTrinoTy::Timestamp => TrinoTy::Timestamp,
230            RawTrinoTy::TimestampWithTimeZone => TrinoTy::TimestampWithTimeZone,
231            RawTrinoTy::IntervalYearToMonth => TrinoTy::IntervalYearToMonth,
232            RawTrinoTy::IntervalDayToSecond => TrinoTy::IntervalDayToSecond,
233            RawTrinoTy::Unknown => TrinoTy::Unknown,
234            RawTrinoTy::Decimal if sig.arguments.len() == 2 => {
235                let s_sig = sig.arguments.pop().unwrap();
236                let p_sig = sig.arguments.pop().unwrap();
237                if let (
238                    ClientTypeSignatureParameter::LongLiteral(p),
239                    ClientTypeSignatureParameter::LongLiteral(s),
240                ) = (p_sig, s_sig)
241                {
242                    TrinoTy::Decimal(p as usize, s as usize)
243                } else {
244                    return Err(Error::InvalidTypeSignature);
245                }
246            }
247            RawTrinoTy::Boolean => TrinoTy::Boolean,
248            RawTrinoTy::TinyInt => TrinoTy::TrinoInt(I8),
249            RawTrinoTy::SmallInt => TrinoTy::TrinoInt(I16),
250            RawTrinoTy::Integer => TrinoTy::TrinoInt(I32),
251            RawTrinoTy::BigInt => TrinoTy::TrinoInt(I64),
252            RawTrinoTy::Real => TrinoTy::TrinoFloat(F32),
253            RawTrinoTy::Double => TrinoTy::TrinoFloat(F64),
254            RawTrinoTy::VarChar => TrinoTy::Varchar,
255            RawTrinoTy::Char if sig.arguments.len() == 1 => {
256                if let ClientTypeSignatureParameter::LongLiteral(p) = sig.arguments.pop().unwrap() {
257                    TrinoTy::Char(p as usize)
258                } else {
259                    return Err(Error::InvalidTypeSignature);
260                }
261            }
262            RawTrinoTy::Array if sig.arguments.len() == 1 => {
263                let sig = sig.arguments.pop().unwrap();
264                if let ClientTypeSignatureParameter::TypeSignature(sig) = sig {
265                    let inner = Self::from_type_signature(sig)?;
266                    TrinoTy::Array(Box::new(inner))
267                } else {
268                    return Err(Error::InvalidTypeSignature);
269                }
270            }
271            RawTrinoTy::Map if sig.arguments.len() == 2 => {
272                let v_sig = sig.arguments.pop().unwrap();
273                let k_sig = sig.arguments.pop().unwrap();
274                if let (
275                    ClientTypeSignatureParameter::TypeSignature(k_sig),
276                    ClientTypeSignatureParameter::TypeSignature(v_sig),
277                ) = (k_sig, v_sig)
278                {
279                    let k_inner = Self::from_type_signature(k_sig)?;
280                    let v_inner = Self::from_type_signature(v_sig)?;
281                    TrinoTy::Map(Box::new(k_inner), Box::new(v_inner))
282                } else {
283                    return Err(Error::InvalidTypeSignature);
284                }
285            }
286            RawTrinoTy::Row if !sig.arguments.is_empty() => {
287                let ir = sig.arguments.try_map(|arg| match arg {
288                    ClientTypeSignatureParameter::NamedTypeSignature(sig) => {
289                        let name = sig.field_name.map(|n| n.name);
290                        let ty = Self::from_type_signature(sig.type_signature)?;
291                        Ok((name, ty))
292                    }
293                    _ => Err(Error::InvalidTypeSignature),
294                })?;
295
296                let is_named = ir[0].0.is_some();
297
298                if is_named {
299                    let row = ir.try_map(|(name, ty)| match name {
300                        Some(n) => Ok((n, ty)),
301                        None => Err(Error::InvalidTypeSignature),
302                    })?;
303                    TrinoTy::Row(row)
304                } else {
305                    let tuple = ir.try_map(|(name, ty)| match name {
306                        Some(_) => Err(Error::InvalidTypeSignature),
307                        None => Ok(ty),
308                    })?;
309                    TrinoTy::Tuple(tuple)
310                }
311            }
312            RawTrinoTy::IpAddress => TrinoTy::IpAddress,
313            RawTrinoTy::Uuid => TrinoTy::Uuid,
314            RawTrinoTy::Json => TrinoTy::Json,
315            _ => return Err(Error::InvalidTypeSignature),
316        };
317
318        Ok(ty)
319    }
320
321    pub fn from_column(column: Column) -> Result<(String, Self), Error> {
322        let name = column.name;
323        if let Some(sig) = column.type_signature {
324            let ty = Self::from_type_signature(sig)?;
325            Ok((name, ty))
326        } else {
327            Err(Error::InvalidColumn)
328        }
329    }
330
331    pub fn from_columns(columns: Vec<Column>) -> Result<Self, Error> {
332        let row = columns.try_map(Self::from_column)?;
333        Ok(TrinoTy::Row(row))
334    }
335
336    pub fn into_type_signature(self) -> TypeSignature {
337        use TrinoTy::*;
338
339        let raw_ty = self.raw_type();
340
341        let params = match self {
342            Unknown => vec![],
343            Decimal(p, s) => vec![
344                ClientTypeSignatureParameter::LongLiteral(p as u64),
345                ClientTypeSignatureParameter::LongLiteral(s as u64),
346            ],
347            Date => vec![],
348            Time => vec![],
349            TimeWithTimeZone => vec![],
350            Timestamp => vec![],
351            TimestampWithTimeZone => vec![],
352            IntervalYearToMonth => vec![],
353            IntervalDayToSecond => vec![],
354            Option(t) => return t.into_type_signature(),
355            Boolean => vec![],
356            TrinoInt(_) => vec![],
357            TrinoFloat(_) => vec![],
358            Varchar => vec![ClientTypeSignatureParameter::LongLiteral(2147483647)],
359            Char(a) => vec![ClientTypeSignatureParameter::LongLiteral(a as u64)],
360            Tuple(ts) => ts.map(|ty| {
361                ClientTypeSignatureParameter::NamedTypeSignature(NamedTypeSignature {
362                    field_name: None,
363                    type_signature: ty.into_type_signature(),
364                })
365            }),
366            Row(ts) => ts.map(|(name, ty)| {
367                ClientTypeSignatureParameter::NamedTypeSignature(NamedTypeSignature {
368                    field_name: Some(RowFieldName::new(name)),
369                    type_signature: ty.into_type_signature(),
370                })
371            }),
372            Array(t) => vec![ClientTypeSignatureParameter::TypeSignature(
373                t.into_type_signature(),
374            )],
375            Map(t1, t2) => vec![
376                ClientTypeSignatureParameter::TypeSignature(t1.into_type_signature()),
377                ClientTypeSignatureParameter::TypeSignature(t2.into_type_signature()),
378            ],
379            IpAddress => vec![],
380            Uuid => vec![],
381            Json => vec![],
382        };
383
384        TypeSignature::new(raw_ty, params)
385    }
386
387    pub fn full_type(&self) -> Cow<'static, str> {
388        use TrinoTy::*;
389
390        match self {
391            Unknown => RawTrinoTy::Unknown.to_str().into(),
392            Decimal(p, s) => format!("{}({},{})", RawTrinoTy::Decimal.to_str(), p, s).into(),
393            Option(t) => t.full_type(),
394            Date => RawTrinoTy::Date.to_str().into(),
395            Time => RawTrinoTy::Time.to_str().into(),
396            TimeWithTimeZone => RawTrinoTy::TimeWithTimeZone.to_str().into(),
397            Timestamp => RawTrinoTy::Timestamp.to_str().into(),
398            TimestampWithTimeZone => RawTrinoTy::TimestampWithTimeZone.to_str().into(),
399            IntervalYearToMonth => RawTrinoTy::IntervalYearToMonth.to_str().into(),
400            IntervalDayToSecond => RawTrinoTy::IntervalDayToSecond.to_str().into(),
401            Boolean => RawTrinoTy::Boolean.to_str().into(),
402            TrinoInt(ty) => ty.raw_type().to_str().into(),
403            TrinoFloat(ty) => ty.raw_type().to_str().into(),
404            Varchar => RawTrinoTy::VarChar.to_str().into(),
405            Char(a) => format!("{}({})", RawTrinoTy::Char.to_str(), a).into(),
406            Tuple(ts) => format!(
407                "{}({})",
408                RawTrinoTy::Row.to_str(),
409                ts.lazy_map(|ty| ty.full_type()).join(",")
410            )
411            .into(),
412            Row(ts) => format!(
413                "{}({})",
414                RawTrinoTy::Row.to_str(),
415                ts.lazy_map(|(name, ty)| format!("{} {}", name, ty.full_type()))
416                    .join(",")
417            )
418            .into(),
419            Array(t) => format!("{}({})", RawTrinoTy::Array.to_str(), t.full_type()).into(),
420            Map(t1, t2) => format!(
421                "{}({},{})",
422                RawTrinoTy::Map.to_str(),
423                t1.full_type(),
424                t2.full_type()
425            )
426            .into(),
427            IpAddress => RawTrinoTy::IpAddress.to_str().into(),
428            Uuid => RawTrinoTy::Uuid.to_str().into(),
429            Json => RawTrinoTy::Json.to_str().into(),
430        }
431    }
432
433    pub fn raw_type(&self) -> RawTrinoTy {
434        use TrinoTy::*;
435
436        match self {
437            Unknown => RawTrinoTy::Unknown,
438            Date => RawTrinoTy::Date,
439            Time => RawTrinoTy::Time,
440            TimeWithTimeZone => RawTrinoTy::TimeWithTimeZone,
441            Timestamp => RawTrinoTy::Timestamp,
442            TimestampWithTimeZone => RawTrinoTy::TimestampWithTimeZone,
443            IntervalYearToMonth => RawTrinoTy::IntervalYearToMonth,
444            IntervalDayToSecond => RawTrinoTy::IntervalDayToSecond,
445            Decimal(_, _) => RawTrinoTy::Decimal,
446            Option(ty) => ty.raw_type(),
447            Boolean => RawTrinoTy::Boolean,
448            TrinoInt(ty) => ty.raw_type(),
449            TrinoFloat(ty) => ty.raw_type(),
450            Varchar => RawTrinoTy::VarChar,
451            Char(_) => RawTrinoTy::Char,
452            Tuple(_) => RawTrinoTy::Row,
453            Row(_) => RawTrinoTy::Row,
454            Array(_) => RawTrinoTy::Array,
455            Map(_, _) => RawTrinoTy::Map,
456            IpAddress => RawTrinoTy::IpAddress,
457            Uuid => RawTrinoTy::Uuid,
458            Json => RawTrinoTy::Json,
459        }
460    }
461}
462
463impl TrinoInt {
464    pub fn raw_type(&self) -> RawTrinoTy {
465        use TrinoInt::*;
466        match self {
467            I8 => RawTrinoTy::TinyInt,
468            I16 => RawTrinoTy::SmallInt,
469            I32 => RawTrinoTy::Integer,
470            I64 => RawTrinoTy::BigInt,
471        }
472    }
473}
474
475impl TrinoFloat {
476    pub fn raw_type(&self) -> RawTrinoTy {
477        use TrinoFloat::*;
478        match self {
479            F32 => RawTrinoTy::Real,
480            F64 => RawTrinoTy::Double,
481        }
482    }
483}